diff --git a/base/utils/call_limiter2.go b/base/utils/call_limiter2.go new file mode 100644 index 00000000..30d1b1ef --- /dev/null +++ b/base/utils/call_limiter2.go @@ -0,0 +1,74 @@ +package utils + +import ( + "sync" + "sync/atomic" + "time" +) + +// CallLimiter2 bundles concurrent calls and optionally limits how fast a function is called. +type CallLimiter2 struct { + pause time.Duration + + slot atomic.Int64 + slotWait sync.RWMutex + + executing atomic.Bool + lastExec time.Time +} + +// NewCallLimiter2 returns a new call limiter. +// Set minPause to zero to disable the minimum pause between calls. +func NewCallLimiter2(minPause time.Duration) *CallLimiter2 { + return &CallLimiter2{ + pause: minPause, + } +} + +// Do executes the given function. +// All concurrent calls to Do are bundled and return when f() finishes. +// Waits until the minimum pause is over before executing f() again. +func (l *CallLimiter2) Do(f func()) { + // Get ticket number. + slot := l.slot.Load() + + // Check if we can execute. + if l.executing.CompareAndSwap(false, true) { + // Make others wait. + l.slotWait.Lock() + defer l.slotWait.Unlock() + + // Execute and return. + l.waitAndExec(f) + return + } + + // Wait for slot to end and check if slot is done. + for l.slot.Load() == slot { + time.Sleep(100 * time.Microsecond) + l.slotWait.RLock() + l.slotWait.RUnlock() //nolint:staticcheck + } +} + +func (l *CallLimiter2) waitAndExec(f func()) { + defer func() { + // Update last exec time. + l.lastExec = time.Now().UTC() + // Enable next execution first. + l.executing.Store(false) + // Move to next slot aftewards to prevent wait loops. + l.slot.Add(1) + }() + + // Wait for the minimum duration between executions. + if l.pause > 0 { + sinceLastExec := time.Since(l.lastExec) + if sinceLastExec < l.pause { + time.Sleep(l.pause - sinceLastExec) + } + } + + // Execute. + f() +} diff --git a/base/utils/call_limiter_test.go b/base/utils/call_limiter_test.go index 3144644e..2343673a 100644 --- a/base/utils/call_limiter_test.go +++ b/base/utils/call_limiter_test.go @@ -13,7 +13,7 @@ func TestCallLimiter(t *testing.T) { t.Parallel() pause := 10 * time.Millisecond - oa := NewCallLimiter(pause) + oa := NewCallLimiter2(pause) executed := abool.New() var testWg sync.WaitGroup @@ -41,14 +41,14 @@ func TestCallLimiter(t *testing.T) { executed.UnSet() // reset check } - // Wait for pause to reset. - time.Sleep(pause) + // Wait for 2x pause to reset. + time.Sleep(2 * pause) // Continuous use with re-execution. // Choose values so that about 10 executions are expected var execs uint32 - testWg.Add(200) - for range 200 { + testWg.Add(100) + for range 100 { go func() { oa.Do(func() { atomic.AddUint32(&execs, 1) @@ -69,8 +69,8 @@ func TestCallLimiter(t *testing.T) { t.Errorf("unexpected high exec count: %d", execs) } - // Wait for pause to reset. - time.Sleep(pause) + // Wait for 2x pause to reset. + time.Sleep(2 * pause) // Check if the limiter correctly handles panics. testWg.Add(100) diff --git a/service/network/proc/pids_by_user.go b/service/network/proc/pids_by_user.go index 7fd451e7..8ea43dc4 100644 --- a/service/network/proc/pids_by_user.go +++ b/service/network/proc/pids_by_user.go @@ -19,7 +19,7 @@ var ( // pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs. pidsByUser = make(map[int][]int) pidsByUserLock sync.RWMutex - fetchPidsByUser = utils.NewCallLimiter(10 * time.Millisecond) + fetchPidsByUser = utils.NewCallLimiter2(10 * time.Millisecond) ) // getPidsByUser returns the cached PIDs for the given UID. diff --git a/service/network/state/tcp.go b/service/network/state/tcp.go index 5d08b054..e3883847 100644 --- a/service/network/state/tcp.go +++ b/service/network/state/tcp.go @@ -25,7 +25,7 @@ type tcpTable struct { // lastUpdateAt stores the time when the tables where last updated as unix nanoseconds. lastUpdateAt atomic.Int64 - fetchLimiter *utils.CallLimiter + fetchLimiter *utils.CallLimiter2 fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) dualStack *tcpTable @@ -34,13 +34,13 @@ type tcpTable struct { var ( tcp6Table = &tcpTable{ version: 6, - fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates), + fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates), fetchTable: getTCP6Table, } tcp4Table = &tcpTable{ version: 4, - fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates), + fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates), fetchTable: getTCP4Table, } ) diff --git a/service/network/state/udp.go b/service/network/state/udp.go index 1c534b7f..0dcecc84 100644 --- a/service/network/state/udp.go +++ b/service/network/state/udp.go @@ -24,7 +24,7 @@ type udpTable struct { // lastUpdateAt stores the time when the tables where last updated as unix nanoseconds. lastUpdateAt atomic.Int64 - fetchLimiter *utils.CallLimiter + fetchLimiter *utils.CallLimiter2 fetchTable func() (binds []*socket.BindInfo, err error) states map[string]map[string]*udpState @@ -52,14 +52,14 @@ const ( var ( udp6Table = &udpTable{ version: 6, - fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates), + fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates), fetchTable: getUDP6Table, states: make(map[string]map[string]*udpState), } udp4Table = &udpTable{ version: 4, - fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates), + fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates), fetchTable: getUDP4Table, states: make(map[string]map[string]*udpState), }