wip: migrate to mono-repo. SPN has already been moved to spn/

This commit is contained in:
Patrick Pacher
2024-03-15 11:55:13 +01:00
parent b30fd00ccf
commit 8579430db9
577 changed files with 35981 additions and 818 deletions

View File

@@ -0,0 +1,107 @@
//go:build linux
package proc
import (
"errors"
"io/fs"
"os"
"strconv"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/socket"
)
// GetPID returns the already existing pid of the given socket info or searches for it.
// This also acts as a getter for socket.Info.PID, as locking for that occurs here.
func GetPID(socketInfo socket.Info) (pid int) {
// Get currently assigned PID to the socket info.
currentPid := socketInfo.GetPID()
// If the current PID already is valid (ie. not unidentified), return it immediately.
if currentPid != socket.UndefinedProcessID {
return currentPid
}
// Find PID for the given UID and inode.
pid = findPID(socketInfo.GetUIDandInode())
// Set the newly found PID on the socket info.
socketInfo.SetPID(pid)
// Return found PID.
return pid
}
// findPID returns the pid of the given uid and socket inode.
func findPID(uid, inode int) (pid int) {
socketName := "socket:[" + strconv.Itoa(inode) + "]"
// Always update pid table (it has a call limiter anyway)
updatePids()
// Get all pids for the given uid.
pids, ok := getPidsByUser(uid)
if !ok {
return socket.UndefinedProcessID
}
// Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to
// be searched for.
for j := len(pids) - 1; j >= 0; j-- {
if pidHasSocket(pids[j], socketName) {
return pids[j]
}
}
return socket.UndefinedProcessID
}
func pidHasSocket(pid int, socketName string) bool {
socketBase := "/proc/" + strconv.Itoa(pid) + "/fd"
entries := readDirNames(socketBase)
if len(entries) == 0 {
return false
}
socketBase += "/"
// Look through the FDs in reverse order, because higher/newer FDs will be
// more likely to be searched for.
for i := len(entries) - 1; i >= 0; i-- {
link, err := os.Readlink(socketBase + entries[i])
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Warningf("proc: failed to read link /proc/%d/fd/%s: %s", pid, entries[i], err)
}
continue
}
if link == socketName {
return true
}
}
return false
}
// readDirNames only reads the directory names. Using os.ReadDir() would call `lstat` on every
// resulting directory name, which we don't need. This function will be called a lot, so we should
// refrain from unnecessary work.
func readDirNames(dir string) (names []string) {
file, err := os.Open(dir)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Warningf("proc: could not open directory %s: %s", dir, err)
}
return
}
defer func() {
_ = file.Close()
}()
names, err = file.Readdirnames(0)
if err != nil {
log.Warningf("proc: could not get entries from directory %s: %s", dir, err)
return []string{}
}
return
}

View File

@@ -0,0 +1,81 @@
//go:build linux
package proc
import (
"errors"
"io/fs"
"os"
"strconv"
"sync"
"syscall"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
)
var (
// pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs.
pidsByUser = make(map[int][]int)
pidsByUserLock sync.RWMutex
fetchPidsByUser = utils.NewCallLimiter(10 * time.Millisecond)
)
// getPidsByUser returns the cached PIDs for the given UID.
func getPidsByUser(uid int) (pids []int, ok bool) {
pidsByUserLock.RLock()
defer pidsByUserLock.RUnlock()
pids, ok = pidsByUser[uid]
return
}
// updatePids fetches and creates a new pidsByUser map using a call limiter.
func updatePids() {
fetchPidsByUser.Do(func() {
newPidsByUser := make(map[int][]int)
pidCnt := 0
entries := readDirNames("/proc")
if len(entries) == 0 {
log.Warning("proc: found no PIDs in /proc")
return
}
entryLoop:
for _, entry := range entries {
pid, err := strconv.ParseInt(entry, 10, 32)
if err != nil {
continue entryLoop
}
statData, err := os.Stat("/proc/" + strconv.FormatInt(pid, 10))
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Warningf("proc: could not stat /proc/%d: %s", pid, err)
}
continue entryLoop
}
sys, ok := statData.Sys().(*syscall.Stat_t)
if !ok {
log.Warningf("proc: unable to parse /proc/%d: wrong type", pid)
continue entryLoop
}
pids, ok := newPidsByUser[int(sys.Uid)]
if ok {
newPidsByUser[int(sys.Uid)] = append(pids, int(pid))
} else {
newPidsByUser[int(sys.Uid)] = []int{int(pid)}
}
pidCnt++
}
// log.Tracef("proc: updated PID table with %d entries", pidCnt)
pidsByUserLock.Lock()
defer pidsByUserLock.Unlock()
pidsByUser = newPidsByUser
})
}

View File

@@ -0,0 +1,244 @@
//go:build linux
package proc
import (
"bufio"
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
"strings"
"unicode"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/socket"
)
/*
1. find socket inode
- by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP?
- /proc/net/{tcp|udp}[6]
2. get list of processes of uid
3. find socket inode in process fds
- if not found, refresh map of uid->pids
- if not found, check ALL pids: maybe euid != uid
4. gather process info
Cache every step!
*/
// Network Related Constants.
const (
TCP4 uint8 = iota
UDP4
TCP6
UDP6
ICMP4
ICMP6
tcp4ProcFile = "/proc/net/tcp"
tcp6ProcFile = "/proc/net/tcp6"
udp4ProcFile = "/proc/net/udp"
udp6ProcFile = "/proc/net/udp6"
tcpListenStateHex = "0A"
)
// GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP4, tcp4ProcFile)
}
// GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP6, tcp6ProcFile)
}
// GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP4, udp4ProcFile)
return
}
// GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP6, udp6ProcFile)
return
}
const (
// hint: we split fields by multiple delimiters, see procDelimiter
fieldIndexLocalIP = 1
fieldIndexLocalPort = 2
fieldIndexRemoteIP = 3
fieldIndexRemotePort = 4
fieldIndexUID = 11
fieldIndexInode = 13
)
func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
var ipConverter func(string) net.IP
switch stack {
case TCP4, UDP4:
ipConverter = convertIPv4
case TCP6, UDP6:
ipConverter = convertIPv6
default:
return nil, nil, fmt.Errorf("unsupported table stack: %d", stack)
}
// open file
socketData, err := os.Open(procFile)
if err != nil {
return nil, nil, err
}
defer func() {
_ = socketData.Close()
}()
// file scanner
scanner := bufio.NewScanner(socketData)
scanner.Split(bufio.ScanLines)
// parse
scanner.Scan() // skip first row
for scanner.Scan() {
fields := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(fields) < 14 {
// log.Tracef("proc: too short: %s", fields)
continue
}
localIP := ipConverter(fields[fieldIndexLocalIP])
if localIP == nil {
continue
}
localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16)
if err != nil {
log.Warningf("proc: could not parse port: %s", err)
continue
}
uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32)
// log.Tracef("uid: %s", fields[fieldIndexUID])
if err != nil {
log.Warningf("proc: could not parse uid %s: %s", fields[11], err)
continue
}
inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32)
// log.Tracef("inode: %s", fields[fieldIndexInode])
if err != nil {
log.Warningf("proc: could not parse inode %s: %s", fields[13], err)
continue
}
switch stack {
case UDP4, UDP6:
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
PID: socket.UndefinedProcessID,
UID: int(uid),
Inode: int(inode),
})
case TCP4, TCP6:
if fields[5] == tcpListenStateHex {
// listener
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
PID: socket.UndefinedProcessID,
UID: int(uid),
Inode: int(inode),
})
} else {
// connection
remoteIP := ipConverter(fields[fieldIndexRemoteIP])
if remoteIP == nil {
continue
}
remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16)
if err != nil {
log.Warningf("proc: could not parse port: %s", err)
continue
}
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
Remote: socket.Address{
IP: remoteIP,
Port: uint16(remotePort),
},
PID: socket.UndefinedProcessID,
UID: int(uid),
Inode: int(inode),
})
}
}
}
return connections, binds, nil
}
func procDelimiter(c rune) bool {
return unicode.IsSpace(c) || c == ':'
}
func convertIPv4(data string) net.IP {
// Decode and bullshit check the data length.
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("proc: could not parse IPv4 %s: %s", data, err)
return nil
}
if len(decoded) != 4 {
log.Warningf("proc: decoded IPv4 %s has wrong length", decoded)
return nil
}
// Build the IPv4 address with the reversed byte order.
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
return ip
}
func convertIPv6(data string) net.IP {
// Decode and bullshit check the data length.
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("proc: could not parse IPv6 %s: %s", data, err)
return nil
}
if len(decoded) != 16 {
log.Warningf("proc: decoded IPv6 %s has wrong length", decoded)
return nil
}
// Build the IPv6 address with the translated byte order.
for i := 0; i < 16; i += 4 {
decoded[i], decoded[i+1], decoded[i+2], decoded[i+3] = decoded[i+3], decoded[i+2], decoded[i+1], decoded[i]
}
ip := net.IP(decoded)
return ip
}

View File

@@ -0,0 +1,62 @@
//go:build linux
package proc
import (
"fmt"
"testing"
)
func TestSockets(t *testing.T) {
t.Parallel()
connections, listeners, err := GetTCP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 4 connections:")
for _, connection := range connections {
pid := GetPID(connection)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners {
pid := GetPID(listener)
fmt.Printf("%d: %+v\n", pid, listener)
}
connections, listeners, err = GetTCP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 6 connections:")
for _, connection := range connections {
pid := GetPID(connection)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners {
pid := GetPID(listener)
fmt.Printf("%d: %+v\n", pid, listener)
}
binds, err := GetUDP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 4 binds:")
for _, bind := range binds {
pid := GetPID(bind)
fmt.Printf("%d: %+v\n", pid, bind)
}
binds, err = GetUDP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 6 binds:")
for _, bind := range binds {
pid := GetPID(bind)
fmt.Printf("%d: %+v\n", pid, bind)
}
}