From 6b02dcd725f6e32b23349689a794ca2edd1d7dca Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 25 May 2020 17:22:46 +0200 Subject: [PATCH 1/2] Transition to table structs, use OnceAgain for fetching tables --- network/state/exists.go | 36 +++++++--------- network/state/info.go | 36 ++++++++-------- network/state/lookup.go | 60 +++++++++++---------------- network/state/tables.go | 80 +++++++++++------------------------ network/state/tcp.go | 31 ++++++++++++++ network/state/udp.go | 92 ++++++++++++++++++++++++++++------------- 6 files changed, 173 insertions(+), 162 deletions(-) create mode 100644 network/state/tcp.go diff --git a/network/state/exists.go b/network/state/exists.go index 7b308608..d401058a 100644 --- a/network/state/exists.go +++ b/network/state/exists.go @@ -19,38 +19,33 @@ func Exists(pktInfo *packet.Info, now time.Time) (exists bool) { switch { case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: - tcp4Lock.Lock() - defer tcp4Lock.Unlock() - return existsTCP(tcp4Connections, pktInfo) + return tcp4Table.exists(pktInfo) case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP: - tcp6Lock.Lock() - defer tcp6Lock.Unlock() - return existsTCP(tcp6Connections, pktInfo) + return tcp6Table.exists(pktInfo) case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP: - udp4Lock.Lock() - defer udp4Lock.Unlock() - return existsUDP(udp4Binds, udp4States, pktInfo, now) + return udp4Table.exists(pktInfo, now) case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP: - udp6Lock.Lock() - defer udp6Lock.Unlock() - return existsUDP(udp6Binds, udp6States, pktInfo, now) + return udp6Table.exists(pktInfo, now) default: return false } } -func existsTCP(connections []*socket.ConnectionInfo, pktInfo *packet.Info) (exists bool) { +func (table *tcpTable) exists(pktInfo *packet.Info) (exists bool) { + table.lock.RLock() + defer table.lock.RUnlock() + localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() remoteIP := pktInfo.RemoteIP() remotePort := pktInfo.RemotePort() // search connections - for _, socketInfo := range connections { + for _, socketInfo := range table.connections { if localPort == socketInfo.Local.Port && remotePort == socketInfo.Remote.Port && remoteIP.Equal(socketInfo.Remote.IP) && @@ -62,12 +57,9 @@ func existsTCP(connections []*socket.ConnectionInfo, pktInfo *packet.Info) (exis return false } -func existsUDP( - binds []*socket.BindInfo, - udpStates map[string]map[string]*udpState, - pktInfo *packet.Info, - now time.Time, -) (exists bool) { +func (table *udpTable) exists(pktInfo *packet.Info, now time.Time) (exists bool) { + table.lock.RLock() + defer table.lock.RUnlock() localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() @@ -77,11 +69,11 @@ func existsUDP( connThreshhold := now.Add(-UDPConnectionTTL) // search binds - for _, socketInfo := range binds { + for _, socketInfo := range table.binds { if localPort == socketInfo.Local.Port && (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { - udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{ + udpConnState, ok := table.getConnState(socketInfo, socket.Address{ IP: remoteIP, Port: remotePort, }) diff --git a/network/state/info.go b/network/state/info.go index 5d4b0d4d..7d4c11fc 100644 --- a/network/state/info.go +++ b/network/state/info.go @@ -25,27 +25,27 @@ type Info struct { func GetInfo() *Info { info := &Info{} - tcp4Lock.Lock() - updateTCP4Tables() - info.TCP4Connections = tcp4Connections - info.TCP4Listeners = tcp4Listeners - tcp4Lock.Unlock() + tcp4Table.updateTables() + tcp4Table.lock.RLock() + info.TCP4Connections = tcp4Table.connections + info.TCP4Listeners = tcp4Table.listeners + tcp4Table.lock.RUnlock() - tcp6Lock.Lock() - updateTCP6Tables() - info.TCP6Connections = tcp6Connections - info.TCP6Listeners = tcp6Listeners - tcp6Lock.Unlock() + tcp6Table.updateTables() + tcp6Table.lock.RLock() + info.TCP6Connections = tcp6Table.connections + info.TCP6Listeners = tcp6Table.listeners + tcp6Table.lock.RUnlock() - udp4Lock.Lock() - updateUDP4Table() - info.UDP4Binds = udp4Binds - udp4Lock.Unlock() + udp4Table.updateTable() + udp4Table.lock.RLock() + info.UDP4Binds = udp4Table.binds + udp4Table.lock.RUnlock() - udp6Lock.Lock() - updateUDP6Table() - info.UDP6Binds = udp6Binds - udp6Lock.Unlock() + udp6Table.updateTable() + udp6Table.lock.RLock() + info.UDP6Binds = udp6Table.binds + udp6Table.lock.RUnlock() info.UpdateMeta() return info diff --git a/network/state/lookup.go b/network/state/lookup.go index 5aadf7fa..a9ac6916 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -2,7 +2,6 @@ package state import ( "errors" - "sync" "time" "github.com/safing/portmaster/network/netutils" @@ -31,11 +30,6 @@ var ( ) var ( - tcp4Lock sync.Mutex - tcp6Lock sync.Mutex - udp4Lock sync.Mutex - udp6Lock sync.Mutex - baseWaitTime = 3 * time.Millisecond ) @@ -52,36 +46,23 @@ func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) { switch { case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: - tcp4Lock.Lock() - defer tcp4Lock.Unlock() - return searchTCP(tcp4Connections, tcp4Listeners, updateTCP4Tables, pktInfo) + return tcp4Table.lookup(pktInfo) case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP: - tcp6Lock.Lock() - defer tcp6Lock.Unlock() - return searchTCP(tcp6Connections, tcp6Listeners, updateTCP6Tables, pktInfo) + return tcp6Table.lookup(pktInfo) case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP: - udp4Lock.Lock() - defer udp4Lock.Unlock() - return searchUDP(udp4Binds, udp4States, updateUDP4Table, pktInfo) + return udp4Table.lookup(pktInfo) case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP: - udp6Lock.Lock() - defer udp6Lock.Unlock() - return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo) + return udp6Table.lookup(pktInfo) default: return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") } } -func searchTCP( - connections []*socket.ConnectionInfo, - listeners []*socket.BindInfo, - updateTables func() ([]*socket.ConnectionInfo, []*socket.BindInfo), - pktInfo *packet.Info, -) ( +func (table *tcpTable) lookup(pktInfo *packet.Info) ( pid int, inbound bool, err error, @@ -92,44 +73,44 @@ func searchTCP( // search until we find something for i := 0; i < 7; i++ { + table.lock.RLock() + // always search listeners first - for _, socketInfo := range listeners { + for _, socketInfo := range table.listeners { if localPort == socketInfo.Local.Port && (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { + table.lock.RUnlock() return checkBindPID(socketInfo, true) } } // search connections - for _, socketInfo := range connections { + for _, socketInfo := range table.connections { if localPort == socketInfo.Local.Port && localIP.Equal(socketInfo.Local.IP) { + table.lock.RUnlock() return checkConnectionPID(socketInfo, false) } } + table.lock.RUnlock() + // we found nothing, we could have been too fast, give the kernel some time to think // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total time.Sleep(time.Duration(i+1) * baseWaitTime) // refetch lists - connections, listeners = updateTables() + table.updateTables() } return socket.UnidentifiedProcessID, false, ErrConnectionNotFound } -func searchUDP( - binds []*socket.BindInfo, - udpStates map[string]map[string]*udpState, - updateTable func() []*socket.BindInfo, - pktInfo *packet.Info, -) ( +func (table *udpTable) lookup(pktInfo *packet.Info) ( pid int, inbound bool, err error, ) { - localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() @@ -141,12 +122,15 @@ func searchUDP( // search until we find something for i := 0; i < 5; i++ { + table.lock.RLock() + // search binds - for _, socketInfo := range binds { + for _, socketInfo := range table.binds { if localPort == socketInfo.Local.Port && (socketInfo.Local.IP[0] == 0 || // zero IP isInboundMulticast || // inbound broadcast, multicast localIP.Equal(socketInfo.Local.IP)) { + table.lock.RUnlock() // do not check direction if remoteIP/Port is not given if pktInfo.RemotePort() == 0 { @@ -154,17 +138,19 @@ func searchUDP( } // get direction and return - connInbound := getUDPDirection(socketInfo, udpStates, pktInfo) + connInbound := table.getDirection(socketInfo, pktInfo) return checkBindPID(socketInfo, connInbound) } } + table.lock.RUnlock() + // we found nothing, we could have been too fast, give the kernel some time to think // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total time.Sleep(time.Duration(i+1) * baseWaitTime) // refetch lists - binds = updateTable() + table.updateTable() } return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound diff --git a/network/state/tables.go b/network/state/tables.go index 2f236cc6..df6e9783 100644 --- a/network/state/tables.go +++ b/network/state/tables.go @@ -2,67 +2,35 @@ package state import ( "github.com/safing/portbase/log" - "github.com/safing/portmaster/network/socket" ) -var ( - tcp4Connections []*socket.ConnectionInfo - tcp4Listeners []*socket.BindInfo +func (table *tcpTable) updateTables() { + table.fetchOnceAgain.Do(func() { + table.lock.Lock() + defer table.lock.Unlock() - tcp6Connections []*socket.ConnectionInfo - tcp6Listeners []*socket.BindInfo + connections, listeners, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) + return + } - udp4Binds []*socket.BindInfo - - udp6Binds []*socket.BindInfo -) - -func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { - var err error - connections, listeners, err = getTCP4Table() - if err != nil { - log.Warningf("state: failed to get TCP4 socket table: %s", err) - return - } - - tcp4Connections = connections - tcp4Listeners = listeners - return + table.connections = connections + table.listeners = listeners + }) } -func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { - var err error - connections, listeners, err = getTCP6Table() - if err != nil { - log.Warningf("state: failed to get TCP6 socket table: %s", err) - return - } +func (table *udpTable) updateTable() { + table.fetchOnceAgain.Do(func() { + table.lock.Lock() + defer table.lock.Unlock() - tcp6Connections = connections - tcp6Listeners = listeners - return -} - -func updateUDP4Table() (binds []*socket.BindInfo) { - var err error - binds, err = getUDP4Table() - if err != nil { - log.Warningf("state: failed to get UDP4 socket table: %s", err) - return - } - - udp4Binds = binds - return -} - -func updateUDP6Table() (binds []*socket.BindInfo) { - var err error - binds, err = getUDP6Table() - if err != nil { - log.Warningf("state: failed to get UDP6 socket table: %s", err) - return - } - - udp6Binds = binds - return + binds, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) + return + } + + table.binds = binds + }) } diff --git a/network/state/tcp.go b/network/state/tcp.go new file mode 100644 index 00000000..2894aada --- /dev/null +++ b/network/state/tcp.go @@ -0,0 +1,31 @@ +package state + +import ( + "sync" + + "github.com/safing/portbase/utils" + "github.com/safing/portmaster/network/socket" +) + +type tcpTable struct { + version int + + connections []*socket.ConnectionInfo + listeners []*socket.BindInfo + lock sync.RWMutex + + fetchOnceAgain utils.OnceAgain + fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) +} + +var ( + tcp4Table = &tcpTable{ + version: 4, + fetchTable: getTCP4Table, + } + + tcp6Table = &tcpTable{ + version: 6, + fetchTable: getTCP6Table, + } +) diff --git a/network/state/udp.go b/network/state/udp.go index f49b1d04..db238e47 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -2,12 +2,27 @@ package state import ( "context" + "sync" "time" + "github.com/safing/portbase/utils" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/socket" ) +type udpTable struct { + version int + + binds []*socket.BindInfo + lock sync.RWMutex + + fetchOnceAgain utils.OnceAgain + fetchTable func() (binds []*socket.BindInfo, err error) + + states map[string]map[string]*udpState + statesLock sync.Mutex +} + type udpState struct { inbound bool lastSeen time.Time @@ -25,12 +40,38 @@ const ( ) var ( - udp4States = make(map[string]map[string]*udpState) // locked with udp4Lock - udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock + udp4Table = &udpTable{ + version: 4, + fetchTable: getUDP4Table, + states: make(map[string]map[string]*udpState), + } + + udp6Table = &udpTable{ + version: 6, + fetchTable: getUDP6Table, + states: make(map[string]map[string]*udpState), + } ) -func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) { - bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)] +// CleanUDPStates cleans the udp connection states which save connection directions. +func CleanUDPStates(_ context.Context) { + now := time.Now().UTC() + + udp4Table.updateTable() + udp4Table.cleanStates(now) + + udp6Table.updateTable() + udp6Table.cleanStates(now) +} + +func (table *udpTable) getConnState( + socketInfo *socket.BindInfo, + remoteAddress socket.Address, +) (udpConnState *udpState, ok bool) { + table.statesLock.Lock() + defer table.statesLock.Unlock() + + bindMap, ok := table.states[makeUDPStateKey(socketInfo.Local)] if ok { udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)] return @@ -39,13 +80,19 @@ func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[strin return nil, false } -func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) { +func (table *udpTable) getDirection( + socketInfo *socket.BindInfo, + pktInfo *packet.Info, +) (connDirection bool) { + table.statesLock.Lock() + defer table.statesLock.Unlock() + localKey := makeUDPStateKey(socketInfo.Local) - bindMap, ok := udpStates[localKey] + bindMap, ok := table.states[localKey] if !ok { bindMap = make(map[string]*udpState) - udpStates[localKey] = bindMap + table.states[localKey] = bindMap } remoteKey := makeUDPStateKey(socket.Address{ @@ -65,38 +112,25 @@ func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[strin return udpConnState.inbound } -// CleanUDPStates cleans the udp connection states which save connection directions. -func CleanUDPStates(_ context.Context) { - now := time.Now().UTC() +func (table *udpTable) cleanStates(now time.Time) { - udp4Lock.Lock() - updateUDP4Table() - cleanStates(udp4Binds, udp4States, now) - udp4Lock.Unlock() - - udp6Lock.Lock() - updateUDP6Table() - cleanStates(udp6Binds, udp6States, now) - udp6Lock.Unlock() -} - -func cleanStates( - binds []*socket.BindInfo, - udpStates map[string]map[string]*udpState, - now time.Time, -) { // compute thresholds threshold := now.Add(-UDPConnStateTTL) shortThreshhold := now.Add(-UDPConnStateShortenedTTL) // make lookup map of all active keys bindKeys := make(map[string]struct{}) - for _, socketInfo := range binds { + table.lock.RLock() + for _, socketInfo := range table.binds { bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{} } + table.lock.RUnlock() + + table.statesLock.Lock() + defer table.statesLock.Unlock() // clean the udp state storage - for localKey, bindMap := range udpStates { + for localKey, bindMap := range table.states { if _, active := bindKeys[localKey]; active { // clean old entries for remoteKey, udpConnState := range bindMap { @@ -114,7 +148,7 @@ func cleanStates( } } else { // delete the whole thing - delete(udpStates, localKey) + delete(table.states, localKey) } } } From 93a3ad2a8096a55809759fcd03cbe28a2594cd71 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 25 May 2020 17:35:14 +0200 Subject: [PATCH 2/2] Improve retrying --- network/state/lookup.go | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/network/state/lookup.go b/network/state/lookup.go index a9ac6916..b05fe237 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -30,7 +30,8 @@ var ( ) var ( - baseWaitTime = 3 * time.Millisecond + baseWaitTime = 3 * time.Millisecond + lookupRetries = 7 ) // Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound. @@ -72,7 +73,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) ( localPort := pktInfo.LocalPort() // search until we find something - for i := 0; i < 7; i++ { + for i := 0; i <= lookupRetries; i++ { table.lock.RLock() // always search listeners first @@ -95,12 +96,15 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) ( table.lock.RUnlock() - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) + // every time, except for the last iteration + if i < lookupRetries { + // we found nothing, we could have been too fast, give the kernel some time to think + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) - // refetch lists - table.updateTables() + // refetch lists + table.updateTables() + } } return socket.UnidentifiedProcessID, false, ErrConnectionNotFound @@ -121,7 +125,7 @@ func (table *udpTable) lookup(pktInfo *packet.Info) ( // binding to different addresses. This highly unusual for clients. // search until we find something - for i := 0; i < 5; i++ { + for i := 0; i <= lookupRetries; i++ { table.lock.RLock() // search binds @@ -145,12 +149,15 @@ func (table *udpTable) lookup(pktInfo *packet.Info) ( table.lock.RUnlock() - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) + // every time, except for the last iteration + if i < lookupRetries { + // we found nothing, we could have been too fast, give the kernel some time to think + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) - // refetch lists - table.updateTable() + // refetch lists + table.updateTable() + } } return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound