From 4b29eed2cf68e92d45b35d5e81fbfc2949dee0e7 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 29 Mar 2023 14:57:32 +0200 Subject: [PATCH 1/8] Improve support for disabled IPv6 stack --- network/module.go | 9 ++++++++- network/state/info.go | 24 ++++++++++++++---------- network/state/tcp.go | 7 ++++++- network/state/udp.go | 15 ++++++++++++--- 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/network/module.go b/network/module.go index 23022dbc..1a7ee708 100644 --- a/network/module.go +++ b/network/module.go @@ -2,6 +2,8 @@ package network import ( "github.com/safing/portbase/modules" + "github.com/safing/portmaster/netenv" + "github.com/safing/portmaster/network/state" ) var ( @@ -11,7 +13,7 @@ var ( ) func init() { - module = modules.Register("network", prep, start, nil, "base", "processes") + module = modules.Register("network", prep, start, nil, "base", "netenv", "processes") } // SetDefaultFirewallHandler sets the default firewall handler. @@ -22,6 +24,11 @@ func SetDefaultFirewallHandler(handler FirewallHandler) { } func prep() error { + if netenv.IPv6Enabled() { + state.EnableTCPDualStack() + state.EnableUDPDualStack() + } + return registerAPIEndpoints() } diff --git a/network/state/info.go b/network/state/info.go index 6f97a74e..0b75d37a 100644 --- a/network/state/info.go +++ b/network/state/info.go @@ -3,6 +3,8 @@ package state import ( "sync" + "github.com/safing/portmaster/netenv" + "github.com/safing/portbase/database/record" "github.com/safing/portmaster/network/socket" ) @@ -30,21 +32,23 @@ func GetInfo() *Info { info.TCP4Listeners = tcp4Table.listeners tcp4Table.lock.RUnlock() - tcp6Table.updateTables() - tcp6Table.lock.RLock() - info.TCP6Connections = tcp6Table.connections - info.TCP6Listeners = tcp6Table.listeners - tcp6Table.lock.RUnlock() - udp4Table.updateTable() udp4Table.lock.RLock() info.UDP4Binds = udp4Table.binds udp4Table.lock.RUnlock() - udp6Table.updateTable() - udp6Table.lock.RLock() - info.UDP6Binds = udp6Table.binds - udp6Table.lock.RUnlock() + if netenv.IPv6Enabled() { + tcp6Table.updateTables() + tcp6Table.lock.RLock() + info.TCP6Connections = tcp6Table.connections + info.TCP6Listeners = tcp6Table.listeners + tcp6Table.lock.RUnlock() + + udp6Table.updateTable() + udp6Table.lock.RLock() + info.UDP6Binds = udp6Table.binds + udp6Table.lock.RUnlock() + } info.UpdateMeta() return info diff --git a/network/state/tcp.go b/network/state/tcp.go index 8b04cd4d..020de2aa 100644 --- a/network/state/tcp.go +++ b/network/state/tcp.go @@ -29,6 +29,11 @@ var ( tcp4Table = &tcpTable{ version: 4, fetchTable: getTCP4Table, - dualStack: tcp6Table, } ) + +// EnableTCPDualStack adds the TCP6 table to the TCP4 table as a dual-stack. +// Must be called before any lookup operation. +func EnableTCPDualStack() { + tcp4Table.dualStack = tcp6Table +} diff --git a/network/state/udp.go b/network/state/udp.go index 0302be14..41e13d25 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -6,6 +6,8 @@ import ( "sync" "time" + "github.com/safing/portmaster/netenv" + "github.com/safing/portbase/utils" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/socket" @@ -53,10 +55,15 @@ var ( version: 4, fetchTable: getUDP4Table, states: make(map[string]map[string]*udpState), - dualStack: udp6Table, } ) +// EnableUDPDualStack adds the UDP6 table to the UDP4 table as a dual-stack. +// Must be called before any lookup operation. +func EnableUDPDualStack() { + udp4Table.dualStack = udp6Table +} + // CleanUDPStates cleans the udp connection states which save connection directions. func CleanUDPStates(_ context.Context) { now := time.Now().UTC() @@ -64,8 +71,10 @@ func CleanUDPStates(_ context.Context) { udp4Table.updateTable() udp4Table.cleanStates(now) - udp6Table.updateTable() - udp6Table.cleanStates(now) + if netenv.IPv6Enabled() { + udp6Table.updateTable() + udp6Table.cleanStates(now) + } } func (table *udpTable) getConnState( From b9488c1a8ff6ac0b3bfbe567c9a49501ea901dc2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 30 Mar 2023 11:36:44 +0200 Subject: [PATCH 2/8] Fix matching TCP connections --- network/state/lookup.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/network/state/lookup.go b/network/state/lookup.go index 46aac9a6..0586b17b 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -129,9 +129,14 @@ func (table *tcpTable) findSocket(pktInfo *packet.Info) ( } } + remoteIP := pktInfo.RemoteIP() + remotePort := pktInfo.RemotePort() + // search connections for _, socketInfo := range table.connections { if localPort == socketInfo.Local.Port && + remotePort == socketInfo.Remote.Port && + remoteIP.Equal(socketInfo.Remote.IP) && localIP.Equal(socketInfo.Local.IP) { return socketInfo, false } From 834282cb0e82bb2908c26e81c6cdab6eba8eb8d2 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Thu, 30 Mar 2023 14:43:38 +0200 Subject: [PATCH 3/8] Fix process identification key --- intel/module.go | 2 +- netquery/manager.go | 3 +- network/connection.go | 5 +- network/database.go | 25 ++++------ network/iphelper/iphelper.go | 2 +- process/database.go | 12 ++--- process/process.go | 90 +++++++++++++++++++++++++----------- process/special.go | 41 ++++++++-------- 8 files changed, 107 insertions(+), 73 deletions(-) diff --git a/intel/module.go b/intel/module.go index ceec6b64..81f3efc5 100644 --- a/intel/module.go +++ b/intel/module.go @@ -9,5 +9,5 @@ import ( var Module *modules.Module func init() { - Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists", "customlists") + Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists") } diff --git a/netquery/manager.go b/netquery/manager.go index fb244526..16dcbe0e 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -197,7 +197,8 @@ func convertConnection(conn *network.Connection) (*Conn, error) { } extraData := map[string]interface{}{ - "pid": conn.ProcessContext.PID, + "pid": conn.ProcessContext.PID, + "processCreatedAt": conn.ProcessContext.CreatedAt, } if conn.TunnelContext != nil { diff --git a/network/connection.go b/network/connection.go index a1af8c06..f9b2379d 100644 --- a/network/connection.go +++ b/network/connection.go @@ -26,7 +26,7 @@ import ( type FirewallHandler func(conn *Connection, pkt packet.Packet) // ProcessContext holds additional information about the process -// that iniated a connection. +// that initiated a connection. type ProcessContext struct { // ProcessName is the name of the process. ProcessName string @@ -38,6 +38,8 @@ type ProcessContext struct { CmdLine string // PID is the process identifier. PID int + // CreatedAt the time when the process was created. + CreatedAt int64 // Profile is the ID of the main profile that // is applied to the process. Profile string @@ -223,6 +225,7 @@ func getProcessContext(ctx context.Context, proc *process.Process) ProcessContex BinaryPath: proc.Path, CmdLine: proc.CmdLine, PID: proc.Pid, + CreatedAt: proc.CreatedAt, } // Get local profile. diff --git a/network/database.go b/network/database.go index 20c4cbfb..159ba1d3 100644 --- a/network/database.go +++ b/network/database.go @@ -45,7 +45,7 @@ func makeKey(pid int, scope, id string) string { return fmt.Sprintf("network:tree/%d/%s/%s", pid, scope, id) } -func parseDBKey(key string) (pid int, scope, id string, ok bool) { +func parseDBKey(key string) (processKey string, scope, id string, ok bool) { // Split into segments. segments := strings.Split(key, "/") @@ -65,27 +65,18 @@ func parseDBKey(key string) (pid int, scope, id string, ok bool) { // TODO: For queries, also prefixes of these values are valid. default: // Unknown scope. - return 0, "", "", false + return "", "", "", false } fallthrough case 2: - var err error - if segments[1] == "" { - pid = process.UndefinedProcessID - } else { - pid, err = strconv.Atoi(segments[1]) - if err != nil { - return 0, "", "", false - } - } - - return pid, scope, id, true + processKey = segments[1] + return processKey, scope, id, true case 1: // This is a valid query prefix, but not process ID was given. - return process.UndefinedProcessID, "", "", true + return "", "", "", true default: - return 0, "", "", false + return "", "", "", false } } @@ -93,7 +84,7 @@ func parseDBKey(key string) (pid int, scope, id string, ok bool) { func (s *StorageInterface) Get(key string) (record.Record, error) { // Parse key and check if valid. pid, scope, id, ok := parseDBKey(strings.TrimPrefix(key, "network:")) - if !ok || pid == process.UndefinedProcessID { + if !ok || pid == "" { return nil, storage.ErrNotFound } @@ -135,7 +126,7 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { return } - if pid == process.UndefinedProcessID { + if pid == "" { // processes for _, proc := range process.All() { func() { diff --git a/network/iphelper/iphelper.go b/network/iphelper/iphelper.go index 5fa1ea12..2256434b 100644 --- a/network/iphelper/iphelper.go +++ b/network/iphelper/iphelper.go @@ -11,7 +11,7 @@ import ( ) var ( - errInvalid = errors.New("IPHelper not initialzed or broken") + errInvalid = errors.New("IPHelper not initialized or broken") ) // IPHelper represents a subset of the Windows iphlpapi.dll. diff --git a/process/database.go b/process/database.go index 62df8f09..b457d069 100644 --- a/process/database.go +++ b/process/database.go @@ -15,7 +15,7 @@ import ( const processDatabaseNamespace = "network:tree" var ( - processes = make(map[int]*Process) + processes = make(map[string]*Process) processesLock sync.RWMutex dbController *database.Controller @@ -25,11 +25,11 @@ var ( ) // GetProcessFromStorage returns a process from the internal storage. -func GetProcessFromStorage(pid int) (*Process, bool) { +func GetProcessFromStorage(key string) (*Process, bool) { processesLock.RLock() defer processesLock.RUnlock() - p, ok := processes[pid] + p, ok := processes[key] return p, ok } @@ -55,11 +55,11 @@ func (p *Process) Save() { if !p.KeyIsSet() { // set key - p.SetKey(fmt.Sprintf("%s/%d", processDatabaseNamespace, p.Pid)) + p.SetKey(fmt.Sprintf("%s/%s", processDatabaseNamespace, getProcessKey(int32(p.Pid), p.CreatedAt))) // save processesLock.Lock() - processes[p.Pid] = p + processes[p.key] = p processesLock.Unlock() } @@ -75,7 +75,7 @@ func (p *Process) Delete() { // delete from internal storage processesLock.Lock() - delete(processes, p.Pid) + delete(processes, p.key) processesLock.Unlock() // propagate delete diff --git a/process/process.go b/process/process.go index 9f3ad2e4..777a6aa9 100644 --- a/process/process.go +++ b/process/process.go @@ -6,7 +6,6 @@ import ( "fmt" "path/filepath" "runtime" - "strconv" "strings" "sync" "time" @@ -31,18 +30,23 @@ type Process struct { // Process attributes. // Don't change; safe for concurrent access. - Name string - UserID int - UserName string - UserHome string - Pid int - ParentPid int - Path string - ExecName string - Cwd string - CmdLine string - FirstArg string - Env map[string]string + Name string + UserID int + UserName string + UserHome string + Pid int + CreatedAt int64 + ParentPid int + ParentCreatedAt int64 + Path string + ExecName string + Cwd string + CmdLine string + FirstArg string + Env map[string]string + + // unique process identifier ("Pid-CreatedAt") + key string // Profile attributes. // Once set, these don't change; safe for concurrent access. @@ -156,8 +160,21 @@ func (p *Process) String() string { func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { log.Tracer(ctx).Tracef("process: getting process for PID %d", pid) - p, err, _ := getProcessSingleInflight.Do(strconv.Itoa(pid), func() (interface{}, error) { - return loadProcess(ctx, pid) + // Get pid and created time for identification. + pInfo, err := processInfo.NewProcessWithContext(ctx, int32(pid)) + if err != nil { + return nil, err + } + + createdTime, err := pInfo.CreateTimeWithContext(ctx) + if err != nil { + return nil, err + } + + key := getProcessKey(int32(pid), createdTime) + + p, err, _ := getProcessSingleInflight.Do(key, func() (interface{}, error) { + return loadProcess(ctx, key, pInfo) }) if err != nil { return nil, err @@ -169,8 +186,8 @@ func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { return p.(*Process), nil // nolint:forcetypeassert // Can only be a *Process. } -func loadProcess(ctx context.Context, pid int) (*Process, error) { - switch pid { +func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*Process, error) { + switch pInfo.Pid { case UnidentifiedProcessID: return GetUnidentifiedProcess(ctx), nil case UnsolicitedProcessID: @@ -179,19 +196,24 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { return GetSystemProcess(ctx), nil } - process, ok := GetProcessFromStorage(pid) + // Get created time of process. The value should be cached. + createdAt, _ := pInfo.CreateTimeWithContext(ctx) + + process, ok := GetProcessFromStorage(getProcessKey(pInfo.Pid, createdAt)) if ok { return process, nil } // Create new a process object. process = &Process{ - Pid: pid, + Pid: int(pInfo.Pid), + CreatedAt: createdAt, FirstSeen: time.Now().Unix(), + key: key, } // Get process information from the system. - pInfo, err := processInfo.NewProcessWithContext(ctx, int32(pid)) + pInfo, err := processInfo.NewProcessWithContext(ctx, pInfo.Pid) if err != nil { return nil, err } @@ -202,7 +224,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { var uids []int32 uids, err = pInfo.UidsWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get UID for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get UID for p%d: %w", pInfo.Pid, err) } process.UserID = int(uids[0]) } @@ -210,23 +232,30 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // Username process.UserName, err = pInfo.UsernameWithContext(ctx) if err != nil { - return nil, fmt.Errorf("process: failed to get Username for p%d: %w", pid, err) + return nil, fmt.Errorf("process: failed to get Username for p%d: %w", pInfo.Pid, err) } // TODO: User Home // new.UserHome, err = - // PPID + // Parent process id ppid, err := pInfo.PpidWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get PPID for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get PPID for p%d: %w", pInfo.Pid, err) } process.ParentPid = int(ppid) + // Parent created at time + parentCreatedAt, err := pInfo.CreateTimeWithContext(ctx) + if err != nil { + return nil, err + } + process.ParentCreatedAt = parentCreatedAt + // Path process.Path, err = pInfo.ExeWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get Path for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get Path for p%d: %w", pInfo.Pid, err) } // remove linux " (deleted)" suffix for deleted files if onLinux { @@ -247,13 +276,13 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // Command line arguments process.CmdLine, err = pInfo.CmdlineWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get Cmdline for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get Cmdline for p%d: %w", pInfo.Pid, err) } // Name process.Name, err = pInfo.NameWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get Name for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get Name for p%d: %w", pInfo.Pid, err) } if process.Name == "" { process.Name = process.ExecName @@ -262,7 +291,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // Get all environment variables env, err := pInfo.EnvironWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get the environment for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get the environment for p%d: %w", pInfo.Pid, err) } // Split env variables in key and value. process.Env = make(map[string]string, len(env)) @@ -283,6 +312,11 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { return process, nil } +// Builds a unique identifier for a processes. +func getProcessKey(pid int32, createdTime int64) string { + return fmt.Sprintf("%d-%d", pid, createdTime) +} + // MatchingData returns the matching data for the process. func (p *Process) MatchingData() *MatchingData { return &MatchingData{p} diff --git a/process/special.go b/process/special.go index 7d2c3e93..6ecd2d0b 100644 --- a/process/special.go +++ b/process/special.go @@ -2,7 +2,6 @@ package process import ( "context" - "strconv" "time" "golang.org/x/sync/singleflight" @@ -39,29 +38,35 @@ func init() { var ( // unidentifiedProcess is used for non-attributed outgoing connections. unidentifiedProcess = &Process{ - UserID: UnidentifiedProcessID, - UserName: "Unknown", - Pid: UnidentifiedProcessID, - ParentPid: UnidentifiedProcessID, - Name: profile.UnidentifiedProfileName, + UserID: UnidentifiedProcessID, + UserName: "Unknown", + Pid: UnidentifiedProcessID, + CreatedAt: 1, + ParentPid: UnidentifiedProcessID, + ParentCreatedAt: 1, + Name: profile.UnidentifiedProfileName, } // unsolicitedProcess is used for non-attributed incoming connections. unsolicitedProcess = &Process{ - UserID: UnsolicitedProcessID, - UserName: "Unknown", - Pid: UnsolicitedProcessID, - ParentPid: UnsolicitedProcessID, - Name: profile.UnsolicitedProfileName, + UserID: UnsolicitedProcessID, + UserName: "Unknown", + Pid: UnsolicitedProcessID, + CreatedAt: 1, + ParentPid: UnsolicitedProcessID, + ParentCreatedAt: 1, + Name: profile.UnsolicitedProfileName, } // systemProcess is used to represent the Kernel. systemProcess = &Process{ - UserID: SystemProcessID, - UserName: "Kernel", - Pid: SystemProcessID, - ParentPid: SystemProcessID, - Name: profile.SystemProfileName, + UserID: SystemProcessID, + UserName: "Kernel", + Pid: SystemProcessID, + CreatedAt: 1, + ParentPid: SystemProcessID, + ParentCreatedAt: 1, + Name: profile.SystemProfileName, } getSpecialProcessSingleInflight singleflight.Group @@ -83,9 +88,9 @@ func GetSystemProcess(ctx context.Context) *Process { } func getSpecialProcess(ctx context.Context, template *Process) *Process { - p, _, _ := getSpecialProcessSingleInflight.Do(strconv.Itoa(template.Pid), func() (interface{}, error) { + p, _, _ := getSpecialProcessSingleInflight.Do(template.key, func() (interface{}, error) { // Check if we have already loaded the special process. - process, ok := GetProcessFromStorage(template.Pid) + process, ok := GetProcessFromStorage(template.key) if ok { return process, nil } From 79d6bcb6c605b87cb23b353f3fc528199f49a511 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Thu, 30 Mar 2023 15:18:18 +0200 Subject: [PATCH 4/8] Fix parent process key an refactoring --- intel/module.go | 2 +- process/database.go | 4 ++-- process/process.go | 38 ++++++++++++++++++++++---------------- process/special.go | 4 ++-- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/intel/module.go b/intel/module.go index 81f3efc5..ceec6b64 100644 --- a/intel/module.go +++ b/intel/module.go @@ -9,5 +9,5 @@ import ( var Module *modules.Module func init() { - Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists") + Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists", "customlists") } diff --git a/process/database.go b/process/database.go index b457d069..4def9b01 100644 --- a/process/database.go +++ b/process/database.go @@ -59,7 +59,7 @@ func (p *Process) Save() { // save processesLock.Lock() - processes[p.key] = p + processes[p.processKey] = p processesLock.Unlock() } @@ -75,7 +75,7 @@ func (p *Process) Delete() { // delete from internal storage processesLock.Lock() - delete(processes, p.key) + delete(processes, p.processKey) processesLock.Unlock() // propagate delete diff --git a/process/process.go b/process/process.go index 777a6aa9..f7d12973 100644 --- a/process/process.go +++ b/process/process.go @@ -46,7 +46,7 @@ type Process struct { Env map[string]string // unique process identifier ("Pid-CreatedAt") - key string + processKey string // Profile attributes. // Once set, these don't change; safe for concurrent access. @@ -160,6 +160,16 @@ func (p *Process) String() string { func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { log.Tracer(ctx).Tracef("process: getting process for PID %d", pid) + // Check for special processes + switch pid { + case UnidentifiedProcessID: + return GetUnidentifiedProcess(ctx), nil + case UnsolicitedProcessID: + return GetUnsolicitedProcess(ctx), nil + case SystemProcessID: + return GetSystemProcess(ctx), nil + } + // Get pid and created time for identification. pInfo, err := processInfo.NewProcessWithContext(ctx, int32(pid)) if err != nil { @@ -187,29 +197,20 @@ func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { } func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*Process, error) { - switch pInfo.Pid { - case UnidentifiedProcessID: - return GetUnidentifiedProcess(ctx), nil - case UnsolicitedProcessID: - return GetUnsolicitedProcess(ctx), nil - case SystemProcessID: - return GetSystemProcess(ctx), nil - } - // Get created time of process. The value should be cached. createdAt, _ := pInfo.CreateTimeWithContext(ctx) - process, ok := GetProcessFromStorage(getProcessKey(pInfo.Pid, createdAt)) + process, ok := GetProcessFromStorage(key) if ok { return process, nil } // Create new a process object. process = &Process{ - Pid: int(pInfo.Pid), - CreatedAt: createdAt, - FirstSeen: time.Now().Unix(), - key: key, + Pid: int(pInfo.Pid), + CreatedAt: createdAt, + FirstSeen: time.Now().Unix(), + processKey: key, } // Get process information from the system. @@ -246,7 +247,12 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (* process.ParentPid = int(ppid) // Parent created at time - parentCreatedAt, err := pInfo.CreateTimeWithContext(ctx) + parentPInfo, err := processInfo.NewProcessWithContext(ctx, ppid) + if err != nil { + return nil, err + } + + parentCreatedAt, err := parentPInfo.CreateTimeWithContext(ctx) if err != nil { return nil, err } diff --git a/process/special.go b/process/special.go index 6ecd2d0b..93f17c9a 100644 --- a/process/special.go +++ b/process/special.go @@ -88,9 +88,9 @@ func GetSystemProcess(ctx context.Context) *Process { } func getSpecialProcess(ctx context.Context, template *Process) *Process { - p, _, _ := getSpecialProcessSingleInflight.Do(template.key, func() (interface{}, error) { + p, _, _ := getSpecialProcessSingleInflight.Do(template.processKey, func() (interface{}, error) { // Check if we have already loaded the special process. - process, ok := GetProcessFromStorage(template.key) + process, ok := GetProcessFromStorage(template.processKey) if ok { return process, nil } From 94b4fa3bfd65e776664cfd2d1e5f3dc1855de2a9 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 30 Mar 2023 15:53:47 +0200 Subject: [PATCH 5/8] Improve network state caching --- network/state/info.go | 27 ++----- network/state/lookup.go | 161 +++++++++++++++++++++++++--------------- network/state/tables.go | 47 ------------ network/state/tcp.go | 121 ++++++++++++++++++++++++++++-- network/state/udp.go | 128 ++++++++++++++++++++++++++++---- 5 files changed, 334 insertions(+), 150 deletions(-) diff --git a/network/state/info.go b/network/state/info.go index 0b75d37a..818efda3 100644 --- a/network/state/info.go +++ b/network/state/info.go @@ -3,9 +3,8 @@ package state import ( "sync" - "github.com/safing/portmaster/netenv" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network/socket" ) @@ -26,28 +25,12 @@ type Info struct { func GetInfo() *Info { info := &Info{} - tcp4Table.updateTables() - tcp4Table.lock.RLock() - info.TCP4Connections = tcp4Table.connections - info.TCP4Listeners = tcp4Table.listeners - tcp4Table.lock.RUnlock() - - udp4Table.updateTable() - udp4Table.lock.RLock() - info.UDP4Binds = udp4Table.binds - udp4Table.lock.RUnlock() + info.TCP4Connections, info.TCP4Listeners, _ = tcp4Table.updateTables(0) + info.UDP4Binds, _ = udp4Table.updateTables(0) if netenv.IPv6Enabled() { - tcp6Table.updateTables() - tcp6Table.lock.RLock() - info.TCP6Connections = tcp6Table.connections - info.TCP6Listeners = tcp6Table.listeners - tcp6Table.lock.RUnlock() - - udp6Table.updateTable() - udp6Table.lock.RLock() - info.UDP6Binds = udp6Table.binds - udp6Table.lock.RUnlock() + info.TCP6Connections, info.TCP6Listeners, _ = tcp6Table.updateTables(tcp6Table.updateIter.Load()) + info.UDP6Binds, _ = udp6Table.updateTables(0) } info.UpdateMeta() diff --git a/network/state/lookup.go b/network/state/lookup.go index 0586b17b..a764e522 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -69,19 +69,52 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( inbound bool, err error, ) { - // Search pattern: search, refresh, search, wait, search, refresh, search, wait, ... + // Prepare variables. + var ( + connections []*socket.ConnectionInfo + listeners []*socket.BindInfo + updateIter uint64 + + dualStackConnections []*socket.ConnectionInfo + dualStackListeners []*socket.BindInfo + dualStackUpdateIter uint64 + ) // Search for the socket until found. for i := 1; i <= lookupRetries; i++ { - // Check main table for socket. - socketInfo, inbound := table.findSocket(pktInfo) - if socketInfo == nil && table.dualStack != nil { - // If there was no match in the main table and we are dual-stack, check - // the dual-stack table for the socket. - socketInfo, inbound = table.dualStack.findSocket(pktInfo) + // Get or update tables. + if i == 1 { + connections, listeners, updateIter = table.getCurrentTables() + } else { + connections, listeners, updateIter = table.updateTables(updateIter) } - // If there's a match, check we have the PID and return. + // Check tables for socket. + socketInfo, inbound := findTCPSocket(pktInfo, connections, listeners) + + // If there's a match, check if we have the PID and return. + if socketInfo != nil { + return checkPID(socketInfo, inbound) + } + + // DUAL-STACK + + // Skip if dualStack is not enabled. + if table.dualStack == nil { + continue + } + + // Get or update tables. + if i == 0 { + dualStackConnections, dualStackListeners, dualStackUpdateIter = table.dualStack.getCurrentTables() + } else { + dualStackConnections, dualStackListeners, dualStackUpdateIter = table.dualStack.updateTables(dualStackUpdateIter) + } + + // Check tables for socket. + socketInfo, inbound = findTCPSocket(pktInfo, dualStackConnections, dualStackListeners) + + // If there's a match, check if we have the PID and return. if socketInfo != nil { return checkPID(socketInfo, inbound) } @@ -90,39 +123,24 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( if fast && i < fastLookupRetries { break } - - // every time, except for the last iteration - if i < lookupRetries { - // Take turns in waiting and refreshing in order to satisfy the search pattern. - if i%2 == 0 { - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - } else { - // refetch lists - table.updateTables() - if table.dualStack != nil { - table.dualStack.updateTables() - } - } - } } return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound } -func (table *tcpTable) findSocket(pktInfo *packet.Info) ( +func findTCPSocket( + pktInfo *packet.Info, + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, +) ( socketInfo socket.Info, inbound bool, ) { localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() - table.lock.RLock() - defer table.lock.RUnlock() - // always search listeners first - for _, socketInfo := range table.listeners { + for _, socketInfo := range listeners { if localPort == socketInfo.Local.Port && (socketInfo.ListensAny || localIP.Equal(socketInfo.Local.IP)) { return socketInfo, true @@ -133,7 +151,7 @@ func (table *tcpTable) findSocket(pktInfo *packet.Info) ( remotePort := pktInfo.RemotePort() // search connections - for _, socketInfo := range table.connections { + for _, socketInfo := range connections { if localPort == socketInfo.Local.Port && remotePort == socketInfo.Remote.Port && remoteIP.Equal(socketInfo.Remote.IP) && @@ -150,25 +168,67 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( inbound bool, err error, ) { - // Search pattern: search, refresh, search, wait, search, refresh, search, wait, ... - // TODO: Currently broadcast/multicast scopes are not checked, so we might // attribute an incoming broadcast/multicast packet to the wrong process if // there are multiple processes listening on the same local port, but // binding to different addresses. This highly unusual for clients. isInboundMulticast := pktInfo.Inbound && netutils.GetIPScope(pktInfo.LocalIP()) == netutils.LocalMulticast + // Prepare variables. + var ( + binds []*socket.BindInfo + updateIter uint64 + + dualStackBinds []*socket.BindInfo + dualStackUpdateIter uint64 + ) + // Search for the socket until found. for i := 1; i <= lookupRetries; i++ { - // Check main table for socket. - socketInfo := table.findSocket(pktInfo, isInboundMulticast) - if socketInfo == nil && table.dualStack != nil { - // If there was no match in the main table and we are dual-stack, check - // the dual-stack table for the socket. - socketInfo = table.dualStack.findSocket(pktInfo, isInboundMulticast) + // Get or update tables. + if i == 1 { + binds, updateIter = table.getCurrentTables() + } else { + binds, updateIter = table.updateTables(updateIter) } - // If there's a match, get the direction and check we have the PID, then return. + // Check tables for socket. + socketInfo := findUDPSocket(pktInfo, binds, isInboundMulticast) + + // If there's a match, do some last checks and return. + if socketInfo != nil { + // If there is no remote port, do check for the direction of the + // connection. This will be the case for pure checking functions + // that do not want to change direction state. + if pktInfo.RemotePort() == 0 { + return checkPID(socketInfo, pktInfo.Inbound) + } + + // Get (and save) the direction of the connection. + connInbound := table.getDirection(socketInfo, pktInfo) + + // Check we have the PID and return. + return checkPID(socketInfo, connInbound) + } + + // DUAL-STACK + + // Skip if dualStack is not enabled. + if table.dualStack == nil { + continue + } + + // Get or update tables. + if i == 0 { + dualStackBinds, dualStackUpdateIter = table.dualStack.getCurrentTables() + } else { + dualStackBinds, dualStackUpdateIter = table.dualStack.updateTables(dualStackUpdateIter) + } + + // Check tables for socket. + socketInfo = findUDPSocket(pktInfo, dualStackBinds, isInboundMulticast) + + // If there's a match, do some last checks and return. if socketInfo != nil { // If there is no remote port, do check for the direction of the // connection. This will be the case for pure checking functions @@ -188,36 +248,17 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( if fast && i < fastLookupRetries { break } - - // every time, except for the last iteration - if i < lookupRetries { - // Take turns in waiting and refreshing in order to satisfy the search pattern. - if i%2 == 0 { - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - } else { - // refetch lists - table.updateTable() - if table.dualStack != nil { - table.dualStack.updateTable() - } - } - } } return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound } -func (table *udpTable) findSocket(pktInfo *packet.Info, isInboundMulticast bool) (socketInfo *socket.BindInfo) { +func findUDPSocket(pktInfo *packet.Info, binds []*socket.BindInfo, isInboundMulticast bool) (socketInfo *socket.BindInfo) { localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() - table.lock.RLock() - defer table.lock.RUnlock() - // search binds - for _, socketInfo := range table.binds { + for _, socketInfo := range binds { if localPort == socketInfo.Local.Port && (socketInfo.ListensAny || // zero IP (dual-stack) isInboundMulticast || // inbound broadcast, multicast diff --git a/network/state/tables.go b/network/state/tables.go index e99957ab..7bf2df5b 100644 --- a/network/state/tables.go +++ b/network/state/tables.go @@ -1,48 +1 @@ package state - -import ( - "net" - - "github.com/safing/portbase/log" -) - -func (table *tcpTable) updateTables() { - table.fetchOnceAgain.Do(func() { - table.lock.Lock() - defer table.lock.Unlock() - - connections, listeners, err := table.fetchTable() - if err != nil { - log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) - return - } - - // Pre-check for any listeners. - for _, bindInfo := range listeners { - bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) - } - - table.connections = connections - table.listeners = listeners - }) -} - -func (table *udpTable) updateTable() { - table.fetchOnceAgain.Do(func() { - table.lock.Lock() - defer table.lock.Unlock() - - binds, err := table.fetchTable() - if err != nil { - log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) - return - } - - // Pre-check for any listeners. - for _, bindInfo := range binds { - bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) - } - - table.binds = binds - }) -} diff --git a/network/state/tcp.go b/network/state/tcp.go index 020de2aa..4b423a21 100644 --- a/network/state/tcp.go +++ b/network/state/tcp.go @@ -1,34 +1,44 @@ package state import ( + "net" "sync" + "sync/atomic" + "time" - "github.com/safing/portbase/utils" + "github.com/safing/portbase/log" "github.com/safing/portmaster/network/socket" ) +const maxUpdateTries = 100 + type tcpTable struct { version int connections []*socket.ConnectionInfo listeners []*socket.BindInfo + updateIter atomic.Uint64 lock sync.RWMutex - fetchOnceAgain utils.OnceAgain - fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) + fetchingLock sync.Mutex + fetchingInProgress bool + fetchingDoneSignal chan struct{} + fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) dualStack *tcpTable } var ( tcp6Table = &tcpTable{ - version: 6, - fetchTable: getTCP6Table, + version: 6, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getTCP6Table, } tcp4Table = &tcpTable{ - version: 4, - fetchTable: getTCP4Table, + version: 4, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getTCP4Table, } ) @@ -37,3 +47,100 @@ var ( func EnableTCPDualStack() { tcp4Table.dualStack = tcp6Table } + +func (table *tcpTable) getCurrentTables() ( + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, + updateIter uint64, +) { + table.lock.RLock() + defer table.lock.RUnlock() + + return table.connections, table.listeners, table.updateIter.Load() +} + +func (table *tcpTable) checkFetchingState() (fetch bool, signal chan struct{}) { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // If fetching is already in progress, just return the signal. + if table.fetchingInProgress { + return false, table.fetchingDoneSignal + } + + // Otherwise, tell caller to fetch. + table.fetchingInProgress = true + return true, nil +} + +func (table *tcpTable) signalFetchComplete() { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // Set fetching state. + table.fetchingInProgress = false + + // Signal waiting goroutines. + close(table.fetchingDoneSignal) + table.fetchingDoneSignal = make(chan struct{}) +} + +func (table *tcpTable) updateTables(previousUpdateIter uint64) ( + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, + updateIter uint64, +) { + var tries int + + // Attempt to update the tables until we get a new version of the tables. + for previousUpdateIter == table.updateIter.Load() { + // Abort if it takes too long. + tries++ + if tries > maxUpdateTries { + log.Warningf("state: failed to upate TCP%d socket table %d times", table.version, tries-1) + return table.getCurrentTables() + } + + // Check if someone is fetching or if we should fetch. + fetch, signal := table.checkFetchingState() + if fetch { + defer table.signalFetchComplete() + + // Just to be sure, check again if there is a new version. + if previousUpdateIter < table.updateIter.Load() { + return table.getCurrentTables() + } + + // Wait for 5 milliseconds. + time.Sleep(5 * time.Millisecond) + + // Fetch new tables from system. + connections, listeners, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) + // Return the current tables as fallback, as we need to trigger the defer to complete the fetch. + return table.getCurrentTables() + } + + // Pre-check for any listeners. + for _, bindInfo := range listeners { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) + } + + // Apply new tables. + table.lock.Lock() + defer table.lock.Unlock() + table.connections = connections + table.listeners = listeners + table.updateIter.Add(1) + + // Return new tables immediately. + return table.connections, table.listeners, table.updateIter.Load() + } + + // Otherwise, wait for fetch to complete. + <-signal + } + + return table.getCurrentTables() +} diff --git a/network/state/udp.go b/network/state/udp.go index 41e13d25..353ba9b6 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -2,13 +2,14 @@ package state import ( "context" + "net" "strconv" "sync" + "sync/atomic" "time" + "github.com/safing/portbase/log" "github.com/safing/portmaster/netenv" - - "github.com/safing/portbase/utils" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/socket" ) @@ -16,11 +17,14 @@ import ( type udpTable struct { version int - binds []*socket.BindInfo - lock sync.RWMutex + binds []*socket.BindInfo + updateIter atomic.Uint64 + lock sync.RWMutex - fetchOnceAgain utils.OnceAgain - fetchTable func() (binds []*socket.BindInfo, err error) + fetchingLock sync.Mutex + fetchingInProgress bool + fetchingDoneSignal chan struct{} + fetchTable func() (binds []*socket.BindInfo, err error) states map[string]map[string]*udpState statesLock sync.Mutex @@ -46,15 +50,17 @@ const ( var ( udp6Table = &udpTable{ - version: 6, - fetchTable: getUDP6Table, - states: make(map[string]map[string]*udpState), + version: 6, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getUDP6Table, + states: make(map[string]map[string]*udpState), } udp4Table = &udpTable{ - version: 4, - fetchTable: getUDP4Table, - states: make(map[string]map[string]*udpState), + version: 4, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getUDP4Table, + states: make(map[string]map[string]*udpState), } ) @@ -64,15 +70,109 @@ func EnableUDPDualStack() { udp4Table.dualStack = udp6Table } +func (table *udpTable) getCurrentTables() ( + binds []*socket.BindInfo, + updateIter uint64, +) { + table.lock.RLock() + defer table.lock.RUnlock() + + return table.binds, table.updateIter.Load() +} + +func (table *udpTable) checkFetchingState() (fetch bool, signal chan struct{}) { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // If fetching is already in progress, just return the signal. + if table.fetchingInProgress { + return false, table.fetchingDoneSignal + } + + // Otherwise, tell caller to fetch. + table.fetchingInProgress = true + return true, nil +} + +func (table *udpTable) signalFetchComplete() { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // Set fetching state. + table.fetchingInProgress = false + + // Signal waiting goroutines. + close(table.fetchingDoneSignal) + table.fetchingDoneSignal = make(chan struct{}) +} + +func (table *udpTable) updateTables(previousUpdateIter uint64) ( + binds []*socket.BindInfo, + updateIter uint64, +) { + var tries int + + // Attempt to update the tables until we get a new version of the tables. + for previousUpdateIter == table.updateIter.Load() { + // Abort if it takes too long. + tries++ + if tries > maxUpdateTries { + log.Warningf("state: failed to upate UDP%d socket table %d times", table.version, tries-1) + return table.getCurrentTables() + } + + // Check if someone is fetching or if we should fetch. + fetch, signal := table.checkFetchingState() + if fetch { + defer table.signalFetchComplete() + + // Just to be sure, check again if there is a new version. + if previousUpdateIter < table.updateIter.Load() { + return table.getCurrentTables() + } + + // Wait for 5 milliseconds. + time.Sleep(5 * time.Millisecond) + + // Fetch new tables from system. + binds, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) + // Return the current tables as fallback, as we need to trigger the defer to complete the fetch. + return table.getCurrentTables() + } + + // Pre-check for any listeners. + for _, bindInfo := range binds { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) + } + + // Apply new tables. + table.lock.Lock() + defer table.lock.Unlock() + table.binds = binds + table.updateIter.Add(1) + + // Return new tables immediately. + return table.binds, table.updateIter.Load() + } + + // Otherwise, wait for fetch to complete. + <-signal + } + + return table.getCurrentTables() +} + // CleanUDPStates cleans the udp connection states which save connection directions. func CleanUDPStates(_ context.Context) { now := time.Now().UTC() - udp4Table.updateTable() + udp4Table.updateTables(0) udp4Table.cleanStates(now) if netenv.IPv6Enabled() { - udp6Table.updateTable() + udp6Table.updateTables(0) udp6Table.cleanStates(now) } } From 607036639953a1003a633b0a0ab828abac9aa465 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 30 Mar 2023 15:54:12 +0200 Subject: [PATCH 6/8] Ignore if parent process does not exist --- process/process.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/process/process.go b/process/process.go index f7d12973..95ca1daf 100644 --- a/process/process.go +++ b/process/process.go @@ -246,18 +246,16 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (* } process.ParentPid = int(ppid) - // Parent created at time + // Parent created time parentPInfo, err := processInfo.NewProcessWithContext(ctx, ppid) - if err != nil { - return nil, err + if err == nil { + parentCreatedAt, err := parentPInfo.CreateTimeWithContext(ctx) + if err != nil { + return nil, err + } + process.ParentCreatedAt = parentCreatedAt } - parentCreatedAt, err := parentPInfo.CreateTimeWithContext(ctx) - if err != nil { - return nil, err - } - process.ParentCreatedAt = parentCreatedAt - // Path process.Path, err = pInfo.ExeWithContext(ctx) if err != nil { From d9f4a9e70b2f68bd3955230fa341e2d8ea9ae762 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 30 Mar 2023 16:03:06 +0200 Subject: [PATCH 7/8] Fix getting updated tables immediately --- network/state/info.go | 6 +++--- network/state/udp.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/network/state/info.go b/network/state/info.go index 818efda3..5dbdc788 100644 --- a/network/state/info.go +++ b/network/state/info.go @@ -25,12 +25,12 @@ type Info struct { func GetInfo() *Info { info := &Info{} - info.TCP4Connections, info.TCP4Listeners, _ = tcp4Table.updateTables(0) - info.UDP4Binds, _ = udp4Table.updateTables(0) + info.TCP4Connections, info.TCP4Listeners, _ = tcp4Table.updateTables(tcp4Table.updateIter.Load()) + info.UDP4Binds, _ = udp4Table.updateTables(udp4Table.updateIter.Load()) if netenv.IPv6Enabled() { info.TCP6Connections, info.TCP6Listeners, _ = tcp6Table.updateTables(tcp6Table.updateIter.Load()) - info.UDP6Binds, _ = udp6Table.updateTables(0) + info.UDP6Binds, _ = udp6Table.updateTables(udp6Table.updateIter.Load()) } info.UpdateMeta() diff --git a/network/state/udp.go b/network/state/udp.go index 353ba9b6..1823a371 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -168,11 +168,11 @@ func (table *udpTable) updateTables(previousUpdateIter uint64) ( func CleanUDPStates(_ context.Context) { now := time.Now().UTC() - udp4Table.updateTables(0) + udp4Table.updateTables(udp4Table.updateIter.Load()) udp4Table.cleanStates(now) if netenv.IPv6Enabled() { - udp6Table.updateTables(0) + udp6Table.updateTables(udp6Table.updateIter.Load()) udp6Table.cleanStates(now) } } From 56998a010d0581379204a1b8b36987649d2e4c4e Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 30 Mar 2023 16:09:46 +0200 Subject: [PATCH 8/8] Improve lookup tries for network state --- network/state/lookup.go | 14 ++++++-------- network/state/system_linux.go | 6 ++++-- network/state/tables.go | 1 - 3 files changed, 10 insertions(+), 11 deletions(-) delete mode 100644 network/state/tables.go diff --git a/network/state/lookup.go b/network/state/lookup.go index a764e522..479bcf94 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -2,7 +2,6 @@ package state import ( "errors" - "time" "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" @@ -30,9 +29,8 @@ var ( ) var ( - baseWaitTime = 3 * time.Millisecond - lookupRetries = 7 * 2 // Every retry takes two full passes. - fastLookupRetries = 2 * 2 + lookupTries = 15 // With a max wait of 5ms, this amounts to up to 75ms. + fastLookupTries = 2 ) // Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound. @@ -81,7 +79,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( ) // Search for the socket until found. - for i := 1; i <= lookupRetries; i++ { + for i := 1; i <= lookupTries; i++ { // Get or update tables. if i == 1 { connections, listeners, updateIter = table.getCurrentTables() @@ -120,7 +118,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( } // Search less if we want to be fast. - if fast && i < fastLookupRetries { + if fast && i >= fastLookupTries { break } } @@ -184,7 +182,7 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( ) // Search for the socket until found. - for i := 1; i <= lookupRetries; i++ { + for i := 1; i <= lookupTries; i++ { // Get or update tables. if i == 1 { binds, updateIter = table.getCurrentTables() @@ -245,7 +243,7 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( } // Search less if we want to be fast. - if fast && i < fastLookupRetries { + if fast && i >= fastLookupTries { break } } diff --git a/network/state/system_linux.go b/network/state/system_linux.go index abef01c1..4f7c4138 100644 --- a/network/state/system_linux.go +++ b/network/state/system_linux.go @@ -14,8 +14,10 @@ var ( getUDP6Table = proc.GetUDP6Table ) +var baseWaitTime = 3 * time.Millisecond + func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) { - for i := 0; i <= lookupRetries; i++ { + for i := 1; i <= lookupTries; i++ { // look for PID pid = proc.GetPID(socketInfo) if pid != socket.UndefinedProcessID { @@ -24,7 +26,7 @@ func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, } // every time, except for the last iteration - if i < lookupRetries { + if i < lookupTries { // we found no PID, we could have been too fast, give the kernel some time to think // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total time.Sleep(time.Duration(i+1) * baseWaitTime) diff --git a/network/state/tables.go b/network/state/tables.go deleted file mode 100644 index 7bf2df5b..00000000 --- a/network/state/tables.go +++ /dev/null @@ -1 +0,0 @@ -package state