diff --git a/network/proc/findpid.go b/network/proc/findpid.go index 3b123f4f..0610e361 100644 --- a/network/proc/findpid.go +++ b/network/proc/findpid.go @@ -21,31 +21,18 @@ var ( pidsByUser = make(map[int][]int) ) -// FindConnectionPID returns the pid of the given socket info. -func FindConnectionPID(socketInfo *socket.ConnectionInfo) (pid int) { +// GetPID returns the already existing pid of the given socket info or searches for it. +// This also acts as a getter for socket.*Info.PID, as locking for that occurs here. +func GetPID(socketInfo socket.Info) (pid int) { pidsByUserLock.Lock() defer pidsByUserLock.Unlock() - if socketInfo.PID != socket.UnidentifiedProcessID { - return socket.UnidentifiedProcessID + if socketInfo.GetPID() != socket.UnidentifiedProcessID { + return socketInfo.GetPID() } - pid = findPID(socketInfo.UID, socketInfo.Inode) - socketInfo.PID = pid - return pid -} - -// FindBindPID returns the pid of the given socket info. -func FindBindPID(socketInfo *socket.BindInfo) (pid int) { - pidsByUserLock.Lock() - defer pidsByUserLock.Unlock() - - if socketInfo.PID != socket.UnidentifiedProcessID { - return socket.UnidentifiedProcessID - } - - pid = findPID(socketInfo.UID, socketInfo.Inode) - socketInfo.PID = pid + pid = findPID(socketInfo.GetUID(), socketInfo.GetInode()) + socketInfo.SetPID(pid) return pid } @@ -175,6 +162,9 @@ entryLoop: } +// readDirNames only reads the directory names. Using ioutil.ReadDir() would call `lstat` on every +// resulting directory name, which we don't need. This function will be called a lot, so we should +// refrain from unnecessary work. func readDirNames(dir string) (names []string) { file, err := os.Open(dir) if err != nil { diff --git a/network/proc/tables_test.go b/network/proc/tables_test.go index bcce8498..d5b1959f 100644 --- a/network/proc/tables_test.go +++ b/network/proc/tables_test.go @@ -14,12 +14,12 @@ func TestSockets(t *testing.T) { } fmt.Println("\nTCP 4 connections:") for _, connection := range connections { - pid := FindConnectionPID(connection) + pid := GetPID(connection) fmt.Printf("%d: %+v\n", pid, connection) } fmt.Println("\nTCP 4 listeners:") for _, listener := range listeners { - pid := FindBindPID(listener) + pid := GetPID(listener) fmt.Printf("%d: %+v\n", pid, listener) } @@ -29,12 +29,12 @@ func TestSockets(t *testing.T) { } fmt.Println("\nTCP 6 connections:") for _, connection := range connections { - pid := FindConnectionPID(connection) + pid := GetPID(connection) fmt.Printf("%d: %+v\n", pid, connection) } fmt.Println("\nTCP 6 listeners:") for _, listener := range listeners { - pid := FindBindPID(listener) + pid := GetPID(listener) fmt.Printf("%d: %+v\n", pid, listener) } @@ -44,7 +44,7 @@ func TestSockets(t *testing.T) { } fmt.Println("\nUDP 4 binds:") for _, bind := range binds { - pid := FindBindPID(bind) + pid := GetPID(bind) fmt.Printf("%d: %+v\n", pid, bind) } @@ -54,7 +54,7 @@ func TestSockets(t *testing.T) { } fmt.Println("\nUDP 6 binds:") for _, bind := range binds { - pid := FindBindPID(bind) + pid := GetPID(bind) fmt.Printf("%d: %+v\n", pid, bind) } } diff --git a/network/socket/socket.go b/network/socket/socket.go index e8dfe1d9..22f37ef0 100644 --- a/network/socket/socket.go +++ b/network/socket/socket.go @@ -29,3 +29,35 @@ type Address struct { IP net.IP Port uint16 } + +// Info is a generic interface to both ConnectionInfo and BindInfo. +type Info interface { + GetPID() int + SetPID(int) + GetUID() int + GetInode() int +} + +// GetPID returns the PID. +func (i *ConnectionInfo) GetPID() int { return i.PID } + +// SetPID sets the PID to the given value. +func (i *ConnectionInfo) SetPID(pid int) { i.PID = pid } + +// GetUID returns the UID. +func (i *ConnectionInfo) GetUID() int { return i.UID } + +// GetInode returns the Inode. +func (i *ConnectionInfo) GetInode() int { return i.Inode } + +// GetPID returns the PID. +func (i *BindInfo) GetPID() int { return i.PID } + +// SetPID sets the PID to the given value. +func (i *BindInfo) SetPID(pid int) { i.PID = pid } + +// GetUID returns the UID. +func (i *BindInfo) GetUID() int { return i.UID } + +// GetInode returns the Inode. +func (i *BindInfo) GetInode() int { return i.Inode } diff --git a/network/state/lookup.go b/network/state/lookup.go index b05fe237..48d2f5b6 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -81,7 +81,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) ( if localPort == socketInfo.Local.Port && (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { table.lock.RUnlock() - return checkBindPID(socketInfo, true) + return checkPID(socketInfo, true) } } @@ -90,7 +90,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) ( if localPort == socketInfo.Local.Port && localIP.Equal(socketInfo.Local.IP) { table.lock.RUnlock() - return checkConnectionPID(socketInfo, false) + return checkPID(socketInfo, false) } } @@ -138,12 +138,12 @@ func (table *udpTable) lookup(pktInfo *packet.Info) ( // do not check direction if remoteIP/Port is not given if pktInfo.RemotePort() == 0 { - return checkBindPID(socketInfo, pktInfo.Inbound) + return checkPID(socketInfo, pktInfo.Inbound) } // get direction and return connInbound := table.getDirection(socketInfo, pktInfo) - return checkBindPID(socketInfo, connInbound) + return checkPID(socketInfo, connInbound) } } diff --git a/network/state/system_linux.go b/network/state/system_linux.go index c4b67777..9a0a41ab 100644 --- a/network/state/system_linux.go +++ b/network/state/system_linux.go @@ -14,30 +14,10 @@ var ( getUDP6Table = proc.GetUDP6Table ) -func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { +func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) { for i := 0; i <= lookupRetries; i++ { // look for PID - pid = proc.FindConnectionPID(socketInfo) - if pid != socket.UnidentifiedProcessID { - // if we found a PID, return - break - } - - // every time, except for the last iteration - if i < lookupRetries { - // we found no PID, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - } - } - - return pid, connInbound, nil -} - -func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { - for i := 0; i <= lookupRetries; i++ { - // look for PID - pid = proc.FindBindPID(socketInfo) + pid = proc.GetPID(socketInfo) if pid != socket.UnidentifiedProcessID { // if we found a PID, return break