diff --git a/network/state/info.go b/network/state/info.go index 0b75d37a..818efda3 100644 --- a/network/state/info.go +++ b/network/state/info.go @@ -3,9 +3,8 @@ package state import ( "sync" - "github.com/safing/portmaster/netenv" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network/socket" ) @@ -26,28 +25,12 @@ type Info struct { func GetInfo() *Info { info := &Info{} - tcp4Table.updateTables() - tcp4Table.lock.RLock() - info.TCP4Connections = tcp4Table.connections - info.TCP4Listeners = tcp4Table.listeners - tcp4Table.lock.RUnlock() - - udp4Table.updateTable() - udp4Table.lock.RLock() - info.UDP4Binds = udp4Table.binds - udp4Table.lock.RUnlock() + info.TCP4Connections, info.TCP4Listeners, _ = tcp4Table.updateTables(0) + info.UDP4Binds, _ = udp4Table.updateTables(0) if netenv.IPv6Enabled() { - tcp6Table.updateTables() - tcp6Table.lock.RLock() - info.TCP6Connections = tcp6Table.connections - info.TCP6Listeners = tcp6Table.listeners - tcp6Table.lock.RUnlock() - - udp6Table.updateTable() - udp6Table.lock.RLock() - info.UDP6Binds = udp6Table.binds - udp6Table.lock.RUnlock() + info.TCP6Connections, info.TCP6Listeners, _ = tcp6Table.updateTables(tcp6Table.updateIter.Load()) + info.UDP6Binds, _ = udp6Table.updateTables(0) } info.UpdateMeta() diff --git a/network/state/lookup.go b/network/state/lookup.go index 0586b17b..a764e522 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -69,19 +69,52 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( inbound bool, err error, ) { - // Search pattern: search, refresh, search, wait, search, refresh, search, wait, ... + // Prepare variables. + var ( + connections []*socket.ConnectionInfo + listeners []*socket.BindInfo + updateIter uint64 + + dualStackConnections []*socket.ConnectionInfo + dualStackListeners []*socket.BindInfo + dualStackUpdateIter uint64 + ) // Search for the socket until found. for i := 1; i <= lookupRetries; i++ { - // Check main table for socket. - socketInfo, inbound := table.findSocket(pktInfo) - if socketInfo == nil && table.dualStack != nil { - // If there was no match in the main table and we are dual-stack, check - // the dual-stack table for the socket. - socketInfo, inbound = table.dualStack.findSocket(pktInfo) + // Get or update tables. + if i == 1 { + connections, listeners, updateIter = table.getCurrentTables() + } else { + connections, listeners, updateIter = table.updateTables(updateIter) } - // If there's a match, check we have the PID and return. + // Check tables for socket. + socketInfo, inbound := findTCPSocket(pktInfo, connections, listeners) + + // If there's a match, check if we have the PID and return. + if socketInfo != nil { + return checkPID(socketInfo, inbound) + } + + // DUAL-STACK + + // Skip if dualStack is not enabled. + if table.dualStack == nil { + continue + } + + // Get or update tables. + if i == 0 { + dualStackConnections, dualStackListeners, dualStackUpdateIter = table.dualStack.getCurrentTables() + } else { + dualStackConnections, dualStackListeners, dualStackUpdateIter = table.dualStack.updateTables(dualStackUpdateIter) + } + + // Check tables for socket. + socketInfo, inbound = findTCPSocket(pktInfo, dualStackConnections, dualStackListeners) + + // If there's a match, check if we have the PID and return. if socketInfo != nil { return checkPID(socketInfo, inbound) } @@ -90,39 +123,24 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( if fast && i < fastLookupRetries { break } - - // every time, except for the last iteration - if i < lookupRetries { - // Take turns in waiting and refreshing in order to satisfy the search pattern. - if i%2 == 0 { - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - } else { - // refetch lists - table.updateTables() - if table.dualStack != nil { - table.dualStack.updateTables() - } - } - } } return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound } -func (table *tcpTable) findSocket(pktInfo *packet.Info) ( +func findTCPSocket( + pktInfo *packet.Info, + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, +) ( socketInfo socket.Info, inbound bool, ) { localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() - table.lock.RLock() - defer table.lock.RUnlock() - // always search listeners first - for _, socketInfo := range table.listeners { + for _, socketInfo := range listeners { if localPort == socketInfo.Local.Port && (socketInfo.ListensAny || localIP.Equal(socketInfo.Local.IP)) { return socketInfo, true @@ -133,7 +151,7 @@ func (table *tcpTable) findSocket(pktInfo *packet.Info) ( remotePort := pktInfo.RemotePort() // search connections - for _, socketInfo := range table.connections { + for _, socketInfo := range connections { if localPort == socketInfo.Local.Port && remotePort == socketInfo.Remote.Port && remoteIP.Equal(socketInfo.Remote.IP) && @@ -150,25 +168,67 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( inbound bool, err error, ) { - // Search pattern: search, refresh, search, wait, search, refresh, search, wait, ... - // TODO: Currently broadcast/multicast scopes are not checked, so we might // attribute an incoming broadcast/multicast packet to the wrong process if // there are multiple processes listening on the same local port, but // binding to different addresses. This highly unusual for clients. isInboundMulticast := pktInfo.Inbound && netutils.GetIPScope(pktInfo.LocalIP()) == netutils.LocalMulticast + // Prepare variables. + var ( + binds []*socket.BindInfo + updateIter uint64 + + dualStackBinds []*socket.BindInfo + dualStackUpdateIter uint64 + ) + // Search for the socket until found. for i := 1; i <= lookupRetries; i++ { - // Check main table for socket. - socketInfo := table.findSocket(pktInfo, isInboundMulticast) - if socketInfo == nil && table.dualStack != nil { - // If there was no match in the main table and we are dual-stack, check - // the dual-stack table for the socket. - socketInfo = table.dualStack.findSocket(pktInfo, isInboundMulticast) + // Get or update tables. + if i == 1 { + binds, updateIter = table.getCurrentTables() + } else { + binds, updateIter = table.updateTables(updateIter) } - // If there's a match, get the direction and check we have the PID, then return. + // Check tables for socket. + socketInfo := findUDPSocket(pktInfo, binds, isInboundMulticast) + + // If there's a match, do some last checks and return. + if socketInfo != nil { + // If there is no remote port, do check for the direction of the + // connection. This will be the case for pure checking functions + // that do not want to change direction state. + if pktInfo.RemotePort() == 0 { + return checkPID(socketInfo, pktInfo.Inbound) + } + + // Get (and save) the direction of the connection. + connInbound := table.getDirection(socketInfo, pktInfo) + + // Check we have the PID and return. + return checkPID(socketInfo, connInbound) + } + + // DUAL-STACK + + // Skip if dualStack is not enabled. + if table.dualStack == nil { + continue + } + + // Get or update tables. + if i == 0 { + dualStackBinds, dualStackUpdateIter = table.dualStack.getCurrentTables() + } else { + dualStackBinds, dualStackUpdateIter = table.dualStack.updateTables(dualStackUpdateIter) + } + + // Check tables for socket. + socketInfo = findUDPSocket(pktInfo, dualStackBinds, isInboundMulticast) + + // If there's a match, do some last checks and return. if socketInfo != nil { // If there is no remote port, do check for the direction of the // connection. This will be the case for pure checking functions @@ -188,36 +248,17 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( if fast && i < fastLookupRetries { break } - - // every time, except for the last iteration - if i < lookupRetries { - // Take turns in waiting and refreshing in order to satisfy the search pattern. - if i%2 == 0 { - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - } else { - // refetch lists - table.updateTable() - if table.dualStack != nil { - table.dualStack.updateTable() - } - } - } } return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound } -func (table *udpTable) findSocket(pktInfo *packet.Info, isInboundMulticast bool) (socketInfo *socket.BindInfo) { +func findUDPSocket(pktInfo *packet.Info, binds []*socket.BindInfo, isInboundMulticast bool) (socketInfo *socket.BindInfo) { localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() - table.lock.RLock() - defer table.lock.RUnlock() - // search binds - for _, socketInfo := range table.binds { + for _, socketInfo := range binds { if localPort == socketInfo.Local.Port && (socketInfo.ListensAny || // zero IP (dual-stack) isInboundMulticast || // inbound broadcast, multicast diff --git a/network/state/tables.go b/network/state/tables.go index e99957ab..7bf2df5b 100644 --- a/network/state/tables.go +++ b/network/state/tables.go @@ -1,48 +1 @@ package state - -import ( - "net" - - "github.com/safing/portbase/log" -) - -func (table *tcpTable) updateTables() { - table.fetchOnceAgain.Do(func() { - table.lock.Lock() - defer table.lock.Unlock() - - connections, listeners, err := table.fetchTable() - if err != nil { - log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) - return - } - - // Pre-check for any listeners. - for _, bindInfo := range listeners { - bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) - } - - table.connections = connections - table.listeners = listeners - }) -} - -func (table *udpTable) updateTable() { - table.fetchOnceAgain.Do(func() { - table.lock.Lock() - defer table.lock.Unlock() - - binds, err := table.fetchTable() - if err != nil { - log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) - return - } - - // Pre-check for any listeners. - for _, bindInfo := range binds { - bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) - } - - table.binds = binds - }) -} diff --git a/network/state/tcp.go b/network/state/tcp.go index 020de2aa..4b423a21 100644 --- a/network/state/tcp.go +++ b/network/state/tcp.go @@ -1,34 +1,44 @@ package state import ( + "net" "sync" + "sync/atomic" + "time" - "github.com/safing/portbase/utils" + "github.com/safing/portbase/log" "github.com/safing/portmaster/network/socket" ) +const maxUpdateTries = 100 + type tcpTable struct { version int connections []*socket.ConnectionInfo listeners []*socket.BindInfo + updateIter atomic.Uint64 lock sync.RWMutex - fetchOnceAgain utils.OnceAgain - fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) + fetchingLock sync.Mutex + fetchingInProgress bool + fetchingDoneSignal chan struct{} + fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) dualStack *tcpTable } var ( tcp6Table = &tcpTable{ - version: 6, - fetchTable: getTCP6Table, + version: 6, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getTCP6Table, } tcp4Table = &tcpTable{ - version: 4, - fetchTable: getTCP4Table, + version: 4, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getTCP4Table, } ) @@ -37,3 +47,100 @@ var ( func EnableTCPDualStack() { tcp4Table.dualStack = tcp6Table } + +func (table *tcpTable) getCurrentTables() ( + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, + updateIter uint64, +) { + table.lock.RLock() + defer table.lock.RUnlock() + + return table.connections, table.listeners, table.updateIter.Load() +} + +func (table *tcpTable) checkFetchingState() (fetch bool, signal chan struct{}) { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // If fetching is already in progress, just return the signal. + if table.fetchingInProgress { + return false, table.fetchingDoneSignal + } + + // Otherwise, tell caller to fetch. + table.fetchingInProgress = true + return true, nil +} + +func (table *tcpTable) signalFetchComplete() { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // Set fetching state. + table.fetchingInProgress = false + + // Signal waiting goroutines. + close(table.fetchingDoneSignal) + table.fetchingDoneSignal = make(chan struct{}) +} + +func (table *tcpTable) updateTables(previousUpdateIter uint64) ( + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, + updateIter uint64, +) { + var tries int + + // Attempt to update the tables until we get a new version of the tables. + for previousUpdateIter == table.updateIter.Load() { + // Abort if it takes too long. + tries++ + if tries > maxUpdateTries { + log.Warningf("state: failed to upate TCP%d socket table %d times", table.version, tries-1) + return table.getCurrentTables() + } + + // Check if someone is fetching or if we should fetch. + fetch, signal := table.checkFetchingState() + if fetch { + defer table.signalFetchComplete() + + // Just to be sure, check again if there is a new version. + if previousUpdateIter < table.updateIter.Load() { + return table.getCurrentTables() + } + + // Wait for 5 milliseconds. + time.Sleep(5 * time.Millisecond) + + // Fetch new tables from system. + connections, listeners, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) + // Return the current tables as fallback, as we need to trigger the defer to complete the fetch. + return table.getCurrentTables() + } + + // Pre-check for any listeners. + for _, bindInfo := range listeners { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) + } + + // Apply new tables. + table.lock.Lock() + defer table.lock.Unlock() + table.connections = connections + table.listeners = listeners + table.updateIter.Add(1) + + // Return new tables immediately. + return table.connections, table.listeners, table.updateIter.Load() + } + + // Otherwise, wait for fetch to complete. + <-signal + } + + return table.getCurrentTables() +} diff --git a/network/state/udp.go b/network/state/udp.go index 41e13d25..353ba9b6 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -2,13 +2,14 @@ package state import ( "context" + "net" "strconv" "sync" + "sync/atomic" "time" + "github.com/safing/portbase/log" "github.com/safing/portmaster/netenv" - - "github.com/safing/portbase/utils" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/socket" ) @@ -16,11 +17,14 @@ import ( type udpTable struct { version int - binds []*socket.BindInfo - lock sync.RWMutex + binds []*socket.BindInfo + updateIter atomic.Uint64 + lock sync.RWMutex - fetchOnceAgain utils.OnceAgain - fetchTable func() (binds []*socket.BindInfo, err error) + fetchingLock sync.Mutex + fetchingInProgress bool + fetchingDoneSignal chan struct{} + fetchTable func() (binds []*socket.BindInfo, err error) states map[string]map[string]*udpState statesLock sync.Mutex @@ -46,15 +50,17 @@ const ( var ( udp6Table = &udpTable{ - version: 6, - fetchTable: getUDP6Table, - states: make(map[string]map[string]*udpState), + version: 6, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getUDP6Table, + states: make(map[string]map[string]*udpState), } udp4Table = &udpTable{ - version: 4, - fetchTable: getUDP4Table, - states: make(map[string]map[string]*udpState), + version: 4, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getUDP4Table, + states: make(map[string]map[string]*udpState), } ) @@ -64,15 +70,109 @@ func EnableUDPDualStack() { udp4Table.dualStack = udp6Table } +func (table *udpTable) getCurrentTables() ( + binds []*socket.BindInfo, + updateIter uint64, +) { + table.lock.RLock() + defer table.lock.RUnlock() + + return table.binds, table.updateIter.Load() +} + +func (table *udpTable) checkFetchingState() (fetch bool, signal chan struct{}) { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // If fetching is already in progress, just return the signal. + if table.fetchingInProgress { + return false, table.fetchingDoneSignal + } + + // Otherwise, tell caller to fetch. + table.fetchingInProgress = true + return true, nil +} + +func (table *udpTable) signalFetchComplete() { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // Set fetching state. + table.fetchingInProgress = false + + // Signal waiting goroutines. + close(table.fetchingDoneSignal) + table.fetchingDoneSignal = make(chan struct{}) +} + +func (table *udpTable) updateTables(previousUpdateIter uint64) ( + binds []*socket.BindInfo, + updateIter uint64, +) { + var tries int + + // Attempt to update the tables until we get a new version of the tables. + for previousUpdateIter == table.updateIter.Load() { + // Abort if it takes too long. + tries++ + if tries > maxUpdateTries { + log.Warningf("state: failed to upate UDP%d socket table %d times", table.version, tries-1) + return table.getCurrentTables() + } + + // Check if someone is fetching or if we should fetch. + fetch, signal := table.checkFetchingState() + if fetch { + defer table.signalFetchComplete() + + // Just to be sure, check again if there is a new version. + if previousUpdateIter < table.updateIter.Load() { + return table.getCurrentTables() + } + + // Wait for 5 milliseconds. + time.Sleep(5 * time.Millisecond) + + // Fetch new tables from system. + binds, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) + // Return the current tables as fallback, as we need to trigger the defer to complete the fetch. + return table.getCurrentTables() + } + + // Pre-check for any listeners. + for _, bindInfo := range binds { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) + } + + // Apply new tables. + table.lock.Lock() + defer table.lock.Unlock() + table.binds = binds + table.updateIter.Add(1) + + // Return new tables immediately. + return table.binds, table.updateIter.Load() + } + + // Otherwise, wait for fetch to complete. + <-signal + } + + return table.getCurrentTables() +} + // CleanUDPStates cleans the udp connection states which save connection directions. func CleanUDPStates(_ context.Context) { now := time.Now().UTC() - udp4Table.updateTable() + udp4Table.updateTables(0) udp4Table.cleanStates(now) if netenv.IPv6Enabled() { - udp6Table.updateTable() + udp6Table.updateTables(0) udp6Table.cleanStates(now) } }