diff --git a/netquery/manager.go b/netquery/manager.go index fb244526..16dcbe0e 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -197,7 +197,8 @@ func convertConnection(conn *network.Connection) (*Conn, error) { } extraData := map[string]interface{}{ - "pid": conn.ProcessContext.PID, + "pid": conn.ProcessContext.PID, + "processCreatedAt": conn.ProcessContext.CreatedAt, } if conn.TunnelContext != nil { diff --git a/network/connection.go b/network/connection.go index a1af8c06..f9b2379d 100644 --- a/network/connection.go +++ b/network/connection.go @@ -26,7 +26,7 @@ import ( type FirewallHandler func(conn *Connection, pkt packet.Packet) // ProcessContext holds additional information about the process -// that iniated a connection. +// that initiated a connection. type ProcessContext struct { // ProcessName is the name of the process. ProcessName string @@ -38,6 +38,8 @@ type ProcessContext struct { CmdLine string // PID is the process identifier. PID int + // CreatedAt the time when the process was created. + CreatedAt int64 // Profile is the ID of the main profile that // is applied to the process. Profile string @@ -223,6 +225,7 @@ func getProcessContext(ctx context.Context, proc *process.Process) ProcessContex BinaryPath: proc.Path, CmdLine: proc.CmdLine, PID: proc.Pid, + CreatedAt: proc.CreatedAt, } // Get local profile. diff --git a/network/database.go b/network/database.go index 20c4cbfb..159ba1d3 100644 --- a/network/database.go +++ b/network/database.go @@ -45,7 +45,7 @@ func makeKey(pid int, scope, id string) string { return fmt.Sprintf("network:tree/%d/%s/%s", pid, scope, id) } -func parseDBKey(key string) (pid int, scope, id string, ok bool) { +func parseDBKey(key string) (processKey string, scope, id string, ok bool) { // Split into segments. segments := strings.Split(key, "/") @@ -65,27 +65,18 @@ func parseDBKey(key string) (pid int, scope, id string, ok bool) { // TODO: For queries, also prefixes of these values are valid. default: // Unknown scope. - return 0, "", "", false + return "", "", "", false } fallthrough case 2: - var err error - if segments[1] == "" { - pid = process.UndefinedProcessID - } else { - pid, err = strconv.Atoi(segments[1]) - if err != nil { - return 0, "", "", false - } - } - - return pid, scope, id, true + processKey = segments[1] + return processKey, scope, id, true case 1: // This is a valid query prefix, but not process ID was given. - return process.UndefinedProcessID, "", "", true + return "", "", "", true default: - return 0, "", "", false + return "", "", "", false } } @@ -93,7 +84,7 @@ func parseDBKey(key string) (pid int, scope, id string, ok bool) { func (s *StorageInterface) Get(key string) (record.Record, error) { // Parse key and check if valid. pid, scope, id, ok := parseDBKey(strings.TrimPrefix(key, "network:")) - if !ok || pid == process.UndefinedProcessID { + if !ok || pid == "" { return nil, storage.ErrNotFound } @@ -135,7 +126,7 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { return } - if pid == process.UndefinedProcessID { + if pid == "" { // processes for _, proc := range process.All() { func() { diff --git a/network/iphelper/iphelper.go b/network/iphelper/iphelper.go index 5fa1ea12..2256434b 100644 --- a/network/iphelper/iphelper.go +++ b/network/iphelper/iphelper.go @@ -11,7 +11,7 @@ import ( ) var ( - errInvalid = errors.New("IPHelper not initialzed or broken") + errInvalid = errors.New("IPHelper not initialized or broken") ) // IPHelper represents a subset of the Windows iphlpapi.dll. diff --git a/network/module.go b/network/module.go index 23022dbc..1a7ee708 100644 --- a/network/module.go +++ b/network/module.go @@ -2,6 +2,8 @@ package network import ( "github.com/safing/portbase/modules" + "github.com/safing/portmaster/netenv" + "github.com/safing/portmaster/network/state" ) var ( @@ -11,7 +13,7 @@ var ( ) func init() { - module = modules.Register("network", prep, start, nil, "base", "processes") + module = modules.Register("network", prep, start, nil, "base", "netenv", "processes") } // SetDefaultFirewallHandler sets the default firewall handler. @@ -22,6 +24,11 @@ func SetDefaultFirewallHandler(handler FirewallHandler) { } func prep() error { + if netenv.IPv6Enabled() { + state.EnableTCPDualStack() + state.EnableUDPDualStack() + } + return registerAPIEndpoints() } diff --git a/network/state/info.go b/network/state/info.go index 6f97a74e..5dbdc788 100644 --- a/network/state/info.go +++ b/network/state/info.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network/socket" ) @@ -24,27 +25,13 @@ type Info struct { func GetInfo() *Info { info := &Info{} - tcp4Table.updateTables() - tcp4Table.lock.RLock() - info.TCP4Connections = tcp4Table.connections - info.TCP4Listeners = tcp4Table.listeners - tcp4Table.lock.RUnlock() + info.TCP4Connections, info.TCP4Listeners, _ = tcp4Table.updateTables(tcp4Table.updateIter.Load()) + info.UDP4Binds, _ = udp4Table.updateTables(udp4Table.updateIter.Load()) - tcp6Table.updateTables() - tcp6Table.lock.RLock() - info.TCP6Connections = tcp6Table.connections - info.TCP6Listeners = tcp6Table.listeners - tcp6Table.lock.RUnlock() - - udp4Table.updateTable() - udp4Table.lock.RLock() - info.UDP4Binds = udp4Table.binds - udp4Table.lock.RUnlock() - - udp6Table.updateTable() - udp6Table.lock.RLock() - info.UDP6Binds = udp6Table.binds - udp6Table.lock.RUnlock() + if netenv.IPv6Enabled() { + info.TCP6Connections, info.TCP6Listeners, _ = tcp6Table.updateTables(tcp6Table.updateIter.Load()) + info.UDP6Binds, _ = udp6Table.updateTables(udp6Table.updateIter.Load()) + } info.UpdateMeta() return info diff --git a/network/state/lookup.go b/network/state/lookup.go index 46aac9a6..479bcf94 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -2,7 +2,6 @@ package state import ( "errors" - "time" "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" @@ -30,9 +29,8 @@ var ( ) var ( - baseWaitTime = 3 * time.Millisecond - lookupRetries = 7 * 2 // Every retry takes two full passes. - fastLookupRetries = 2 * 2 + lookupTries = 15 // With a max wait of 5ms, this amounts to up to 75ms. + fastLookupTries = 2 ) // Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound. @@ -69,69 +67,92 @@ func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) ( inbound bool, err error, ) { - // Search pattern: search, refresh, search, wait, search, refresh, search, wait, ... + // Prepare variables. + var ( + connections []*socket.ConnectionInfo + listeners []*socket.BindInfo + updateIter uint64 + + dualStackConnections []*socket.ConnectionInfo + dualStackListeners []*socket.BindInfo + dualStackUpdateIter uint64 + ) // Search for the socket until found. - for i := 1; i <= lookupRetries; i++ { - // Check main table for socket. - socketInfo, inbound := table.findSocket(pktInfo) - if socketInfo == nil && table.dualStack != nil { - // If there was no match in the main table and we are dual-stack, check - // the dual-stack table for the socket. - socketInfo, inbound = table.dualStack.findSocket(pktInfo) + for i := 1; i <= lookupTries; i++ { + // Get or update tables. + if i == 1 { + connections, listeners, updateIter = table.getCurrentTables() + } else { + connections, listeners, updateIter = table.updateTables(updateIter) } - // If there's a match, check we have the PID and return. + // Check tables for socket. + socketInfo, inbound := findTCPSocket(pktInfo, connections, listeners) + + // If there's a match, check if we have the PID and return. + if socketInfo != nil { + return checkPID(socketInfo, inbound) + } + + // DUAL-STACK + + // Skip if dualStack is not enabled. + if table.dualStack == nil { + continue + } + + // Get or update tables. + if i == 0 { + dualStackConnections, dualStackListeners, dualStackUpdateIter = table.dualStack.getCurrentTables() + } else { + dualStackConnections, dualStackListeners, dualStackUpdateIter = table.dualStack.updateTables(dualStackUpdateIter) + } + + // Check tables for socket. + socketInfo, inbound = findTCPSocket(pktInfo, dualStackConnections, dualStackListeners) + + // If there's a match, check if we have the PID and return. if socketInfo != nil { return checkPID(socketInfo, inbound) } // Search less if we want to be fast. - if fast && i < fastLookupRetries { + if fast && i >= fastLookupTries { break } - - // every time, except for the last iteration - if i < lookupRetries { - // Take turns in waiting and refreshing in order to satisfy the search pattern. - if i%2 == 0 { - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - } else { - // refetch lists - table.updateTables() - if table.dualStack != nil { - table.dualStack.updateTables() - } - } - } } return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound } -func (table *tcpTable) findSocket(pktInfo *packet.Info) ( +func findTCPSocket( + pktInfo *packet.Info, + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, +) ( socketInfo socket.Info, inbound bool, ) { localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() - table.lock.RLock() - defer table.lock.RUnlock() - // always search listeners first - for _, socketInfo := range table.listeners { + for _, socketInfo := range listeners { if localPort == socketInfo.Local.Port && (socketInfo.ListensAny || localIP.Equal(socketInfo.Local.IP)) { return socketInfo, true } } + remoteIP := pktInfo.RemoteIP() + remotePort := pktInfo.RemotePort() + // search connections - for _, socketInfo := range table.connections { + for _, socketInfo := range connections { if localPort == socketInfo.Local.Port && + remotePort == socketInfo.Remote.Port && + remoteIP.Equal(socketInfo.Remote.IP) && localIP.Equal(socketInfo.Local.IP) { return socketInfo, false } @@ -145,25 +166,67 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( inbound bool, err error, ) { - // Search pattern: search, refresh, search, wait, search, refresh, search, wait, ... - // TODO: Currently broadcast/multicast scopes are not checked, so we might // attribute an incoming broadcast/multicast packet to the wrong process if // there are multiple processes listening on the same local port, but // binding to different addresses. This highly unusual for clients. isInboundMulticast := pktInfo.Inbound && netutils.GetIPScope(pktInfo.LocalIP()) == netutils.LocalMulticast + // Prepare variables. + var ( + binds []*socket.BindInfo + updateIter uint64 + + dualStackBinds []*socket.BindInfo + dualStackUpdateIter uint64 + ) + // Search for the socket until found. - for i := 1; i <= lookupRetries; i++ { - // Check main table for socket. - socketInfo := table.findSocket(pktInfo, isInboundMulticast) - if socketInfo == nil && table.dualStack != nil { - // If there was no match in the main table and we are dual-stack, check - // the dual-stack table for the socket. - socketInfo = table.dualStack.findSocket(pktInfo, isInboundMulticast) + for i := 1; i <= lookupTries; i++ { + // Get or update tables. + if i == 1 { + binds, updateIter = table.getCurrentTables() + } else { + binds, updateIter = table.updateTables(updateIter) } - // If there's a match, get the direction and check we have the PID, then return. + // Check tables for socket. + socketInfo := findUDPSocket(pktInfo, binds, isInboundMulticast) + + // If there's a match, do some last checks and return. + if socketInfo != nil { + // If there is no remote port, do check for the direction of the + // connection. This will be the case for pure checking functions + // that do not want to change direction state. + if pktInfo.RemotePort() == 0 { + return checkPID(socketInfo, pktInfo.Inbound) + } + + // Get (and save) the direction of the connection. + connInbound := table.getDirection(socketInfo, pktInfo) + + // Check we have the PID and return. + return checkPID(socketInfo, connInbound) + } + + // DUAL-STACK + + // Skip if dualStack is not enabled. + if table.dualStack == nil { + continue + } + + // Get or update tables. + if i == 0 { + dualStackBinds, dualStackUpdateIter = table.dualStack.getCurrentTables() + } else { + dualStackBinds, dualStackUpdateIter = table.dualStack.updateTables(dualStackUpdateIter) + } + + // Check tables for socket. + socketInfo = findUDPSocket(pktInfo, dualStackBinds, isInboundMulticast) + + // If there's a match, do some last checks and return. if socketInfo != nil { // If there is no remote port, do check for the direction of the // connection. This will be the case for pure checking functions @@ -180,39 +243,20 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( } // Search less if we want to be fast. - if fast && i < fastLookupRetries { + if fast && i >= fastLookupTries { break } - - // every time, except for the last iteration - if i < lookupRetries { - // Take turns in waiting and refreshing in order to satisfy the search pattern. - if i%2 == 0 { - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - } else { - // refetch lists - table.updateTable() - if table.dualStack != nil { - table.dualStack.updateTable() - } - } - } } return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound } -func (table *udpTable) findSocket(pktInfo *packet.Info, isInboundMulticast bool) (socketInfo *socket.BindInfo) { +func findUDPSocket(pktInfo *packet.Info, binds []*socket.BindInfo, isInboundMulticast bool) (socketInfo *socket.BindInfo) { localIP := pktInfo.LocalIP() localPort := pktInfo.LocalPort() - table.lock.RLock() - defer table.lock.RUnlock() - // search binds - for _, socketInfo := range table.binds { + for _, socketInfo := range binds { if localPort == socketInfo.Local.Port && (socketInfo.ListensAny || // zero IP (dual-stack) isInboundMulticast || // inbound broadcast, multicast diff --git a/network/state/system_linux.go b/network/state/system_linux.go index abef01c1..4f7c4138 100644 --- a/network/state/system_linux.go +++ b/network/state/system_linux.go @@ -14,8 +14,10 @@ var ( getUDP6Table = proc.GetUDP6Table ) +var baseWaitTime = 3 * time.Millisecond + func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) { - for i := 0; i <= lookupRetries; i++ { + for i := 1; i <= lookupTries; i++ { // look for PID pid = proc.GetPID(socketInfo) if pid != socket.UndefinedProcessID { @@ -24,7 +26,7 @@ func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, } // every time, except for the last iteration - if i < lookupRetries { + if i < lookupTries { // we found no PID, we could have been too fast, give the kernel some time to think // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total time.Sleep(time.Duration(i+1) * baseWaitTime) diff --git a/network/state/tables.go b/network/state/tables.go deleted file mode 100644 index e99957ab..00000000 --- a/network/state/tables.go +++ /dev/null @@ -1,48 +0,0 @@ -package state - -import ( - "net" - - "github.com/safing/portbase/log" -) - -func (table *tcpTable) updateTables() { - table.fetchOnceAgain.Do(func() { - table.lock.Lock() - defer table.lock.Unlock() - - connections, listeners, err := table.fetchTable() - if err != nil { - log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) - return - } - - // Pre-check for any listeners. - for _, bindInfo := range listeners { - bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) - } - - table.connections = connections - table.listeners = listeners - }) -} - -func (table *udpTable) updateTable() { - table.fetchOnceAgain.Do(func() { - table.lock.Lock() - defer table.lock.Unlock() - - binds, err := table.fetchTable() - if err != nil { - log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) - return - } - - // Pre-check for any listeners. - for _, bindInfo := range binds { - bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) - } - - table.binds = binds - }) -} diff --git a/network/state/tcp.go b/network/state/tcp.go index 8b04cd4d..4b423a21 100644 --- a/network/state/tcp.go +++ b/network/state/tcp.go @@ -1,34 +1,146 @@ package state import ( + "net" "sync" + "sync/atomic" + "time" - "github.com/safing/portbase/utils" + "github.com/safing/portbase/log" "github.com/safing/portmaster/network/socket" ) +const maxUpdateTries = 100 + type tcpTable struct { version int connections []*socket.ConnectionInfo listeners []*socket.BindInfo + updateIter atomic.Uint64 lock sync.RWMutex - fetchOnceAgain utils.OnceAgain - fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) + fetchingLock sync.Mutex + fetchingInProgress bool + fetchingDoneSignal chan struct{} + fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) dualStack *tcpTable } var ( tcp6Table = &tcpTable{ - version: 6, - fetchTable: getTCP6Table, + version: 6, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getTCP6Table, } tcp4Table = &tcpTable{ - version: 4, - fetchTable: getTCP4Table, - dualStack: tcp6Table, + version: 4, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getTCP4Table, } ) + +// EnableTCPDualStack adds the TCP6 table to the TCP4 table as a dual-stack. +// Must be called before any lookup operation. +func EnableTCPDualStack() { + tcp4Table.dualStack = tcp6Table +} + +func (table *tcpTable) getCurrentTables() ( + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, + updateIter uint64, +) { + table.lock.RLock() + defer table.lock.RUnlock() + + return table.connections, table.listeners, table.updateIter.Load() +} + +func (table *tcpTable) checkFetchingState() (fetch bool, signal chan struct{}) { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // If fetching is already in progress, just return the signal. + if table.fetchingInProgress { + return false, table.fetchingDoneSignal + } + + // Otherwise, tell caller to fetch. + table.fetchingInProgress = true + return true, nil +} + +func (table *tcpTable) signalFetchComplete() { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // Set fetching state. + table.fetchingInProgress = false + + // Signal waiting goroutines. + close(table.fetchingDoneSignal) + table.fetchingDoneSignal = make(chan struct{}) +} + +func (table *tcpTable) updateTables(previousUpdateIter uint64) ( + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, + updateIter uint64, +) { + var tries int + + // Attempt to update the tables until we get a new version of the tables. + for previousUpdateIter == table.updateIter.Load() { + // Abort if it takes too long. + tries++ + if tries > maxUpdateTries { + log.Warningf("state: failed to upate TCP%d socket table %d times", table.version, tries-1) + return table.getCurrentTables() + } + + // Check if someone is fetching or if we should fetch. + fetch, signal := table.checkFetchingState() + if fetch { + defer table.signalFetchComplete() + + // Just to be sure, check again if there is a new version. + if previousUpdateIter < table.updateIter.Load() { + return table.getCurrentTables() + } + + // Wait for 5 milliseconds. + time.Sleep(5 * time.Millisecond) + + // Fetch new tables from system. + connections, listeners, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err) + // Return the current tables as fallback, as we need to trigger the defer to complete the fetch. + return table.getCurrentTables() + } + + // Pre-check for any listeners. + for _, bindInfo := range listeners { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) + } + + // Apply new tables. + table.lock.Lock() + defer table.lock.Unlock() + table.connections = connections + table.listeners = listeners + table.updateIter.Add(1) + + // Return new tables immediately. + return table.connections, table.listeners, table.updateIter.Load() + } + + // Otherwise, wait for fetch to complete. + <-signal + } + + return table.getCurrentTables() +} diff --git a/network/state/udp.go b/network/state/udp.go index 0302be14..1823a371 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -2,11 +2,14 @@ package state import ( "context" + "net" "strconv" "sync" + "sync/atomic" "time" - "github.com/safing/portbase/utils" + "github.com/safing/portbase/log" + "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/socket" ) @@ -14,11 +17,14 @@ import ( type udpTable struct { version int - binds []*socket.BindInfo - lock sync.RWMutex + binds []*socket.BindInfo + updateIter atomic.Uint64 + lock sync.RWMutex - fetchOnceAgain utils.OnceAgain - fetchTable func() (binds []*socket.BindInfo, err error) + fetchingLock sync.Mutex + fetchingInProgress bool + fetchingDoneSignal chan struct{} + fetchTable func() (binds []*socket.BindInfo, err error) states map[string]map[string]*udpState statesLock sync.Mutex @@ -44,28 +50,131 @@ const ( var ( udp6Table = &udpTable{ - version: 6, - fetchTable: getUDP6Table, - states: make(map[string]map[string]*udpState), + version: 6, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getUDP6Table, + states: make(map[string]map[string]*udpState), } udp4Table = &udpTable{ - version: 4, - fetchTable: getUDP4Table, - states: make(map[string]map[string]*udpState), - dualStack: udp6Table, + version: 4, + fetchingDoneSignal: make(chan struct{}), + fetchTable: getUDP4Table, + states: make(map[string]map[string]*udpState), } ) +// EnableUDPDualStack adds the UDP6 table to the UDP4 table as a dual-stack. +// Must be called before any lookup operation. +func EnableUDPDualStack() { + udp4Table.dualStack = udp6Table +} + +func (table *udpTable) getCurrentTables() ( + binds []*socket.BindInfo, + updateIter uint64, +) { + table.lock.RLock() + defer table.lock.RUnlock() + + return table.binds, table.updateIter.Load() +} + +func (table *udpTable) checkFetchingState() (fetch bool, signal chan struct{}) { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // If fetching is already in progress, just return the signal. + if table.fetchingInProgress { + return false, table.fetchingDoneSignal + } + + // Otherwise, tell caller to fetch. + table.fetchingInProgress = true + return true, nil +} + +func (table *udpTable) signalFetchComplete() { + table.fetchingLock.Lock() + defer table.fetchingLock.Unlock() + + // Set fetching state. + table.fetchingInProgress = false + + // Signal waiting goroutines. + close(table.fetchingDoneSignal) + table.fetchingDoneSignal = make(chan struct{}) +} + +func (table *udpTable) updateTables(previousUpdateIter uint64) ( + binds []*socket.BindInfo, + updateIter uint64, +) { + var tries int + + // Attempt to update the tables until we get a new version of the tables. + for previousUpdateIter == table.updateIter.Load() { + // Abort if it takes too long. + tries++ + if tries > maxUpdateTries { + log.Warningf("state: failed to upate UDP%d socket table %d times", table.version, tries-1) + return table.getCurrentTables() + } + + // Check if someone is fetching or if we should fetch. + fetch, signal := table.checkFetchingState() + if fetch { + defer table.signalFetchComplete() + + // Just to be sure, check again if there is a new version. + if previousUpdateIter < table.updateIter.Load() { + return table.getCurrentTables() + } + + // Wait for 5 milliseconds. + time.Sleep(5 * time.Millisecond) + + // Fetch new tables from system. + binds, err := table.fetchTable() + if err != nil { + log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err) + // Return the current tables as fallback, as we need to trigger the defer to complete the fetch. + return table.getCurrentTables() + } + + // Pre-check for any listeners. + for _, bindInfo := range binds { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) + } + + // Apply new tables. + table.lock.Lock() + defer table.lock.Unlock() + table.binds = binds + table.updateIter.Add(1) + + // Return new tables immediately. + return table.binds, table.updateIter.Load() + } + + // Otherwise, wait for fetch to complete. + <-signal + } + + return table.getCurrentTables() +} + // CleanUDPStates cleans the udp connection states which save connection directions. func CleanUDPStates(_ context.Context) { now := time.Now().UTC() - udp4Table.updateTable() + udp4Table.updateTables(udp4Table.updateIter.Load()) udp4Table.cleanStates(now) - udp6Table.updateTable() - udp6Table.cleanStates(now) + if netenv.IPv6Enabled() { + udp6Table.updateTables(udp6Table.updateIter.Load()) + udp6Table.cleanStates(now) + } } func (table *udpTable) getConnState( diff --git a/process/database.go b/process/database.go index 62df8f09..4def9b01 100644 --- a/process/database.go +++ b/process/database.go @@ -15,7 +15,7 @@ import ( const processDatabaseNamespace = "network:tree" var ( - processes = make(map[int]*Process) + processes = make(map[string]*Process) processesLock sync.RWMutex dbController *database.Controller @@ -25,11 +25,11 @@ var ( ) // GetProcessFromStorage returns a process from the internal storage. -func GetProcessFromStorage(pid int) (*Process, bool) { +func GetProcessFromStorage(key string) (*Process, bool) { processesLock.RLock() defer processesLock.RUnlock() - p, ok := processes[pid] + p, ok := processes[key] return p, ok } @@ -55,11 +55,11 @@ func (p *Process) Save() { if !p.KeyIsSet() { // set key - p.SetKey(fmt.Sprintf("%s/%d", processDatabaseNamespace, p.Pid)) + p.SetKey(fmt.Sprintf("%s/%s", processDatabaseNamespace, getProcessKey(int32(p.Pid), p.CreatedAt))) // save processesLock.Lock() - processes[p.Pid] = p + processes[p.processKey] = p processesLock.Unlock() } @@ -75,7 +75,7 @@ func (p *Process) Delete() { // delete from internal storage processesLock.Lock() - delete(processes, p.Pid) + delete(processes, p.processKey) processesLock.Unlock() // propagate delete diff --git a/process/process.go b/process/process.go index 9f3ad2e4..95ca1daf 100644 --- a/process/process.go +++ b/process/process.go @@ -6,7 +6,6 @@ import ( "fmt" "path/filepath" "runtime" - "strconv" "strings" "sync" "time" @@ -31,18 +30,23 @@ type Process struct { // Process attributes. // Don't change; safe for concurrent access. - Name string - UserID int - UserName string - UserHome string - Pid int - ParentPid int - Path string - ExecName string - Cwd string - CmdLine string - FirstArg string - Env map[string]string + Name string + UserID int + UserName string + UserHome string + Pid int + CreatedAt int64 + ParentPid int + ParentCreatedAt int64 + Path string + ExecName string + Cwd string + CmdLine string + FirstArg string + Env map[string]string + + // unique process identifier ("Pid-CreatedAt") + processKey string // Profile attributes. // Once set, these don't change; safe for concurrent access. @@ -156,8 +160,31 @@ func (p *Process) String() string { func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { log.Tracer(ctx).Tracef("process: getting process for PID %d", pid) - p, err, _ := getProcessSingleInflight.Do(strconv.Itoa(pid), func() (interface{}, error) { - return loadProcess(ctx, pid) + // Check for special processes + switch pid { + case UnidentifiedProcessID: + return GetUnidentifiedProcess(ctx), nil + case UnsolicitedProcessID: + return GetUnsolicitedProcess(ctx), nil + case SystemProcessID: + return GetSystemProcess(ctx), nil + } + + // Get pid and created time for identification. + pInfo, err := processInfo.NewProcessWithContext(ctx, int32(pid)) + if err != nil { + return nil, err + } + + createdTime, err := pInfo.CreateTimeWithContext(ctx) + if err != nil { + return nil, err + } + + key := getProcessKey(int32(pid), createdTime) + + p, err, _ := getProcessSingleInflight.Do(key, func() (interface{}, error) { + return loadProcess(ctx, key, pInfo) }) if err != nil { return nil, err @@ -169,29 +196,25 @@ func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { return p.(*Process), nil // nolint:forcetypeassert // Can only be a *Process. } -func loadProcess(ctx context.Context, pid int) (*Process, error) { - switch pid { - case UnidentifiedProcessID: - return GetUnidentifiedProcess(ctx), nil - case UnsolicitedProcessID: - return GetUnsolicitedProcess(ctx), nil - case SystemProcessID: - return GetSystemProcess(ctx), nil - } +func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*Process, error) { + // Get created time of process. The value should be cached. + createdAt, _ := pInfo.CreateTimeWithContext(ctx) - process, ok := GetProcessFromStorage(pid) + process, ok := GetProcessFromStorage(key) if ok { return process, nil } // Create new a process object. process = &Process{ - Pid: pid, - FirstSeen: time.Now().Unix(), + Pid: int(pInfo.Pid), + CreatedAt: createdAt, + FirstSeen: time.Now().Unix(), + processKey: key, } // Get process information from the system. - pInfo, err := processInfo.NewProcessWithContext(ctx, int32(pid)) + pInfo, err := processInfo.NewProcessWithContext(ctx, pInfo.Pid) if err != nil { return nil, err } @@ -202,7 +225,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { var uids []int32 uids, err = pInfo.UidsWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get UID for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get UID for p%d: %w", pInfo.Pid, err) } process.UserID = int(uids[0]) } @@ -210,23 +233,33 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // Username process.UserName, err = pInfo.UsernameWithContext(ctx) if err != nil { - return nil, fmt.Errorf("process: failed to get Username for p%d: %w", pid, err) + return nil, fmt.Errorf("process: failed to get Username for p%d: %w", pInfo.Pid, err) } // TODO: User Home // new.UserHome, err = - // PPID + // Parent process id ppid, err := pInfo.PpidWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get PPID for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get PPID for p%d: %w", pInfo.Pid, err) } process.ParentPid = int(ppid) + // Parent created time + parentPInfo, err := processInfo.NewProcessWithContext(ctx, ppid) + if err == nil { + parentCreatedAt, err := parentPInfo.CreateTimeWithContext(ctx) + if err != nil { + return nil, err + } + process.ParentCreatedAt = parentCreatedAt + } + // Path process.Path, err = pInfo.ExeWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get Path for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get Path for p%d: %w", pInfo.Pid, err) } // remove linux " (deleted)" suffix for deleted files if onLinux { @@ -247,13 +280,13 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // Command line arguments process.CmdLine, err = pInfo.CmdlineWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get Cmdline for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get Cmdline for p%d: %w", pInfo.Pid, err) } // Name process.Name, err = pInfo.NameWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get Name for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get Name for p%d: %w", pInfo.Pid, err) } if process.Name == "" { process.Name = process.ExecName @@ -262,7 +295,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // Get all environment variables env, err := pInfo.EnvironWithContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get the environment for p%d: %w", pid, err) + return nil, fmt.Errorf("failed to get the environment for p%d: %w", pInfo.Pid, err) } // Split env variables in key and value. process.Env = make(map[string]string, len(env)) @@ -283,6 +316,11 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { return process, nil } +// Builds a unique identifier for a processes. +func getProcessKey(pid int32, createdTime int64) string { + return fmt.Sprintf("%d-%d", pid, createdTime) +} + // MatchingData returns the matching data for the process. func (p *Process) MatchingData() *MatchingData { return &MatchingData{p} diff --git a/process/special.go b/process/special.go index 7d2c3e93..93f17c9a 100644 --- a/process/special.go +++ b/process/special.go @@ -2,7 +2,6 @@ package process import ( "context" - "strconv" "time" "golang.org/x/sync/singleflight" @@ -39,29 +38,35 @@ func init() { var ( // unidentifiedProcess is used for non-attributed outgoing connections. unidentifiedProcess = &Process{ - UserID: UnidentifiedProcessID, - UserName: "Unknown", - Pid: UnidentifiedProcessID, - ParentPid: UnidentifiedProcessID, - Name: profile.UnidentifiedProfileName, + UserID: UnidentifiedProcessID, + UserName: "Unknown", + Pid: UnidentifiedProcessID, + CreatedAt: 1, + ParentPid: UnidentifiedProcessID, + ParentCreatedAt: 1, + Name: profile.UnidentifiedProfileName, } // unsolicitedProcess is used for non-attributed incoming connections. unsolicitedProcess = &Process{ - UserID: UnsolicitedProcessID, - UserName: "Unknown", - Pid: UnsolicitedProcessID, - ParentPid: UnsolicitedProcessID, - Name: profile.UnsolicitedProfileName, + UserID: UnsolicitedProcessID, + UserName: "Unknown", + Pid: UnsolicitedProcessID, + CreatedAt: 1, + ParentPid: UnsolicitedProcessID, + ParentCreatedAt: 1, + Name: profile.UnsolicitedProfileName, } // systemProcess is used to represent the Kernel. systemProcess = &Process{ - UserID: SystemProcessID, - UserName: "Kernel", - Pid: SystemProcessID, - ParentPid: SystemProcessID, - Name: profile.SystemProfileName, + UserID: SystemProcessID, + UserName: "Kernel", + Pid: SystemProcessID, + CreatedAt: 1, + ParentPid: SystemProcessID, + ParentCreatedAt: 1, + Name: profile.SystemProfileName, } getSpecialProcessSingleInflight singleflight.Group @@ -83,9 +88,9 @@ func GetSystemProcess(ctx context.Context) *Process { } func getSpecialProcess(ctx context.Context, template *Process) *Process { - p, _, _ := getSpecialProcessSingleInflight.Do(strconv.Itoa(template.Pid), func() (interface{}, error) { + p, _, _ := getSpecialProcessSingleInflight.Do(template.processKey, func() (interface{}, error) { // Check if we have already loaded the special process. - process, ok := GetProcessFromStorage(template.Pid) + process, ok := GetProcessFromStorage(template.processKey) if ok { return process, nil }