Merge pull request #1860 from safing/feature/call-limiter-2
CallLimiter2
This commit is contained in:
74
base/utils/call_limiter2.go
Normal file
74
base/utils/call_limiter2.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CallLimiter2 bundles concurrent calls and optionally limits how fast a function is called.
|
||||||
|
type CallLimiter2 struct {
|
||||||
|
pause time.Duration
|
||||||
|
|
||||||
|
slot atomic.Int64
|
||||||
|
slotWait sync.RWMutex
|
||||||
|
|
||||||
|
executing atomic.Bool
|
||||||
|
lastExec time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCallLimiter2 returns a new call limiter.
|
||||||
|
// Set minPause to zero to disable the minimum pause between calls.
|
||||||
|
func NewCallLimiter2(minPause time.Duration) *CallLimiter2 {
|
||||||
|
return &CallLimiter2{
|
||||||
|
pause: minPause,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do executes the given function.
|
||||||
|
// All concurrent calls to Do are bundled and return when f() finishes.
|
||||||
|
// Waits until the minimum pause is over before executing f() again.
|
||||||
|
func (l *CallLimiter2) Do(f func()) {
|
||||||
|
// Get ticket number.
|
||||||
|
slot := l.slot.Load()
|
||||||
|
|
||||||
|
// Check if we can execute.
|
||||||
|
if l.executing.CompareAndSwap(false, true) {
|
||||||
|
// Make others wait.
|
||||||
|
l.slotWait.Lock()
|
||||||
|
defer l.slotWait.Unlock()
|
||||||
|
|
||||||
|
// Execute and return.
|
||||||
|
l.waitAndExec(f)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for slot to end and check if slot is done.
|
||||||
|
for l.slot.Load() == slot {
|
||||||
|
time.Sleep(100 * time.Microsecond)
|
||||||
|
l.slotWait.RLock()
|
||||||
|
l.slotWait.RUnlock() //nolint:staticcheck
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *CallLimiter2) waitAndExec(f func()) {
|
||||||
|
defer func() {
|
||||||
|
// Update last exec time.
|
||||||
|
l.lastExec = time.Now().UTC()
|
||||||
|
// Enable next execution first.
|
||||||
|
l.executing.Store(false)
|
||||||
|
// Move to next slot aftewards to prevent wait loops.
|
||||||
|
l.slot.Add(1)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for the minimum duration between executions.
|
||||||
|
if l.pause > 0 {
|
||||||
|
sinceLastExec := time.Since(l.lastExec)
|
||||||
|
if sinceLastExec < l.pause {
|
||||||
|
time.Sleep(l.pause - sinceLastExec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute.
|
||||||
|
f()
|
||||||
|
}
|
||||||
@@ -13,7 +13,7 @@ func TestCallLimiter(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
pause := 10 * time.Millisecond
|
pause := 10 * time.Millisecond
|
||||||
oa := NewCallLimiter(pause)
|
oa := NewCallLimiter2(pause)
|
||||||
executed := abool.New()
|
executed := abool.New()
|
||||||
var testWg sync.WaitGroup
|
var testWg sync.WaitGroup
|
||||||
|
|
||||||
@@ -41,14 +41,14 @@ func TestCallLimiter(t *testing.T) {
|
|||||||
executed.UnSet() // reset check
|
executed.UnSet() // reset check
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for pause to reset.
|
// Wait for 2x pause to reset.
|
||||||
time.Sleep(pause)
|
time.Sleep(2 * pause)
|
||||||
|
|
||||||
// Continuous use with re-execution.
|
// Continuous use with re-execution.
|
||||||
// Choose values so that about 10 executions are expected
|
// Choose values so that about 10 executions are expected
|
||||||
var execs uint32
|
var execs uint32
|
||||||
testWg.Add(200)
|
testWg.Add(100)
|
||||||
for range 200 {
|
for range 100 {
|
||||||
go func() {
|
go func() {
|
||||||
oa.Do(func() {
|
oa.Do(func() {
|
||||||
atomic.AddUint32(&execs, 1)
|
atomic.AddUint32(&execs, 1)
|
||||||
@@ -69,8 +69,8 @@ func TestCallLimiter(t *testing.T) {
|
|||||||
t.Errorf("unexpected high exec count: %d", execs)
|
t.Errorf("unexpected high exec count: %d", execs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for pause to reset.
|
// Wait for 2x pause to reset.
|
||||||
time.Sleep(pause)
|
time.Sleep(2 * pause)
|
||||||
|
|
||||||
// Check if the limiter correctly handles panics.
|
// Check if the limiter correctly handles panics.
|
||||||
testWg.Add(100)
|
testWg.Add(100)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ var (
|
|||||||
// pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs.
|
// pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs.
|
||||||
pidsByUser = make(map[int][]int)
|
pidsByUser = make(map[int][]int)
|
||||||
pidsByUserLock sync.RWMutex
|
pidsByUserLock sync.RWMutex
|
||||||
fetchPidsByUser = utils.NewCallLimiter(10 * time.Millisecond)
|
fetchPidsByUser = utils.NewCallLimiter2(10 * time.Millisecond)
|
||||||
)
|
)
|
||||||
|
|
||||||
// getPidsByUser returns the cached PIDs for the given UID.
|
// getPidsByUser returns the cached PIDs for the given UID.
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ type tcpTable struct {
|
|||||||
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
|
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
|
||||||
lastUpdateAt atomic.Int64
|
lastUpdateAt atomic.Int64
|
||||||
|
|
||||||
fetchLimiter *utils.CallLimiter
|
fetchLimiter *utils.CallLimiter2
|
||||||
fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error)
|
fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error)
|
||||||
|
|
||||||
dualStack *tcpTable
|
dualStack *tcpTable
|
||||||
@@ -34,13 +34,13 @@ type tcpTable struct {
|
|||||||
var (
|
var (
|
||||||
tcp6Table = &tcpTable{
|
tcp6Table = &tcpTable{
|
||||||
version: 6,
|
version: 6,
|
||||||
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
|
fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates),
|
||||||
fetchTable: getTCP6Table,
|
fetchTable: getTCP6Table,
|
||||||
}
|
}
|
||||||
|
|
||||||
tcp4Table = &tcpTable{
|
tcp4Table = &tcpTable{
|
||||||
version: 4,
|
version: 4,
|
||||||
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
|
fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates),
|
||||||
fetchTable: getTCP4Table,
|
fetchTable: getTCP4Table,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type udpTable struct {
|
|||||||
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
|
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
|
||||||
lastUpdateAt atomic.Int64
|
lastUpdateAt atomic.Int64
|
||||||
|
|
||||||
fetchLimiter *utils.CallLimiter
|
fetchLimiter *utils.CallLimiter2
|
||||||
fetchTable func() (binds []*socket.BindInfo, err error)
|
fetchTable func() (binds []*socket.BindInfo, err error)
|
||||||
|
|
||||||
states map[string]map[string]*udpState
|
states map[string]map[string]*udpState
|
||||||
@@ -52,14 +52,14 @@ const (
|
|||||||
var (
|
var (
|
||||||
udp6Table = &udpTable{
|
udp6Table = &udpTable{
|
||||||
version: 6,
|
version: 6,
|
||||||
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
|
fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates),
|
||||||
fetchTable: getUDP6Table,
|
fetchTable: getUDP6Table,
|
||||||
states: make(map[string]map[string]*udpState),
|
states: make(map[string]map[string]*udpState),
|
||||||
}
|
}
|
||||||
|
|
||||||
udp4Table = &udpTable{
|
udp4Table = &udpTable{
|
||||||
version: 4,
|
version: 4,
|
||||||
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
|
fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates),
|
||||||
fetchTable: getUDP4Table,
|
fetchTable: getUDP4Table,
|
||||||
states: make(map[string]map[string]*udpState),
|
states: make(map[string]map[string]*udpState),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user