diff --git a/network/state/exists.go b/network/state/exists.go index 89a0fff9..ed0c48c3 100644 --- a/network/state/exists.go +++ b/network/state/exists.go @@ -37,7 +37,7 @@ func Exists(pktInfo *packet.Info, now time.Time) (exists bool) { func (table *tcpTable) exists(pktInfo *packet.Info) (exists bool) { // Update tables if older than the connection that is checked. if table.lastUpdateAt.Load() < pktInfo.SeenAt.UnixNano() { - table.updateTables(table.updateIter.Load()) + table.updateTables() } table.lock.RLock() @@ -64,7 +64,7 @@ func (table *tcpTable) exists(pktInfo *packet.Info) (exists bool) { func (table *udpTable) exists(pktInfo *packet.Info, now time.Time) (exists bool) { // Update tables if older than the connection that is checked. if table.lastUpdateAt.Load() < pktInfo.SeenAt.UnixNano() { - table.updateTables(table.updateIter.Load()) + table.updateTables() } table.lock.RLock() diff --git a/network/state/info.go b/network/state/info.go index 5dbdc788..483cd66e 100644 --- a/network/state/info.go +++ b/network/state/info.go @@ -25,12 +25,12 @@ type Info struct { func GetInfo() *Info { info := &Info{} - info.TCP4Connections, info.TCP4Listeners, _ = tcp4Table.updateTables(tcp4Table.updateIter.Load()) - info.UDP4Binds, _ = udp4Table.updateTables(udp4Table.updateIter.Load()) + info.TCP4Connections, info.TCP4Listeners = tcp4Table.updateTables() + info.UDP4Binds = udp4Table.updateTables() if netenv.IPv6Enabled() { - info.TCP6Connections, info.TCP6Listeners, _ = tcp6Table.updateTables(tcp6Table.updateIter.Load()) - info.UDP6Binds, _ = udp6Table.updateTables(udp6Table.updateIter.Load()) + info.TCP6Connections, info.TCP6Listeners = tcp6Table.updateTables() + info.UDP6Binds = udp6Table.updateTables() } info.UpdateMeta() diff --git a/network/state/lookup.go b/network/state/lookup.go index eea7ec37..42eb3da5 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -66,20 +66,18 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( var ( connections []*socket.ConnectionInfo listeners []*socket.BindInfo - updateIter uint64 dualStackConnections []*socket.ConnectionInfo dualStackListeners []*socket.BindInfo - dualStackUpdateIter uint64 ) // Search for the socket until found. for i := 1; i <= lookupTries; i++ { - // Get or update tables. + // Use existing tables for first check if packet was seen after last table update. if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() { - connections, listeners, updateIter = table.getCurrentTables() + connections, listeners = table.getCurrentTables() } else { - connections, listeners, updateIter = table.updateTables(updateIter) + connections, listeners = table.updateTables() } // Check tables for socket. @@ -97,11 +95,11 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( continue } - // Get or update tables. - if i == 0 { - dualStackConnections, dualStackListeners, dualStackUpdateIter = table.dualStack.getCurrentTables() + // Use existing tables for first check if packet was seen after last table update. + if i == 1 && pktInfo.SeenAt.UnixNano() >= table.dualStack.lastUpdateAt.Load() { + dualStackConnections, dualStackListeners = table.dualStack.getCurrentTables() } else { - dualStackConnections, dualStackListeners, dualStackUpdateIter = table.dualStack.updateTables(dualStackUpdateIter) + dualStackConnections, dualStackListeners = table.dualStack.updateTables() } // Check tables for socket. @@ -169,20 +167,17 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( // Prepare variables. var ( - binds []*socket.BindInfo - updateIter uint64 - - dualStackBinds []*socket.BindInfo - dualStackUpdateIter uint64 + binds []*socket.BindInfo + dualStackBinds []*socket.BindInfo ) // Search for the socket until found. for i := 1; i <= lookupTries; i++ { // Get or update tables. if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() { - binds, updateIter = table.getCurrentTables() + binds = table.getCurrentTables() } else { - binds, updateIter = table.updateTables(updateIter) + binds = table.updateTables() } // Check tables for socket. @@ -212,10 +207,10 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( } // Get or update tables. - if i == 0 { - dualStackBinds, dualStackUpdateIter = table.dualStack.getCurrentTables() + if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() { + dualStackBinds = table.dualStack.getCurrentTables() } else { - dualStackBinds, dualStackUpdateIter = table.dualStack.updateTables(dualStackUpdateIter) + dualStackBinds = table.dualStack.updateTables() } // Check tables for socket. diff --git a/network/state/tcp.go b/network/state/tcp.go index 35aa9ef9..5f8c03d7 100644 --- a/network/state/tcp.go +++ b/network/state/tcp.go @@ -7,10 +7,13 @@ import ( "time" "github.com/safing/portbase/log" + "github.com/safing/portbase/utils" "github.com/safing/portmaster/network/socket" ) -const maxUpdateTries = 100 +const ( + minDurationBetweenTableUpdates = 10 * time.Millisecond +) type tcpTable struct { version int @@ -19,29 +22,26 @@ type tcpTable struct { listeners []*socket.BindInfo lock sync.RWMutex - updateIter atomic.Uint64 // lastUpdateAt stores the time when the tables where last updated as unix nanoseconds. lastUpdateAt atomic.Int64 - fetchingLock sync.Mutex - fetchingInProgress bool - fetchingDoneSignal chan struct{} - fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) + fetchLimiter *utils.CallLimiter + fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) dualStack *tcpTable } var ( tcp6Table = &tcpTable{ - version: 6, - fetchingDoneSignal: make(chan struct{}), - fetchTable: getTCP6Table, + version: 6, + fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates), + fetchTable: getTCP6Table, } tcp4Table = &tcpTable{ - version: 4, - fetchingDoneSignal: make(chan struct{}), - fetchTable: getTCP4Table, + version: 4, + fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates), + fetchTable: getTCP4Table, } ) @@ -54,97 +54,38 @@ func EnableTCPDualStack() { func (table *tcpTable) getCurrentTables() ( connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, - updateIter uint64, ) { table.lock.RLock() defer table.lock.RUnlock() - return table.connections, table.listeners, table.updateIter.Load() + return table.connections, table.listeners } -func (table *tcpTable) checkFetchingState() (fetch bool, signal chan struct{}) { - table.fetchingLock.Lock() - defer table.fetchingLock.Unlock() - - // If fetching is already in progress, just return the signal. - if table.fetchingInProgress { - return false, table.fetchingDoneSignal - } - - // Otherwise, tell caller to fetch. - table.fetchingInProgress = true - return true, nil -} - -func (table *tcpTable) signalFetchComplete() { - table.fetchingLock.Lock() - defer table.fetchingLock.Unlock() - - // Set fetching state. - table.fetchingInProgress = false - - // Signal waiting goroutines. - close(table.fetchingDoneSignal) - table.fetchingDoneSignal = make(chan struct{}) -} - -func (table *tcpTable) updateTables(previousUpdateIter uint64) ( +func (table *tcpTable) updateTables() ( connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, - updateIter uint64, ) { - var tries int - - // Attempt to update the tables until we get a new version of the tables. - for previousUpdateIter == table.updateIter.Load() { - // Abort if it takes too long. - tries++ - if tries > maxUpdateTries { - log.Warningf("state: failed to upate TCP%d socket table %d times", table.version, tries-1) - return table.getCurrentTables() + // Fetch tables. + table.fetchLimiter.Do(func() { + // Fetch new tables from system. + connections, listeners, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) + return } - // Check if someone is fetching or if we should fetch. - fetch, signal := table.checkFetchingState() - if fetch { - defer table.signalFetchComplete() - - // Just to be sure, check again if there is a new version. - if previousUpdateIter < table.updateIter.Load() { - return table.getCurrentTables() - } - - // Wait for 5 milliseconds. - time.Sleep(5 * time.Millisecond) - - // Fetch new tables from system. - connections, listeners, err := table.fetchTable() - if err != nil { - log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) - // Return the current tables as fallback, as we need to trigger the defer to complete the fetch. - return table.getCurrentTables() - } - - // Pre-check for any listeners. - for _, bindInfo := range listeners { - bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) - } - - // Apply new tables. - table.lock.Lock() - defer table.lock.Unlock() - table.connections = connections - table.listeners = listeners - table.updateIter.Add(1) - table.lastUpdateAt.Store(time.Now().UnixNano()) - - // Return new tables immediately. - return table.connections, table.listeners, table.updateIter.Load() + // Pre-check for any listeners. + for _, bindInfo := range listeners { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) } - // Otherwise, wait for fetch to complete. - <-signal - } + // Apply new tables. + table.lock.Lock() + defer table.lock.Unlock() + table.connections = connections + table.listeners = listeners + table.lastUpdateAt.Store(time.Now().UnixNano()) + }) return table.getCurrentTables() } diff --git a/network/state/udp.go b/network/state/udp.go index cdfd604e..40696820 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -9,6 +9,7 @@ import ( "time" "github.com/safing/portbase/log" + "github.com/safing/portbase/utils" "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/socket" @@ -20,14 +21,11 @@ type udpTable struct { binds []*socket.BindInfo lock sync.RWMutex - updateIter atomic.Uint64 // lastUpdateAt stores the time when the tables where last updated as unix nanoseconds. lastUpdateAt atomic.Int64 - fetchingLock sync.Mutex - fetchingInProgress bool - fetchingDoneSignal chan struct{} - fetchTable func() (binds []*socket.BindInfo, err error) + fetchLimiter *utils.CallLimiter + fetchTable func() (binds []*socket.BindInfo, err error) states map[string]map[string]*udpState statesLock sync.Mutex @@ -53,17 +51,17 @@ const ( var ( udp6Table = &udpTable{ - version: 6, - fetchingDoneSignal: make(chan struct{}), - fetchTable: getUDP6Table, - states: make(map[string]map[string]*udpState), + version: 6, + fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates), + fetchTable: getUDP6Table, + states: make(map[string]map[string]*udpState), } udp4Table = &udpTable{ - version: 4, - fetchingDoneSignal: make(chan struct{}), - fetchTable: getUDP4Table, - states: make(map[string]map[string]*udpState), + version: 4, + fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates), + fetchTable: getUDP4Table, + states: make(map[string]map[string]*udpState), } ) @@ -73,97 +71,34 @@ func EnableUDPDualStack() { udp4Table.dualStack = udp6Table } -func (table *udpTable) getCurrentTables() ( - binds []*socket.BindInfo, - updateIter uint64, -) { +func (table *udpTable) getCurrentTables() (binds []*socket.BindInfo) { table.lock.RLock() defer table.lock.RUnlock() - return table.binds, table.updateIter.Load() + return table.binds } -func (table *udpTable) checkFetchingState() (fetch bool, signal chan struct{}) { - table.fetchingLock.Lock() - defer table.fetchingLock.Unlock() - - // If fetching is already in progress, just return the signal. - if table.fetchingInProgress { - return false, table.fetchingDoneSignal - } - - // Otherwise, tell caller to fetch. - table.fetchingInProgress = true - return true, nil -} - -func (table *udpTable) signalFetchComplete() { - table.fetchingLock.Lock() - defer table.fetchingLock.Unlock() - - // Set fetching state. - table.fetchingInProgress = false - - // Signal waiting goroutines. - close(table.fetchingDoneSignal) - table.fetchingDoneSignal = make(chan struct{}) -} - -func (table *udpTable) updateTables(previousUpdateIter uint64) ( - binds []*socket.BindInfo, - updateIter uint64, -) { - var tries int - - // Attempt to update the tables until we get a new version of the tables. - for previousUpdateIter == table.updateIter.Load() { - // Abort if it takes too long. - tries++ - if tries > maxUpdateTries { - log.Warningf("state: failed to upate UDP%d socket table %d times", table.version, tries-1) - return table.getCurrentTables() +func (table *udpTable) updateTables() (binds []*socket.BindInfo) { + // Fetch tables. + table.fetchLimiter.Do(func() { + // Fetch new tables from system. + binds, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) + return } - // Check if someone is fetching or if we should fetch. - fetch, signal := table.checkFetchingState() - if fetch { - defer table.signalFetchComplete() - - // Just to be sure, check again if there is a new version. - if previousUpdateIter < table.updateIter.Load() { - return table.getCurrentTables() - } - - // Wait for 5 milliseconds. - time.Sleep(5 * time.Millisecond) - - // Fetch new tables from system. - binds, err := table.fetchTable() - if err != nil { - log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) - // Return the current tables as fallback, as we need to trigger the defer to complete the fetch. - return table.getCurrentTables() - } - - // Pre-check for any listeners. - for _, bindInfo := range binds { - bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) - } - - // Apply new tables. - table.lock.Lock() - defer table.lock.Unlock() - table.binds = binds - table.updateIter.Add(1) - table.lastUpdateAt.Store(time.Now().UnixNano()) - - // Return new tables immediately. - return table.binds, table.updateIter.Load() + // Pre-check for any listeners. + for _, bindInfo := range binds { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) } - // Otherwise, wait for fetch to complete. - <-signal - } + // Apply new tables. + table.lock.Lock() + defer table.lock.Unlock() + table.binds = binds + table.lastUpdateAt.Store(time.Now().UnixNano()) + }) return table.getCurrentTables() } @@ -172,11 +107,11 @@ func (table *udpTable) updateTables(previousUpdateIter uint64) ( func CleanUDPStates(_ context.Context) { now := time.Now().UTC() - udp4Table.updateTables(udp4Table.updateIter.Load()) + udp4Table.updateTables() udp4Table.cleanStates(now) if netenv.IPv6Enabled() { - udp6Table.updateTables(udp6Table.updateIter.Load()) + udp6Table.updateTables() udp6Table.cleanStates(now) } }