Revamp process attribution of network connections

This commit is contained in:
Daniel
2020-05-15 17:15:22 +02:00
parent 7a03eed1ff
commit 55b0ae8944
34 changed files with 1234 additions and 1196 deletions

View File

@@ -4,6 +4,8 @@ import (
"context"
"time"
"github.com/safing/portmaster/network/state"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/process"
)
@@ -22,8 +24,12 @@ func connectionCleaner(ctx context.Context) error {
ticker.Stop()
return nil
case <-ticker.C:
// clean connections and processes
activePIDs := cleanConnections()
process.CleanProcessStorage(activePIDs)
// clean udp connection states
state.CleanUDPStates(ctx)
}
}
}
@@ -33,12 +39,9 @@ func cleanConnections() (activePIDs map[int]struct{}) {
name := "clean connections" // TODO: change to new fn
_ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error {
activeIDs := make(map[string]struct{})
for _, cID := range process.GetActiveConnectionIDs() {
activeIDs[cID] = struct{}{}
}
now := time.Now().Unix()
now := time.Now().UTC()
nowUnix := now.Unix()
deleteOlderThan := time.Now().Add(-deleteConnsAfterEndedThreshold).Unix()
// lock both together because we cannot fully guarantee in which map a connection lands
@@ -49,20 +52,20 @@ func cleanConnections() (activePIDs map[int]struct{}) {
defer dnsConnsLock.Unlock()
// network connections
for key, conn := range conns {
for _, conn := range conns {
conn.Lock()
// delete inactive connections
switch {
case conn.Ended == 0:
// Step 1: check if still active
_, ok := activeIDs[key]
if ok {
exists := state.Exists(conn.IPVersion, conn.IPProtocol, conn.LocalIP, conn.LocalPort, conn.Entity.IP, conn.Entity.Port, now)
if exists {
activePIDs[conn.process.Pid] = struct{}{}
} else {
// Step 2: mark end
activePIDs[conn.process.Pid] = struct{}{}
conn.Ended = now
conn.Ended = nowUnix
conn.Save()
}
case conn.Ended < deleteOlderThan:

View File

@@ -25,11 +25,19 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
record.Base
sync.Mutex
ID string
Scope string
Inbound bool
Entity *intel.Entity // needs locking, instance is never shared
process *process.Process
ID string
Scope string
IPVersion packet.IPVersion
Inbound bool
// local endpoint
IPProtocol packet.IPProtocol
LocalIP net.IP
LocalPort uint16
process *process.Process
// remote endpoint
Entity *intel.Entity // needs locking, instance is never shared
Verdict Verdict
Reason string
@@ -55,9 +63,18 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
}
// NewConnectionFromDNSRequest returns a new connection based on the given dns request.
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection {
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, ipVersion packet.IPVersion, localIP net.IP, localPort uint16) *Connection {
// get Process
proc, err := process.GetProcessByEndpoints(ctx, localIP, localPort, dnsAddress, dnsPort, packet.UDP)
proc, _, err := process.GetProcessByEndpoints(
ctx,
ipVersion,
packet.UDP,
localIP,
localPort,
dnsAddress, // this might not be correct, but it does not matter, as matching only occurs on the local address
dnsPort,
false, // inbound, irrevelant
)
if err != nil {
log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err)
proc = process.GetUnidentifiedProcess(ctx)
@@ -147,11 +164,18 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
}
return &Connection{
ID: pkt.GetConnectionID(),
Scope: scope,
Inbound: inbound,
Entity: entity,
process: proc,
ID: pkt.GetConnectionID(),
Scope: scope,
IPVersion: pkt.Info().Version,
Inbound: inbound,
// local endpoint
IPProtocol: pkt.Info().Protocol,
LocalIP: pkt.Info().LocalIP(),
LocalPort: pkt.Info().LocalPort(),
process: proc,
// remote endpoint
Entity: entity,
// meta
Started: time.Now().Unix(),
}
}

View File

@@ -57,6 +57,13 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
return conn, nil
}
}
// case "system":
// if len(splitted) >= 2 {
// switch splitted[1] {
// case "":
// process.Get
// }
// }
}
return nil, storage.ErrNotFound

72
network/iphelper/get.go Normal file
View File

@@ -0,0 +1,72 @@
// +build windows
package iphelper
import (
"sync"
"github.com/safing/portmaster/network/socket"
)
const (
unidentifiedProcessID = -1
)
var (
ipHelper *IPHelper
lock sync.RWMutex
)
// GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, nil, err
}
return ipHelper.getTable(IPv4, TCP)
}
// GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, nil, err
}
return ipHelper.getTable(IPv6, TCP)
}
// GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, err
}
_, binds, err = ipHelper.getTable(IPv4, UDP)
return
}
// GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, err
}
_, binds, err = ipHelper.getTable(IPv6, UDP)
return
}

View File

@@ -0,0 +1,63 @@
// +build windows
package iphelper
import (
"errors"
"fmt"
"github.com/tevino/abool"
"golang.org/x/sys/windows"
)
var (
errInvalid = errors.New("IPHelper not initialzed or broken")
)
// IPHelper represents a subset of the Windows iphlpapi.dll.
type IPHelper struct {
dll *windows.LazyDLL
getExtendedTCPTable *windows.LazyProc
getExtendedUDPTable *windows.LazyProc
valid *abool.AtomicBool
}
func checkIPHelper() (err error) {
if ipHelper == nil {
ipHelper, err = New()
return err
}
return nil
}
// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded).
func New() (*IPHelper, error) {
new := &IPHelper{}
new.valid = abool.NewBool(false)
var err error
// load dll
new.dll = windows.NewLazySystemDLL("iphlpapi.dll")
err = new.dll.Load()
if err != nil {
return nil, err
}
// load functions
new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable")
err = new.getExtendedTCPTable.Find()
if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err)
}
new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable")
err = new.getExtendedUDPTable.Find()
if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err)
}
new.valid.Set()
return new, nil
}

304
network/iphelper/tables.go Normal file
View File

@@ -0,0 +1,304 @@
// +build windows
package iphelper
import (
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"unsafe"
"github.com/safing/portmaster/network/socket"
"golang.org/x/sys/windows"
)
// Windows API constants
const (
iphelperTCPTableOwnerPIDAll uintptr = 5
iphelperUDPTableOwnerPID uintptr = 1
iphelperTCPStateListen uint32 = 2
winErrInsufficientBuffer = uintptr(windows.ERROR_INSUFFICIENT_BUFFER)
winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER)
)
type iphelperTCPTable struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx
numEntries uint32
table [4096]iphelperTCPRow
}
type iphelperTCPRow struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366913(v=vs.85).aspx
state uint32
localAddr uint32
localPort uint32
remoteAddr uint32
remotePort uint32
owningPid uint32
}
type iphelperTCP6Table struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366905(v=vs.85).aspx
numEntries uint32
table [4096]iphelperTCP6Row
}
type iphelperTCP6Row struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366896(v=vs.85).aspx
localAddr [16]byte
_ uint32 // localScopeID
localPort uint32
remoteAddr [16]byte
_ uint32 // remoteScopeID
remotePort uint32
state uint32
owningPid uint32
}
type iphelperUDPTable struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366932(v=vs.85).aspx
numEntries uint32
table [4096]iphelperUDPRow
}
type iphelperUDPRow struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366928(v=vs.85).aspx
localAddr uint32
localPort uint32
owningPid uint32
}
type iphelperUDP6Table struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366925(v=vs.85).aspx
numEntries uint32
table [4096]iphelperUDP6Row
}
type iphelperUDP6Row struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366923(v=vs.85).aspx
localAddr [16]byte
_ uint32 // localScopeID
localPort uint32
owningPid uint32
}
// IP and Protocol constants
const (
IPv4 uint8 = 4
IPv6 uint8 = 6
TCP uint8 = 6
UDP uint8 = 17
)
const (
startBufSize = 4096
bufSizeUses = 100
)
var (
bufSize = startBufSize
bufSizeUsageLeft = bufSizeUses
bufSizeLock sync.Mutex
)
func getBufSize() int {
bufSizeLock.Lock()
defer bufSizeLock.Unlock()
// using bufSize
bufSizeUsageLeft--
// check if we want to reset
if bufSizeUsageLeft <= 0 {
// reset
bufSize = startBufSize
bufSizeUsageLeft = bufSizeUses
}
return bufSize
}
func increaseBufSize() int {
bufSizeLock.Lock()
defer bufSizeLock.Unlock()
// increase
bufSize *= 2
// not too much
if bufSize > 65536 {
bufSize = 65536
}
// reset
bufSizeUsageLeft = bufSizeUses
// return new bufSize
return bufSize
}
// getTable returns the current connection state table of Windows of the given protocol and IP version.
func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { //nolint:gocognit,gocycle // TODO
// docs: https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
if !ipHelper.valid.IsSet() {
return nil, nil, errInvalid
}
var afClass int
switch ipVersion {
case IPv4:
afClass = windows.AF_INET
case IPv6:
afClass = windows.AF_INET6
default:
return nil, nil, errors.New("invalid protocol")
}
// try max 3 times
maxTries := 3
bufSize := getBufSize()
var buf []byte
for i := 1; i <= maxTries; i++ {
buf = make([]byte, bufSize)
var r1 uintptr
switch protocol {
case TCP:
r1, _, err = ipHelper.getExtendedTCPTable.Call(
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable
uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize
0, // _In_ BOOL bOrder
uintptr(afClass), // _In_ ULONG ulAf
iphelperTCPTableOwnerPIDAll, // _In_ TCP_TABLE_CLASS TableClass
0, // _In_ ULONG Reserved
)
case UDP:
r1, _, err = ipHelper.getExtendedUDPTable.Call(
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable,
uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize,
0, // _In_ BOOL bOrder,
uintptr(afClass), // _In_ ULONG ulAf,
iphelperUDPTableOwnerPID, // _In_ UDP_TABLE_CLASS TableClass,
0, // _In_ ULONG Reserved
)
}
switch r1 {
case winErrInsufficientBuffer:
if i >= maxTries {
return nil, nil, fmt.Errorf("insufficient buffer error (tried %d times): [NT 0x%X] %s", i, r1, err)
}
bufSize = increaseBufSize()
case winErrInvalidParameter:
return nil, nil, fmt.Errorf("invalid parameter: [NT 0x%X] %s", r1, err)
case windows.NO_ERROR:
// success
break
default:
return nil, nil, fmt.Errorf("unexpected error: [NT 0x%X] %s", r1, err)
}
}
// parse output
switch {
case protocol == TCP && ipVersion == IPv4:
tcpTable := (*iphelperTCPTable)(unsafe.Pointer(&buf[0]))
table := tcpTable.table[:tcpTable.numEntries]
for _, row := range table {
if row.state == iphelperTCPStateListen {
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8),
},
PID: int(row.owningPid),
})
} else {
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8),
},
Remote: socket.Address{
IP: convertIPv4(row.remoteAddr),
Port: uint16(row.remotePort>>8 | row.remotePort<<8),
},
PID: int(row.owningPid),
})
}
}
case protocol == TCP && ipVersion == IPv6:
tcpTable := (*iphelperTCP6Table)(unsafe.Pointer(&buf[0]))
table := tcpTable.table[:tcpTable.numEntries]
for _, row := range table {
if row.state == iphelperTCPStateListen {
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8),
},
PID: int(row.owningPid),
})
} else {
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8),
},
Remote: socket.Address{
IP: net.IP(row.remoteAddr[:]),
Port: uint16(row.remotePort>>8 | row.remotePort<<8),
},
PID: int(row.owningPid),
})
}
}
case protocol == UDP && ipVersion == IPv4:
udpTable := (*iphelperUDPTable)(unsafe.Pointer(&buf[0]))
table := udpTable.table[:udpTable.numEntries]
for _, row := range table {
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8),
},
PID: int(row.owningPid),
})
}
case protocol == UDP && ipVersion == IPv6:
udpTable := (*iphelperUDP6Table)(unsafe.Pointer(&buf[0]))
table := udpTable.table[:udpTable.numEntries]
for _, row := range table {
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8),
},
PID: int(row.owningPid),
})
}
}
return connections, binds, nil
}
func convertIPv4(input uint32) net.IP {
addressBuf := make([]byte, 4)
binary.BigEndian.PutUint32(addressBuf, input)
return net.IP(addressBuf)
}

View File

@@ -0,0 +1,54 @@
// +build windows
package iphelper
import (
"fmt"
"testing"
)
func TestSockets(t *testing.T) {
connections, listeners, err := GetTCP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 4 connections:")
for _, connection := range connections {
fmt.Printf("%+v\n", connection)
}
fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners {
fmt.Printf("%+v\n", listener)
}
connections, listeners, err = GetTCP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 6 connections:")
for _, connection := range connections {
fmt.Printf("%+v\n", connection)
}
fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners {
fmt.Printf("%+v\n", listener)
}
binds, err := GetUDP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 4 binds:")
for _, bind := range binds {
fmt.Printf("%+v\n", bind)
}
binds, err = GetUDP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 6 binds:")
for _, bind := range binds {
fmt.Printf("%+v\n", bind)
}
}

View File

@@ -0,0 +1,62 @@
// +build windows
package main
import (
"fmt"
"github.com/safing/portmaster/process/iphelper"
)
func main() {
iph, err := iphelper.New()
if err != nil {
panic(err)
}
fmt.Printf("TCP4\n")
conns, lConns, err := iph.GetTables(iphelper.TCP, iphelper.IPv4)
if err != nil {
panic(err)
}
fmt.Printf("Connections:\n")
for _, conn := range conns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("Listeners:\n")
for _, conn := range lConns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("\nTCP6\n")
conns, lConns, err = iph.GetTables(iphelper.TCP, iphelper.IPv6)
if err != nil {
panic(err)
}
fmt.Printf("Connections:\n")
for _, conn := range conns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("Listeners:\n")
for _, conn := range lConns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("\nUDP4\n")
_, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv4)
if err != nil {
panic(err)
}
for _, conn := range lConns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("\nUDP6\n")
_, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv6)
if err != nil {
panic(err)
}
for _, conn := range lConns {
fmt.Printf("%s\n", conn)
}
}

168
network/proc/findpid.go Normal file
View File

@@ -0,0 +1,168 @@
// +build linux
package proc
import (
"fmt"
"os"
"sort"
"strconv"
"sync"
"syscall"
"github.com/safing/portbase/log"
)
const (
unidentifiedProcessID = -1
)
var (
pidsByUserLock sync.Mutex
pidsByUser = make(map[int][]int)
)
// FindPID returns the pid of the given uid and socket inode.
func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO
pidsByUserLock.Lock()
defer pidsByUserLock.Unlock()
pidsUpdated := false
// get pids of user, update if missing
pids, ok := pidsByUser[uid]
if !ok {
// log.Trace("process: no processes of user, updating table")
updatePids()
pidsUpdated = true
pids, ok = pidsByUser[uid]
}
if ok {
// if user has pids, start checking them first
var checkedUserPids []int
for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
}
checkedUserPids = append(checkedUserPids, possiblePID)
}
// if we fail on the first run and have not updated, update and check the ones we haven't tried so far.
if !pidsUpdated {
// log.Trace("process: socket not found in any process of user, updating table")
// update
updatePids()
// sort for faster search
for i, j := 0, len(checkedUserPids)-1; i < j; i, j = i+1, j-1 {
checkedUserPids[i], checkedUserPids[j] = checkedUserPids[j], checkedUserPids[i]
}
len := len(checkedUserPids)
// check unchecked pids
for _, possiblePID := range pids {
// only check if not already checked
if sort.SearchInts(checkedUserPids, possiblePID) == len {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
}
}
}
}
}
// check all other pids
// log.Trace("process: socket not found in any process of user, checking all pids")
// TODO: find best order for pidsByUser for best performance
for possibleUID, pids := range pidsByUser {
if possibleUID != uid {
for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
}
}
}
}
return unidentifiedProcessID, false
}
func findSocketFromPid(pid, inode int) bool {
socketName := fmt.Sprintf("socket:[%d]", inode)
entries := readDirNames(fmt.Sprintf("/proc/%d/fd", pid))
if len(entries) == 0 {
return false
}
for _, entry := range entries {
link, err := os.Readlink(fmt.Sprintf("/proc/%d/fd/%s", pid, entry))
if err != nil {
if !os.IsNotExist(err) {
log.Warningf("process: failed to read link /proc/%d/fd/%s: %s", pid, entry, err)
}
continue
}
if link == socketName {
return true
}
}
return false
}
func updatePids() {
pidsByUser = make(map[int][]int)
entries := readDirNames("/proc")
if len(entries) == 0 {
return
}
entryLoop:
for _, entry := range entries {
pid, err := strconv.ParseInt(entry, 10, 32)
if err != nil {
continue entryLoop
}
statData, err := os.Stat(fmt.Sprintf("/proc/%d", pid))
if err != nil {
log.Warningf("process: could not stat /proc/%d: %s", pid, err)
continue entryLoop
}
sys, ok := statData.Sys().(*syscall.Stat_t)
if !ok {
log.Warningf("process: unable to parse /proc/%d: wrong type", pid)
continue entryLoop
}
pids, ok := pidsByUser[int(sys.Uid)]
if ok {
pidsByUser[int(sys.Uid)] = append(pids, int(pid))
} else {
pidsByUser[int(sys.Uid)] = []int{int(pid)}
}
}
for _, slice := range pidsByUser {
for i, j := 0, len(slice)-1; i < j; i, j = i+1, j-1 {
slice[i], slice[j] = slice[j], slice[i]
}
}
}
func readDirNames(dir string) (names []string) {
file, err := os.Open(dir)
if err != nil {
if !os.IsNotExist(err) {
log.Warningf("process: could not open directory %s: %s", dir, err)
}
return
}
defer file.Close()
names, err = file.Readdirnames(0)
if err != nil {
log.Warningf("process: could not get entries from directory %s: %s", dir, err)
return []string{}
}
return
}

218
network/proc/tables.go Normal file
View File

@@ -0,0 +1,218 @@
// +build linux
package proc
import (
"bufio"
"encoding/hex"
"net"
"os"
"strconv"
"strings"
"unicode"
"github.com/safing/portmaster/network/socket"
"github.com/safing/portbase/log"
)
/*
1. find socket inode
- by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP?
- /proc/net/{tcp|udp}[6]
2. get list of processes of uid
3. find socket inode in process fds
- if not found, refresh map of uid->pids
- if not found, check ALL pids: maybe euid != uid
4. gather process info
Cache every step!
*/
// Network Related Constants
const (
TCP4 uint8 = iota
UDP4
TCP6
UDP6
ICMP4
ICMP6
TCP4Data = "/proc/net/tcp"
UDP4Data = "/proc/net/udp"
TCP6Data = "/proc/net/tcp6"
UDP6Data = "/proc/net/udp6"
ICMP4Data = "/proc/net/icmp"
ICMP6Data = "/proc/net/icmp6"
UnfetchedProcessID = -2
tcpListenStateHex = "0A"
)
// GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP4, TCP4Data, convertIPv4)
}
// GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP6, TCP6Data, convertIPv6)
}
// GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP4, UDP4Data, convertIPv4)
return
}
// GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP6, UDP6Data, convertIPv6)
return
}
func getTableFromSource(stack uint8, procFile string, ipConverter func(string) net.IP) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
// open file
socketData, err := os.Open(procFile)
if err != nil {
return nil, nil, err
}
defer socketData.Close()
// file scanner
scanner := bufio.NewScanner(socketData)
scanner.Split(bufio.ScanLines)
// parse
scanner.Scan() // skip first line
for scanner.Scan() {
line := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(line) < 14 {
// log.Tracef("process: too short: %s", line)
continue
}
localIP := ipConverter(line[1])
if localIP == nil {
continue
}
localPort, err := strconv.ParseUint(line[2], 16, 16)
if err != nil {
log.Warningf("process: could not parse port: %s", err)
continue
}
uid, err := strconv.ParseInt(line[11], 10, 32)
// log.Tracef("uid: %s", line[11])
if err != nil {
log.Warningf("process: could not parse uid %s: %s", line[11], err)
continue
}
inode, err := strconv.ParseInt(line[13], 10, 32)
// log.Tracef("inode: %s", line[13])
if err != nil {
log.Warningf("process: could not parse inode %s: %s", line[13], err)
continue
}
switch stack {
case UDP4, UDP6:
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
PID: UnfetchedProcessID,
UID: int(uid),
Inode: int(inode),
})
case TCP4, TCP6:
if line[5] == tcpListenStateHex {
// listener
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
PID: UnfetchedProcessID,
UID: int(uid),
Inode: int(inode),
})
} else {
// connection
remoteIP := ipConverter(line[3])
if remoteIP == nil {
continue
}
remotePort, err := strconv.ParseUint(line[4], 16, 16)
if err != nil {
log.Warningf("process: could not parse port: %s", err)
continue
}
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
Remote: socket.Address{
IP: remoteIP,
Port: uint16(remotePort),
},
PID: UnfetchedProcessID,
UID: int(uid),
Inode: int(inode),
})
}
}
}
return connections, binds, nil
}
func procDelimiter(c rune) bool {
return unicode.IsSpace(c) || c == ':'
}
func convertIPv4(data string) net.IP {
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("process: could not parse IPv4 %s: %s", data, err)
return nil
}
if len(decoded) != 4 {
log.Warningf("process: decoded IPv4 %s has wrong length", decoded)
return nil
}
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
return ip
}
func convertIPv6(data string) net.IP {
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("process: could not parse IPv6 %s: %s", data, err)
return nil
}
if len(decoded) != 16 {
log.Warningf("process: decoded IPv6 %s has wrong length", decoded)
return nil
}
ip := net.IP(decoded)
return ip
}

View File

@@ -0,0 +1,60 @@
// +build linux
package proc
import (
"fmt"
"testing"
)
func TestSockets(t *testing.T) {
connections, listeners, err := GetTCP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 4 connections:")
for _, connection := range connections {
pid, _ := FindPID(connection.UID, connection.Inode)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners {
pid, _ := FindPID(listener.UID, listener.Inode)
fmt.Printf("%d: %+v\n", pid, listener)
}
connections, listeners, err = GetTCP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 6 connections:")
for _, connection := range connections {
pid, _ := FindPID(connection.UID, connection.Inode)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners {
pid, _ := FindPID(listener.UID, listener.Inode)
fmt.Printf("%d: %+v\n", pid, listener)
}
binds, err := GetUDP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 4 binds:")
for _, bind := range binds {
pid, _ := FindPID(bind.UID, bind.Inode)
fmt.Printf("%d: %+v\n", pid, bind)
}
binds, err = GetUDP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 6 binds:")
for _, bind := range binds {
pid, _ := FindPID(bind.UID, bind.Inode)
fmt.Printf("%d: %+v\n", pid, bind)
}
}

26
network/socket/socket.go Normal file
View File

@@ -0,0 +1,26 @@
package socket
import "net"
// ConnectionInfo holds socket information returned by the system.
type ConnectionInfo struct {
Local Address
Remote Address
PID int
UID int
Inode int
}
// BindInfo holds socket information returned by the system.
type BindInfo struct {
Local Address
PID int
UID int
Inode int
}
// Address is an IP + Port pair.
type Address struct {
IP net.IP
Port uint16
}

103
network/state/exists.go Normal file
View File

@@ -0,0 +1,103 @@
package state
import (
"net"
"time"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/socket"
)
const (
UDPConnectionTTL = 10 * time.Minute
)
func Exists(
ipVersion packet.IPVersion,
protocol packet.IPProtocol,
localIP net.IP,
localPort uint16,
remoteIP net.IP,
remotePort uint16,
now time.Time,
) (exists bool) {
switch {
case ipVersion == packet.IPv4 && protocol == packet.TCP:
tcp4Lock.Lock()
defer tcp4Lock.Unlock()
return existsTCP(tcp4Connections, localIP, localPort, remoteIP, remotePort)
case ipVersion == packet.IPv6 && protocol == packet.TCP:
tcp6Lock.Lock()
defer tcp6Lock.Unlock()
return existsTCP(tcp6Connections, localIP, localPort, remoteIP, remotePort)
case ipVersion == packet.IPv4 && protocol == packet.UDP:
udp4Lock.Lock()
defer udp4Lock.Unlock()
return existsUDP(udp4Binds, udp4States, localIP, localPort, remoteIP, remotePort, now)
case ipVersion == packet.IPv6 && protocol == packet.UDP:
udp6Lock.Lock()
defer udp6Lock.Unlock()
return existsUDP(udp6Binds, udp6States, localIP, localPort, remoteIP, remotePort, now)
default:
return false
}
}
func existsTCP(
connections []*socket.ConnectionInfo,
localIP net.IP,
localPort uint16,
remoteIP net.IP,
remotePort uint16,
) (exists bool) {
// search connections
for _, socketInfo := range connections {
if localPort == socketInfo.Local.Port &&
remotePort == socketInfo.Remote.Port &&
remoteIP.Equal(socketInfo.Remote.IP) &&
localIP.Equal(socketInfo.Local.IP) {
return true
}
}
return false
}
func existsUDP(
binds []*socket.BindInfo,
udpStates map[string]map[string]*udpState,
localIP net.IP,
localPort uint16,
remoteIP net.IP,
remotePort uint16,
now time.Time,
) (exists bool) {
connThreshhold := now.Add(-UDPConnectionTTL)
// search binds
for _, socketInfo := range binds {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
udpConnState, ok := getUDPConnState(socketInfo, udpStates, remoteIP, remotePort)
switch {
case !ok:
return false
case udpConnState.lastSeen.After(connThreshhold):
return true
default:
return false
}
}
}
return false
}

189
network/state/lookup.go Normal file
View File

@@ -0,0 +1,189 @@
package state
import (
"errors"
"net"
"sync"
"time"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/socket"
)
// - TCP
// - Outbound: Match listeners (in!), then connections (out!)
// - Inbound: Match listeners (in!), then connections (out!)
// - Clean via connections
// - UDP
// - Any connection: match specific local address or zero IP
// - In or out: save direction of first packet:
// - map[<local udp bind ip+port>]map[<remote ip+port>]{direction, lastSeen}
// - only clean if <local udp bind ip+port> is removed by OS
// - limit <remote ip+port> to 256 entries?
// - clean <remote ip+port> after 72hrs?
// - switch direction to outbound if outbound packet is seen?
// - IP: Unidentified Process
const (
UnidentifiedProcessID = -1
)
// Errors
var (
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
ErrPIDNotFound = errors.New("could not find pid for socket inode")
)
var (
tcp4Lock sync.Mutex
tcp6Lock sync.Mutex
udp4Lock sync.Mutex
udp6Lock sync.Mutex
waitTime = 3 * time.Millisecond
)
func LookupWithPacket(pkt packet.Packet) (pid int, inbound bool, err error) {
meta := pkt.Info()
return Lookup(
meta.Version,
meta.Protocol,
meta.LocalIP(),
meta.LocalPort(),
meta.RemoteIP(),
meta.RemotePort(),
meta.Direction,
)
}
func Lookup(
ipVersion packet.IPVersion,
protocol packet.IPProtocol,
localIP net.IP,
localPort uint16,
remoteIP net.IP,
remotePort uint16,
pktInbound bool,
) (
pid int,
inbound bool,
err error,
) {
// auto-detect version
if ipVersion == 0 {
if ip := localIP.To4(); ip != nil {
ipVersion = packet.IPv4
} else {
ipVersion = packet.IPv6
}
}
switch {
case ipVersion == packet.IPv4 && protocol == packet.TCP:
tcp4Lock.Lock()
defer tcp4Lock.Unlock()
return searchTCP(tcp4Connections, tcp4Listeners, updateTCP4Tables, localIP, localPort)
case ipVersion == packet.IPv6 && protocol == packet.TCP:
tcp6Lock.Lock()
defer tcp6Lock.Unlock()
return searchTCP(tcp6Connections, tcp6Listeners, updateTCP6Tables, localIP, localPort)
case ipVersion == packet.IPv4 && protocol == packet.UDP:
udp4Lock.Lock()
defer udp4Lock.Unlock()
return searchUDP(udp4Binds, udp4States, updateUDP4Table, localIP, localPort, remoteIP, remotePort, pktInbound)
case ipVersion == packet.IPv6 && protocol == packet.UDP:
udp6Lock.Lock()
defer udp6Lock.Unlock()
return searchUDP(udp6Binds, udp6States, updateUDP6Table, localIP, localPort, remoteIP, remotePort, pktInbound)
default:
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
}
}
func searchTCP(
connections []*socket.ConnectionInfo,
listeners []*socket.BindInfo,
updateTables func() ([]*socket.ConnectionInfo, []*socket.BindInfo),
localIP net.IP,
localPort uint16,
) (
pid int,
inbound bool,
err error,
) {
// search until we find something
for i := 0; i < 5; i++ {
// always search listeners first
for _, socketInfo := range listeners {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
return checkBindPID(socketInfo, true)
}
}
// search connections
for _, socketInfo := range connections {
if localPort == socketInfo.Local.Port &&
localIP.Equal(socketInfo.Local.IP) {
return checkConnectionPID(socketInfo, false)
}
}
// we found nothing, we could have been too fast, give the kernel some time to think
time.Sleep(waitTime)
// refetch lists
connections, listeners = updateTables()
}
return UnidentifiedProcessID, false, ErrConnectionNotFound
}
func searchUDP(
binds []*socket.BindInfo,
udpStates map[string]map[string]*udpState,
updateTable func() []*socket.BindInfo,
localIP net.IP,
localPort uint16,
remoteIP net.IP,
remotePort uint16,
pktInbound bool,
) (
pid int,
inbound bool,
err error,
) {
// search until we find something
for i := 0; i < 5; i++ {
// search binds
for _, socketInfo := range binds {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
// do not check direction if remoteIP/Port is not given
if remotePort == 0 {
return checkBindPID(socketInfo, pktInbound)
}
// get direction and return
connInbound := getUDPDirection(socketInfo, udpStates, remoteIP, remotePort, pktInbound)
return checkBindPID(socketInfo, connInbound)
}
}
// we found nothing, we could have been too fast, give the kernel some time to think
time.Sleep(waitTime)
// refetch lists
binds = updateTable()
}
return UnidentifiedProcessID, pktInbound, ErrConnectionNotFound
}

View File

@@ -0,0 +1,37 @@
package state
import (
"github.com/safing/portmaster/network/proc"
"github.com/safing/portmaster/network/socket"
)
var (
getTCP4Table = proc.GetTCP4Table
getTCP6Table = proc.GetTCP6Table
getUDP4Table = proc.GetUDP4Table
getUDP6Table = proc.GetUDP6Table
)
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID {
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode)
if ok {
socketInfo.PID = pid
} else {
socketInfo.PID = UnidentifiedProcessID
}
}
return socketInfo.PID, connInbound, nil
}
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID {
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode)
if ok {
socketInfo.PID = pid
} else {
socketInfo.PID = UnidentifiedProcessID
}
}
return socketInfo.PID, connInbound, nil
}

View File

@@ -0,0 +1,21 @@
package state
import (
"github.com/safing/portmaster/network/iphelper"
"github.com/safing/portmaster/network/socket"
)
var (
getTCP4Table = iphelper.GetTCP4Table
getTCP6Table = iphelper.GetTCP6Table
getUDP4Table = iphelper.GetUDP4Table
getUDP6Table = iphelper.GetUDP6Table
)
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
return socketInfo.PID, connInbound, nil
}
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
return socketInfo.PID, connInbound, nil
}

66
network/state/tables.go Normal file
View File

@@ -0,0 +1,66 @@
package state
import (
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network/socket"
)
var (
tcp4Connections []*socket.ConnectionInfo
tcp4Listeners []*socket.BindInfo
tcp6Connections []*socket.ConnectionInfo
tcp6Listeners []*socket.BindInfo
udp4Binds []*socket.BindInfo
udp6Binds []*socket.BindInfo
)
func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
// FIXME: repeatable once
connections, listeners, err := getTCP4Table()
if err != nil {
log.Warningf("state: failed to get TCP4 socket table: %s", err)
return
}
tcp4Connections = connections
tcp4Listeners = listeners
return tcp4Connections, tcp4Listeners
}
func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
connections, listeners, err := getTCP6Table()
if err != nil {
log.Warningf("state: failed to get TCP6 socket table: %s", err)
return
}
tcp6Connections = connections
tcp6Listeners = listeners
return tcp6Connections, tcp6Listeners
}
func updateUDP4Table() (binds []*socket.BindInfo) {
binds, err := getUDP4Table()
if err != nil {
log.Warningf("state: failed to get UDP4 socket table: %s", err)
return
}
udp4Binds = binds
return udp4Binds
}
func updateUDP6Table() (binds []*socket.BindInfo) {
binds, err := getUDP6Table()
if err != nil {
log.Warningf("state: failed to get UDP6 socket table: %s", err)
return
}
udp6Binds = binds
return udp6Binds
}

118
network/state/udp.go Normal file
View File

@@ -0,0 +1,118 @@
package state
import (
"context"
"net"
"time"
"github.com/safing/portmaster/network/socket"
)
type udpState struct {
inbound bool
lastSeen time.Time
}
const (
UpdConnStateTTL = 72 * time.Hour
UdpConnStateShortenedTTL = 3 * time.Hour
AggressiveCleaningThreshold = 256
)
var (
udp4States = make(map[string]map[string]*udpState) // locked with udp4Lock
udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock
)
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16) (udpConnState *udpState, ok bool) {
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)]
if ok {
udpConnState, ok = bindMap[makeUDPStateKey(remoteIP, remotePort)]
return
}
return nil, false
}
func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16, pktInbound bool) (connDirection bool) {
localKey := makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)
bindMap, ok := udpStates[localKey]
if !ok {
bindMap = make(map[string]*udpState)
udpStates[localKey] = bindMap
}
remoteKey := makeUDPStateKey(remoteIP, remotePort)
udpConnState, ok := bindMap[remoteKey]
if !ok {
bindMap[remoteKey] = &udpState{
inbound: pktInbound,
lastSeen: time.Now().UTC(),
}
return pktInbound
}
udpConnState.lastSeen = time.Now().UTC()
return udpConnState.inbound
}
func CleanUDPStates(ctx context.Context) {
now := time.Now().UTC()
udp4Lock.Lock()
updateUDP4Table()
cleanStates(ctx, udp4Binds, udp4States, now)
udp4Lock.Unlock()
udp6Lock.Lock()
updateUDP6Table()
cleanStates(ctx, udp6Binds, udp6States, now)
udp6Lock.Unlock()
}
func cleanStates(
ctx context.Context,
binds []*socket.BindInfo,
udpStates map[string]map[string]*udpState,
now time.Time,
) {
// compute thresholds
threshold := now.Add(-UpdConnStateTTL)
shortThreshhold := now.Add(-UdpConnStateShortenedTTL)
// make list of all active keys
bindKeys := make(map[string]struct{})
for _, socketInfo := range binds {
bindKeys[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] = struct{}{}
}
// clean the udp state storage
for localKey, bindMap := range udpStates {
_, active := bindKeys[localKey]
if active {
// clean old entries
for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(threshold) {
delete(bindMap, remoteKey)
}
}
// if there are too many clean more aggressively
if len(bindMap) > AggressiveCleaningThreshold {
for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(shortThreshhold) {
delete(bindMap, remoteKey)
}
}
}
} else {
// delete the whole thing
delete(udpStates, localKey)
}
}
}
func makeUDPStateKey(ip net.IP, port uint16) string {
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
return string(ip) + string(port)
}