From 914418876d303bc7596399727a6a3d31a47bbf54 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 2 Nov 2020 14:18:42 +0100 Subject: [PATCH] Add IPv6 dual-stack support --- network/socket/socket.go | 2 + network/state/lookup.go | 168 +++++++++++++++++++++++++-------------- network/state/tables.go | 12 +++ network/state/tcp.go | 13 +-- network/state/udp.go | 15 ++-- 5 files changed, 140 insertions(+), 70 deletions(-) diff --git a/network/socket/socket.go b/network/socket/socket.go index 24d03518..7ecfddf7 100644 --- a/network/socket/socket.go +++ b/network/socket/socket.go @@ -29,6 +29,8 @@ type BindInfo struct { PID int UID int Inode int + + ListensAny bool } // Address is an IP + Port pair. diff --git a/network/state/lookup.go b/network/state/lookup.go index 2202f2b9..0d1a5d92 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -31,7 +31,7 @@ var ( var ( baseWaitTime = 3 * time.Millisecond - lookupRetries = 7 + lookupRetries = 7 * 2 // Every retry takes two full passes. ) // Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound. @@ -68,97 +68,147 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) ( inbound bool, err error, ) { + // Search pattern: search, wait, search, refresh, search, wait, search, refresh, ... - localIP := pktInfo.LocalIP() - localPort := pktInfo.LocalPort() - - // search until we find something + // Search for the socket until found. for i := 0; i <= lookupRetries; i++ { - table.lock.RLock() - - // always search listeners first - for _, socketInfo := range table.listeners { - if localPort == socketInfo.Local.Port && - (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { - table.lock.RUnlock() - return checkPID(socketInfo, true) - } + // Check main table for socket. + socketInfo, inbound := table.findSocket(pktInfo) + if socketInfo == nil && table.dualStack != nil { + // If there was no match in the main table and we are dual-stack, check + // the dual-stack table for the socket. + socketInfo, inbound = table.dualStack.findSocket(pktInfo) } - // search connections - for _, socketInfo := range table.connections { - if localPort == socketInfo.Local.Port && - localIP.Equal(socketInfo.Local.IP) { - table.lock.RUnlock() - return checkPID(socketInfo, false) - } + // If there's a match, check we have the PID and return. + if socketInfo != nil { + return checkPID(socketInfo, inbound) } - table.lock.RUnlock() - // every time, except for the last iteration if i < lookupRetries { - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - - // refetch lists - table.updateTables() + // Take turns in waiting and refreshing in order to satisfy the search pattern. + if i%2 == 0 { + // we found nothing, we could have been too fast, give the kernel some time to think + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) + } else { + // refetch lists + table.updateTables() + if table.dualStack != nil { + table.dualStack.updateTables() + } + } } } return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound } +func (table *tcpTable) findSocket(pktInfo *packet.Info) ( + socketInfo socket.Info, + inbound bool, +) { + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + + table.lock.RLock() + defer table.lock.RUnlock() + + // always search listeners first + for _, socketInfo := range table.listeners { + if localPort == socketInfo.Local.Port && + (socketInfo.ListensAny || localIP.Equal(socketInfo.Local.IP)) { + return socketInfo, false + } + } + + // search connections + for _, socketInfo := range table.connections { + if localPort == socketInfo.Local.Port && + localIP.Equal(socketInfo.Local.IP) { + return socketInfo, false + } + } + + return nil, false +} + func (table *udpTable) lookup(pktInfo *packet.Info) ( pid int, inbound bool, err error, ) { - localIP := pktInfo.LocalIP() - localPort := pktInfo.LocalPort() + // Search pattern: search, wait, search, refresh, search, wait, search, refresh, ... - isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(localIP) == netutils.LocalMulticast // TODO: Currently broadcast/multicast scopes are not checked, so we might // attribute an incoming broadcast/multicast packet to the wrong process if // there are multiple processes listening on the same local port, but // binding to different addresses. This highly unusual for clients. + isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(pktInfo.LocalIP()) == netutils.LocalMulticast - // search until we find something + // Search for the socket until found. for i := 0; i <= lookupRetries; i++ { - table.lock.RLock() - - // search binds - for _, socketInfo := range table.binds { - if localPort == socketInfo.Local.Port && - (socketInfo.Local.IP[0] == 0 || // zero IP - isInboundMulticast || // inbound broadcast, multicast - localIP.Equal(socketInfo.Local.IP)) { - table.lock.RUnlock() - - // do not check direction if remoteIP/Port is not given - if pktInfo.RemotePort() == 0 { - return checkPID(socketInfo, pktInfo.Inbound) - } - - // get direction and return - connInbound := table.getDirection(socketInfo, pktInfo) - return checkPID(socketInfo, connInbound) - } + // Check main table for socket. + socketInfo := table.findSocket(pktInfo, isInboundMulticast) + if socketInfo == nil && table.dualStack != nil { + // If there was no match in the main table and we are dual-stack, check + // the dual-stack table for the socket. + socketInfo = table.dualStack.findSocket(pktInfo, isInboundMulticast) } - table.lock.RUnlock() + // If there's a match, get the direction and check we have the PID, then return. + if socketInfo != nil { + // If there is no remote port, do check for the direction of the + // connection. This will be the case for pure checking functions + // that do not want to change direction state. + if pktInfo.RemotePort() == 0 { + return checkPID(socketInfo, pktInfo.Inbound) + } + + // Get (and save) the direction of the connection. + connInbound := table.getDirection(socketInfo, pktInfo) + + // Check we have the PID and return. + return checkPID(socketInfo, connInbound) + } // every time, except for the last iteration if i < lookupRetries { - // we found nothing, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - - // refetch lists - table.updateTable() + // Take turns in waiting and refreshing in order to satisfy the search pattern. + if i%2 == 0 { + // we found nothing, we could have been too fast, give the kernel some time to think + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) + } else { + // refetch lists + table.updateTable() + if table.dualStack != nil { + table.dualStack.updateTable() + } + } } } return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound } + +func (table *udpTable) findSocket(pktInfo *packet.Info, isInboundMulticast bool) (socketInfo *socket.BindInfo) { + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + + table.lock.RLock() + defer table.lock.RUnlock() + + // search binds + for _, socketInfo := range table.binds { + if localPort == socketInfo.Local.Port && + (socketInfo.ListensAny || // zero IP (dual-stack) + isInboundMulticast || // inbound broadcast, multicast + localIP.Equal(socketInfo.Local.IP)) { + return socketInfo + } + } + + return nil +} diff --git a/network/state/tables.go b/network/state/tables.go index df6e9783..e99957ab 100644 --- a/network/state/tables.go +++ b/network/state/tables.go @@ -1,6 +1,8 @@ package state import ( + "net" + "github.com/safing/portbase/log" ) @@ -15,6 +17,11 @@ func (table *tcpTable) updateTables() { return } + // Pre-check for any listeners. + for _, bindInfo := range listeners { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) + } + table.connections = connections table.listeners = listeners }) @@ -31,6 +38,11 @@ func (table *udpTable) updateTable() { return } + // Pre-check for any listeners. + for _, bindInfo := range binds { + bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero) + } + table.binds = binds }) } diff --git a/network/state/tcp.go b/network/state/tcp.go index 2894aada..8b04cd4d 100644 --- a/network/state/tcp.go +++ b/network/state/tcp.go @@ -16,16 +16,19 @@ type tcpTable struct { fetchOnceAgain utils.OnceAgain fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) + + dualStack *tcpTable } var ( - tcp4Table = &tcpTable{ - version: 4, - fetchTable: getTCP4Table, - } - tcp6Table = &tcpTable{ version: 6, fetchTable: getTCP6Table, } + + tcp4Table = &tcpTable{ + version: 4, + fetchTable: getTCP4Table, + dualStack: tcp6Table, + } ) diff --git a/network/state/udp.go b/network/state/udp.go index ad596fa6..31474f5b 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -22,6 +22,8 @@ type udpTable struct { states map[string]map[string]*udpState statesLock sync.Mutex + + dualStack *udpTable } type udpState struct { @@ -41,17 +43,18 @@ const ( ) var ( - udp4Table = &udpTable{ - version: 4, - fetchTable: getUDP4Table, - states: make(map[string]map[string]*udpState), - } - udp6Table = &udpTable{ version: 6, fetchTable: getUDP6Table, states: make(map[string]map[string]*udpState), } + + udp4Table = &udpTable{ + version: 4, + fetchTable: getUDP4Table, + states: make(map[string]map[string]*udpState), + dualStack: udp6Table, + } ) // CleanUDPStates cleans the udp connection states which save connection directions.