Add IPv6 dual-stack support

This commit is contained in:
Daniel
2020-11-02 14:18:42 +01:00
parent d4dea212dd
commit 914418876d
5 changed files with 140 additions and 70 deletions

View File

@@ -29,6 +29,8 @@ type BindInfo struct {
PID int
UID int
Inode int
ListensAny bool
}
// Address is an IP + Port pair.

View File

@@ -31,7 +31,7 @@ var (
var (
baseWaitTime = 3 * time.Millisecond
lookupRetries = 7
lookupRetries = 7 * 2 // Every retry takes two full passes.
)
// Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound.
@@ -68,97 +68,147 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) (
inbound bool,
err error,
) {
// Search pattern: search, wait, search, refresh, search, wait, search, refresh, ...
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
// search until we find something
// Search for the socket until found.
for i := 0; i <= lookupRetries; i++ {
table.lock.RLock()
// always search listeners first
for _, socketInfo := range table.listeners {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
table.lock.RUnlock()
return checkPID(socketInfo, true)
}
// Check main table for socket.
socketInfo, inbound := table.findSocket(pktInfo)
if socketInfo == nil && table.dualStack != nil {
// If there was no match in the main table and we are dual-stack, check
// the dual-stack table for the socket.
socketInfo, inbound = table.dualStack.findSocket(pktInfo)
}
// search connections
for _, socketInfo := range table.connections {
if localPort == socketInfo.Local.Port &&
localIP.Equal(socketInfo.Local.IP) {
table.lock.RUnlock()
return checkPID(socketInfo, false)
}
// If there's a match, check we have the PID and return.
if socketInfo != nil {
return checkPID(socketInfo, inbound)
}
table.lock.RUnlock()
// every time, except for the last iteration
if i < lookupRetries {
// we found nothing, we could have been too fast, give the kernel some time to think
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
time.Sleep(time.Duration(i+1) * baseWaitTime)
// refetch lists
table.updateTables()
// Take turns in waiting and refreshing in order to satisfy the search pattern.
if i%2 == 0 {
// we found nothing, we could have been too fast, give the kernel some time to think
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
time.Sleep(time.Duration(i+1) * baseWaitTime)
} else {
// refetch lists
table.updateTables()
if table.dualStack != nil {
table.dualStack.updateTables()
}
}
}
}
return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
}
func (table *tcpTable) findSocket(pktInfo *packet.Info) (
socketInfo socket.Info,
inbound bool,
) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
table.lock.RLock()
defer table.lock.RUnlock()
// always search listeners first
for _, socketInfo := range table.listeners {
if localPort == socketInfo.Local.Port &&
(socketInfo.ListensAny || localIP.Equal(socketInfo.Local.IP)) {
return socketInfo, false
}
}
// search connections
for _, socketInfo := range table.connections {
if localPort == socketInfo.Local.Port &&
localIP.Equal(socketInfo.Local.IP) {
return socketInfo, false
}
}
return nil, false
}
func (table *udpTable) lookup(pktInfo *packet.Info) (
pid int,
inbound bool,
err error,
) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
// Search pattern: search, wait, search, refresh, search, wait, search, refresh, ...
isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(localIP) == netutils.LocalMulticast
// TODO: Currently broadcast/multicast scopes are not checked, so we might
// attribute an incoming broadcast/multicast packet to the wrong process if
// there are multiple processes listening on the same local port, but
// binding to different addresses. This highly unusual for clients.
isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(pktInfo.LocalIP()) == netutils.LocalMulticast
// search until we find something
// Search for the socket until found.
for i := 0; i <= lookupRetries; i++ {
table.lock.RLock()
// search binds
for _, socketInfo := range table.binds {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || // zero IP
isInboundMulticast || // inbound broadcast, multicast
localIP.Equal(socketInfo.Local.IP)) {
table.lock.RUnlock()
// do not check direction if remoteIP/Port is not given
if pktInfo.RemotePort() == 0 {
return checkPID(socketInfo, pktInfo.Inbound)
}
// get direction and return
connInbound := table.getDirection(socketInfo, pktInfo)
return checkPID(socketInfo, connInbound)
}
// Check main table for socket.
socketInfo := table.findSocket(pktInfo, isInboundMulticast)
if socketInfo == nil && table.dualStack != nil {
// If there was no match in the main table and we are dual-stack, check
// the dual-stack table for the socket.
socketInfo = table.dualStack.findSocket(pktInfo, isInboundMulticast)
}
table.lock.RUnlock()
// If there's a match, get the direction and check we have the PID, then return.
if socketInfo != nil {
// If there is no remote port, do check for the direction of the
// connection. This will be the case for pure checking functions
// that do not want to change direction state.
if pktInfo.RemotePort() == 0 {
return checkPID(socketInfo, pktInfo.Inbound)
}
// Get (and save) the direction of the connection.
connInbound := table.getDirection(socketInfo, pktInfo)
// Check we have the PID and return.
return checkPID(socketInfo, connInbound)
}
// every time, except for the last iteration
if i < lookupRetries {
// we found nothing, we could have been too fast, give the kernel some time to think
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
time.Sleep(time.Duration(i+1) * baseWaitTime)
// refetch lists
table.updateTable()
// Take turns in waiting and refreshing in order to satisfy the search pattern.
if i%2 == 0 {
// we found nothing, we could have been too fast, give the kernel some time to think
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
time.Sleep(time.Duration(i+1) * baseWaitTime)
} else {
// refetch lists
table.updateTable()
if table.dualStack != nil {
table.dualStack.updateTable()
}
}
}
}
return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
}
func (table *udpTable) findSocket(pktInfo *packet.Info, isInboundMulticast bool) (socketInfo *socket.BindInfo) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
table.lock.RLock()
defer table.lock.RUnlock()
// search binds
for _, socketInfo := range table.binds {
if localPort == socketInfo.Local.Port &&
(socketInfo.ListensAny || // zero IP (dual-stack)
isInboundMulticast || // inbound broadcast, multicast
localIP.Equal(socketInfo.Local.IP)) {
return socketInfo
}
}
return nil
}

View File

@@ -1,6 +1,8 @@
package state
import (
"net"
"github.com/safing/portbase/log"
)
@@ -15,6 +17,11 @@ func (table *tcpTable) updateTables() {
return
}
// Pre-check for any listeners.
for _, bindInfo := range listeners {
bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero)
}
table.connections = connections
table.listeners = listeners
})
@@ -31,6 +38,11 @@ func (table *udpTable) updateTable() {
return
}
// Pre-check for any listeners.
for _, bindInfo := range binds {
bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero)
}
table.binds = binds
})
}

View File

@@ -16,16 +16,19 @@ type tcpTable struct {
fetchOnceAgain utils.OnceAgain
fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error)
dualStack *tcpTable
}
var (
tcp4Table = &tcpTable{
version: 4,
fetchTable: getTCP4Table,
}
tcp6Table = &tcpTable{
version: 6,
fetchTable: getTCP6Table,
}
tcp4Table = &tcpTable{
version: 4,
fetchTable: getTCP4Table,
dualStack: tcp6Table,
}
)

View File

@@ -22,6 +22,8 @@ type udpTable struct {
states map[string]map[string]*udpState
statesLock sync.Mutex
dualStack *udpTable
}
type udpState struct {
@@ -41,17 +43,18 @@ const (
)
var (
udp4Table = &udpTable{
version: 4,
fetchTable: getUDP4Table,
states: make(map[string]map[string]*udpState),
}
udp6Table = &udpTable{
version: 6,
fetchTable: getUDP6Table,
states: make(map[string]map[string]*udpState),
}
udp4Table = &udpTable{
version: 4,
fetchTable: getUDP4Table,
states: make(map[string]map[string]*udpState),
dualStack: udp6Table,
}
)
// CleanUDPStates cleans the udp connection states which save connection directions.