wip: migrate to mono-repo. SPN has already been moved to spn/

This commit is contained in:
Patrick Pacher
2024-03-15 11:55:13 +01:00
parent b30fd00ccf
commit 8579430db9
577 changed files with 35981 additions and 818 deletions

340
service/network/api.go Normal file
View File

@@ -0,0 +1,340 @@
package network
import (
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"time"
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/utils/debug"
"github.com/safing/portmaster/service/network/state"
"github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/service/resolver"
"github.com/safing/portmaster/service/status"
"github.com/safing/portmaster/service/updates"
)
func registerAPIEndpoints() error {
if err := api.RegisterEndpoint(api.Endpoint{
Path: "debug/network",
Read: api.PermitUser,
BelongsTo: module,
DataFunc: debugInfo,
Name: "Get Network Debug Information",
Description: "Returns network debugging information, similar to debug/core, but with connection data.",
Parameters: []api.Parameter{
{
Method: http.MethodGet,
Field: "style",
Value: "github",
Description: "Specify the formatting style. The default is simple markdown formatting.",
},
{
Method: http.MethodGet,
Field: "profile",
Value: "<Source>/<ID>",
Description: "Specify a profile source and ID for which network connection should be reported.",
},
{
Method: http.MethodGet,
Field: "where",
Value: "<query>",
Description: "Specify a query to limit the connections included in the report. The default is to include all connections.",
},
},
}); err != nil {
return err
}
if err := api.RegisterEndpoint(api.Endpoint{
Path: "debug/network/state",
Read: api.PermitUser,
BelongsTo: module,
StructFunc: func(ar *api.Request) (i interface{}, err error) {
return state.GetInfo(), nil
},
Name: "Get Network State Table Data",
Description: "Returns the current network state tables from the OS.",
}); err != nil {
return err
}
return nil
}
// debugInfo returns the debugging information for support requests.
func debugInfo(ar *api.Request) (data []byte, err error) {
// Create debug information helper.
di := new(debug.Info)
di.Style = ar.Request.URL.Query().Get("style")
// Add debug information.
// Very basic information at the start.
di.AddVersionInfo()
di.AddPlatformInfo(ar.Context())
// Errors and unexpected logs.
di.AddLastReportedModuleError()
di.AddLastUnexpectedLogs()
// Network Connections.
AddNetworkDebugData(
di,
ar.Request.URL.Query().Get("profile"),
ar.Request.URL.Query().Get("where"),
)
// Status Information from various modules.
status.AddToDebugInfo(di)
// captain.AddToDebugInfo(di) // TODO: Cannot use due to import loop.
resolver.AddToDebugInfo(di)
config.AddToDebugInfo(di)
// Detailed information.
updates.AddToDebugInfo(di)
// compat.AddToDebugInfo(di) // TODO: Cannot use due to interception import requirement which we don't want for SPN Hubs.
di.AddGoroutineStack()
// Return data.
return di.Bytes(), nil
}
// AddNetworkDebugData adds the network debug data of the given profile to the debug data.
func AddNetworkDebugData(di *debug.Info, profile, where string) {
// Prepend where prefix to query if necessary.
if where != "" && !strings.HasPrefix(where, "where ") {
where = "where " + where
}
// Build query.
q, err := query.ParseQuery("query network: " + where)
if err != nil {
di.AddSection(
"Network: Debug Failed",
debug.NoFlags,
fmt.Sprintf("Failed to build query: %s", err),
)
return
}
// Get iterator.
it, err := dbController.Query(q, true, true)
if err != nil {
di.AddSection(
"Network: Debug Failed",
debug.NoFlags,
fmt.Sprintf("Failed to run query: %s", err),
)
return
}
// Collect matching connections.
var ( //nolint:prealloc // We don't know the size.
debugConns []*Connection
accepted int
total int
)
for maybeConn := range it.Next {
// Switch to correct type.
conn, ok := maybeConn.(*Connection)
if !ok {
continue
}
// Check if the profile matches
if profile != "" {
found := false
// Get layer IDs and search for a match.
layerIDs := conn.Process().Profile().LayerIDs
for _, layerID := range layerIDs {
if profile == layerID {
found = true
break
}
}
// Skip if the profile does not match.
if !found {
continue
}
}
// Count.
total++
switch conn.Verdict { //nolint:exhaustive
case VerdictAccept,
VerdictRerouteToNameserver,
VerdictRerouteToTunnel:
accepted++
}
// Add to list.
debugConns = append(debugConns, conn)
}
// Add it all.
di.AddSection(
fmt.Sprintf(
"Network: %d/%d Connections",
accepted,
total,
),
debug.UseCodeSection|debug.AddContentLineBreaks,
buildNetworkDebugInfoData(debugConns),
)
}
func buildNetworkDebugInfoData(debugConns []*Connection) string {
// Sort
sort.Sort(connectionsByGroup(debugConns))
// Format lines
var buf strings.Builder
currentPID := process.UndefinedProcessID
for _, conn := range debugConns {
conn.Lock()
// Add process infomration if it differs from previous connection.
if currentPID != conn.ProcessContext.PID {
if currentPID != process.UndefinedProcessID {
buf.WriteString("\n\n\n")
}
buf.WriteString("ProfileName: " + conn.ProcessContext.ProfileName)
buf.WriteString("\nProfile: " + conn.ProcessContext.Profile)
buf.WriteString("\nSource: " + conn.ProcessContext.Source)
buf.WriteString("\nProcessName: " + conn.ProcessContext.ProcessName)
buf.WriteString("\nBinaryPath: " + conn.ProcessContext.BinaryPath)
buf.WriteString("\nCmdLine: " + conn.ProcessContext.CmdLine)
buf.WriteString("\nPID: " + strconv.Itoa(conn.ProcessContext.PID))
buf.WriteString("\n")
// Set current PID in order to not print the process information again.
currentPID = conn.ProcessContext.PID
}
// Add connection.
buf.WriteString("\n")
buf.WriteString(conn.debugInfoLine())
conn.Unlock()
}
return buf.String()
}
func (conn *Connection) debugInfoLine() string {
var connectionData string
if conn.Type == IPConnection {
// Format IP/Port pair for connections.
connectionData = fmt.Sprintf(
"% 15s:%- 5s %s % 15s:%- 5s",
conn.LocalIP,
strconv.Itoa(int(conn.LocalPort)),
conn.fmtProtocolAndDirectionComponent(conn.IPProtocol.String()),
conn.Entity.IP,
strconv.Itoa(int(conn.Entity.Port)),
)
} else {
// Leave empty for DNS Requests.
connectionData = " "
}
return fmt.Sprintf(
"% 14s %s%- 25s %s-%s P#%d [%s] %s - by %s @ %s",
conn.VerdictVerb(),
connectionData,
conn.fmtDomainComponent(),
time.Unix(conn.Started, 0).Format("15:04:05"),
conn.fmtEndTimeComponent(),
conn.ProcessContext.PID,
conn.fmtFlagsComponent(),
conn.Reason.Msg,
conn.Reason.OptionKey,
conn.fmtReasonProfileComponent(),
)
}
func (conn *Connection) fmtDomainComponent() string {
if conn.Entity.Domain != "" {
return " to " + conn.Entity.Domain
}
return ""
}
func (conn *Connection) fmtProtocolAndDirectionComponent(protocol string) string {
if conn.Inbound {
return "<" + protocol
}
return protocol + ">"
}
func (conn *Connection) fmtFlagsComponent() string {
var f string
if conn.Internal {
f += "I"
}
if conn.Encrypted {
f += "E"
}
if conn.Tunneled {
f += "T"
}
if len(conn.activeInspectors) > 0 {
f += "A"
}
if conn.addedToMetrics {
f += "M"
}
return f
}
func (conn *Connection) fmtEndTimeComponent() string {
if conn.Ended == 0 {
return " " // Use same width as a timestamp.
}
return time.Unix(conn.Ended, 0).Format("15:04:05")
}
func (conn *Connection) fmtReasonProfileComponent() string {
if conn.Reason.Profile == "" {
return "global"
}
return conn.Reason.Profile
}
type connectionsByGroup []*Connection
func (a connectionsByGroup) Len() int { return len(a) }
func (a connectionsByGroup) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a connectionsByGroup) Less(i, j int) bool {
// Sort by:
// 1. Profile ID
if a[i].ProcessContext.Profile != a[j].ProcessContext.Profile {
return a[i].ProcessContext.Profile < a[j].ProcessContext.Profile
}
// 2. Process Binary
if a[i].ProcessContext.BinaryPath != a[j].ProcessContext.BinaryPath {
return a[i].ProcessContext.BinaryPath < a[j].ProcessContext.BinaryPath
}
// 3. Process ID
if a[i].ProcessContext.PID != a[j].ProcessContext.PID {
return a[i].ProcessContext.PID < a[j].ProcessContext.PID
}
// 4. Started
return a[i].Started < a[j].Started
}

149
service/network/api_test.go Normal file
View File

@@ -0,0 +1,149 @@
package network
import (
"fmt"
"net"
"testing"
"github.com/safing/portmaster/service/intel"
)
func TestDebugInfoLineFormatting(t *testing.T) {
t.Parallel()
for _, conn := range connectionTestData {
fmt.Println(conn.debugInfoLine())
}
}
func TestDebugInfoFormatting(t *testing.T) {
t.Parallel()
fmt.Println(buildNetworkDebugInfoData(connectionTestData))
}
var connectionTestData = []*Connection{
{
ID: "17-255.255.255.255-29810-192.168.0.23-40672",
Scope: "IL",
IPVersion: 4,
Inbound: true,
IPProtocol: 17,
LocalIP: net.ParseIP("255.255.255.255"),
LocalPort: 29810,
Entity: &intel.Entity{
Protocol: 17,
Port: 40672,
Domain: "",
ReverseDomain: "",
IP: net.ParseIP("192.168.0.23"),
Country: "",
ASN: 0,
},
Verdict: 2,
Reason: Reason{
Msg: "incoming connection blocked by default",
OptionKey: "filter/serviceEndpoints",
Profile: "",
},
Started: 1614010349,
Ended: 1614010350,
VerdictPermanent: true,
Inspecting: false,
Tunneled: false,
Encrypted: false,
ProcessContext: ProcessContext{
ProcessName: "Unidentified Processes",
ProfileName: "Unidentified Processes",
BinaryPath: "",
PID: -1,
Profile: "_unidentified",
Source: "local",
},
Internal: false,
ProfileRevisionCounter: 1,
},
{
ID: "6-192.168.0.176-55216-13.32.6.15-80",
Scope: "PI",
IPVersion: 4,
Inbound: false,
IPProtocol: 6,
LocalIP: net.ParseIP("192.168.0.176"),
LocalPort: 55216,
Entity: &intel.Entity{
Protocol: 6,
Port: 80,
Domain: "",
ReverseDomain: "",
IP: net.ParseIP("13.32.6.15"),
Country: "DE",
ASN: 16509,
},
Verdict: 2,
Reason: Reason{
Msg: "default permit",
OptionKey: "filter/defaultAction",
Profile: "",
},
Started: 1614010475,
Ended: 1614010565,
VerdictPermanent: true,
Inspecting: false,
Tunneled: false,
Encrypted: false,
ProcessContext: ProcessContext{
ProcessName: "NetworkManager",
ProfileName: "Network Manager",
BinaryPath: "/usr/sbin/NetworkManager",
PID: 1273,
Profile: "3a9b0eb5-c7fe-4bc7-9b93-a90f4ff84b5b",
Source: "local",
},
Internal: true,
ProfileRevisionCounter: 1,
},
{
ID: "6-192.168.0.176-49982-142.250.74.211-443",
Scope: "pkg.go.dev.",
IPVersion: 4,
Inbound: false,
IPProtocol: 6,
LocalIP: net.ParseIP("192.168.0.176"),
LocalPort: 49982,
Entity: &intel.Entity{
Protocol: 6,
Port: 443,
Domain: "pkg.go.dev.",
ReverseDomain: "",
CNAME: []string{
"ghs.googlehosted.com.",
},
IP: net.ParseIP("142.250.74.211"),
Country: "US",
ASN: 15169,
},
Verdict: 2,
Reason: Reason{
Msg: "default permit",
OptionKey: "filter/defaultAction",
Profile: "",
},
Started: 1614010415,
Ended: 1614010745,
VerdictPermanent: true,
Inspecting: false,
Tunneled: false,
Encrypted: false,
ProcessContext: ProcessContext{
ProcessName: "firefox",
ProfileName: "Firefox",
BinaryPath: "/usr/bin/firefox",
PID: 5710,
Profile: "74b30392-9e4d-4157-83a9-fffafd3e2bde",
Source: "local",
},
Internal: false,
ProfileRevisionCounter: 1,
},
}

132
service/network/clean.go Normal file
View File

@@ -0,0 +1,132 @@
package network
import (
"context"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/network/state"
"github.com/safing/portmaster/service/process"
)
const (
// DeleteConnsAfterEndedThreshold defines the amount of time after which
// ended connections should be removed from the internal connection state.
DeleteConnsAfterEndedThreshold = 10 * time.Minute
// DeleteIncompleteConnsAfterStartedThreshold defines the amount of time after
// which incomplete connections should be removed from the internal
// connection state.
DeleteIncompleteConnsAfterStartedThreshold = 1 * time.Minute
cleanerTickDuration = 5 * time.Second
)
func connectionCleaner(ctx context.Context) error {
ticker := module.NewSleepyTicker(cleanerTickDuration, 0)
for {
select {
case <-ctx.Done():
ticker.Stop()
return nil
case <-ticker.Wait():
// clean connections and processes
activePIDs := cleanConnections()
process.CleanProcessStorage(activePIDs)
// clean udp connection states
state.CleanUDPStates(ctx)
}
}
}
func cleanConnections() (activePIDs map[int]struct{}) {
activePIDs = make(map[int]struct{})
_ = module.RunMicroTask("clean connections", 0, func(ctx context.Context) error {
now := time.Now().UTC()
nowUnix := now.Unix()
ignoreNewer := nowUnix - 1
deleteOlderThan := now.Add(-DeleteConnsAfterEndedThreshold).Unix()
deleteIncompleteOlderThan := now.Add(-DeleteIncompleteConnsAfterStartedThreshold).Unix()
// network connections
for _, conn := range conns.clone() {
conn.Lock()
// delete inactive connections
switch {
case conn.Started >= ignoreNewer:
// Skip very fresh connections to evade edge cases.
case !conn.DataIsComplete():
// Step 0: delete old incomplete connections
if conn.Started < deleteIncompleteOlderThan {
// Stop the firewall handler, in case one is running.
conn.StopFirewallHandler()
// Remove connection from state.
conn.delete()
}
case conn.Ended == 0:
// Step 1: check if still active
exists := state.Exists(&packet.Info{
Inbound: false, // src == local
Version: conn.IPVersion,
Protocol: conn.IPProtocol,
Src: conn.LocalIP,
SrcPort: conn.LocalPort,
Dst: conn.Entity.IP,
DstPort: conn.Entity.Port,
PID: process.UndefinedProcessID,
SeenAt: time.Unix(conn.Started, 0), // State tables will be updated if older than this.
}, now)
// Step 2: mark as ended
if !exists {
conn.Ended = nowUnix
// Stop the firewall handler, in case one is running.
conn.StopFirewallHandler()
// Save to database.
conn.Save()
}
// If the connection has an associated process, add its PID to the active PID list.
if conn.process != nil {
activePIDs[conn.process.Pid] = struct{}{}
}
case conn.Ended < deleteOlderThan:
// Step 3: delete
// DEBUG:
// log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0))
// Remove connection from state.
conn.delete()
}
conn.Unlock()
}
// dns requests
for _, conn := range dnsConns.clone() {
conn.Lock()
// delete old dns connections
if conn.Ended < deleteOlderThan {
log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0))
conn.delete()
}
conn.Unlock()
}
// rerouted dns requests
cleanDNSRequestConnections()
return nil
})
return activePIDs
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
package network
import (
"context"
"fmt"
"net"
"time"
"github.com/safing/portmaster/service/intel"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/spn/navigator"
"github.com/tevino/abool"
)
// NewDefaultConnection creates a new connection with default values except local and remote IPs and protocols.
func NewDefaultConnection(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, ipVersion packet.IPVersion, protocol packet.IPProtocol) *Connection {
connInfo := &Connection{
ID: fmt.Sprintf("%s-%s-%d-%s-%d", protocol.String(), localIP, localPort, remoteIP, remotePort),
Type: IPConnection,
External: false,
IPVersion: ipVersion,
Inbound: false,
IPProtocol: protocol,
LocalIP: localIP,
LocalIPScope: netutils.Global,
LocalPort: localPort,
PID: process.UnidentifiedProcessID,
Entity: (&intel.Entity{
IP: remoteIP,
Protocol: uint8(protocol),
Port: remotePort,
}).Init(0),
Resolver: nil,
Started: time.Now().Unix(),
VerdictPermanent: false,
Tunneled: true,
Encrypted: false,
DataComplete: abool.NewBool(true),
Internal: false,
addedToMetrics: true, // Metrics are not needed for now. This will mark the Connection to be ignored.
process: process.GetUnidentifiedProcess(context.Background()),
}
// TODO: Quick fix for the SPN.
// Use inspection framework for proper encryption detection.
switch connInfo.Entity.DstPort() {
case
22, // SSH
443, // HTTPS
465, // SMTP-SSL
853, // DoT
993, // IMAP-SSL
995: // POP3-SSL
connInfo.Encrypted = true
}
var layeredProfile = connInfo.process.Profile()
connInfo.TunnelOpts = &navigator.Options{
HubPolicies: layeredProfile.StackedExitHubPolicies(),
CheckHubExitPolicyWith: connInfo.Entity,
RequireTrustedDestinationHubs: !connInfo.Encrypted,
RoutingProfile: layeredProfile.SPNRoutingAlgorithm(),
}
return connInfo
}

View File

@@ -0,0 +1,97 @@
package network
import (
"strings"
"sync"
)
type connectionStore struct {
rw sync.RWMutex
items map[string]*Connection
}
func newConnectionStore() *connectionStore {
return &connectionStore{
items: make(map[string]*Connection, 100),
}
}
func (cs *connectionStore) add(conn *Connection) {
cs.rw.Lock()
defer cs.rw.Unlock()
cs.items[conn.ID] = conn
}
func (cs *connectionStore) delete(conn *Connection) {
cs.rw.Lock()
defer cs.rw.Unlock()
delete(cs.items, conn.ID)
}
func (cs *connectionStore) get(id string) (*Connection, bool) {
cs.rw.RLock()
defer cs.rw.RUnlock()
conn, ok := cs.items[id]
return conn, ok
}
// findByPrefix returns the first connection where the key matches the given prefix.
// If the prefix matches multiple entries, the result is not deterministic.
func (cs *connectionStore) findByPrefix(prefix string) (*Connection, bool) { //nolint:unused
cs.rw.RLock()
defer cs.rw.RUnlock()
for key, conn := range cs.items {
if strings.HasPrefix(key, prefix) {
return conn, true
}
}
return nil, false
}
func (cs *connectionStore) clone() map[string]*Connection {
cs.rw.RLock()
defer cs.rw.RUnlock()
m := make(map[string]*Connection, len(cs.items))
for key, conn := range cs.items {
m[key] = conn
}
return m
}
func (cs *connectionStore) list() []*Connection {
cs.rw.RLock()
defer cs.rw.RUnlock()
l := make([]*Connection, 0, len(cs.items))
for _, conn := range cs.items {
l = append(l, conn)
}
return l
}
func (cs *connectionStore) len() int { //nolint:unused // TODO: Clean up if still unused.
cs.rw.RLock()
defer cs.rw.RUnlock()
return len(cs.items)
}
func (cs *connectionStore) active() int {
// Clone and count all active connections.
var cnt int
for _, conn := range cs.clone() {
conn.Lock()
if conn.Ended != 0 {
cnt++
}
conn.Unlock()
}
return cnt
}

202
service/network/database.go Normal file
View File

@@ -0,0 +1,202 @@
package network
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
"github.com/safing/portmaster/service/process"
)
const (
dbScopeNone = ""
dbScopeDNS = "dns"
dbScopeIP = "ip"
)
var (
dbController *database.Controller
dnsConns = newConnectionStore()
conns = newConnectionStore()
)
// StorageInterface provices a storage.Interface to the
// configuration manager.
type StorageInterface struct {
storage.InjectBase
}
// Database prefixes:
// Processes: network:tree/<PID>
// DNS Requests: network:tree/<PID>/dns/<ID>
// IP Connections: network:tree/<PID>/ip/<ID>
func makeKey(pid int, scope, id string) string {
if scope == "" {
return "network:tree/" + strconv.Itoa(pid)
}
return fmt.Sprintf("network:tree/%d/%s/%s", pid, scope, id)
}
func parseDBKey(key string) (processKey string, scope, id string, ok bool) {
// Split into segments.
segments := strings.Split(key, "/")
// Keys have 2 or 4 segments.
switch len(segments) {
case 4:
id = segments[3]
fallthrough
case 3:
scope = segments[2]
// Sanity check.
switch scope {
case dbScopeNone, dbScopeDNS, dbScopeIP:
// Parsed id matches possible values.
// The empty string is for matching a trailing slash for in query prefix.
// TODO: For queries, also prefixes of these values are valid.
default:
// Unknown scope.
return "", "", "", false
}
fallthrough
case 2:
processKey = segments[1]
return processKey, scope, id, true
case 1:
// This is a valid query prefix, but not process ID was given.
return "", "", "", true
default:
return "", "", "", false
}
}
// Get returns a database record.
func (s *StorageInterface) Get(key string) (record.Record, error) {
// Parse key and check if valid.
pid, scope, id, ok := parseDBKey(strings.TrimPrefix(key, "network:"))
if !ok || pid == "" {
return nil, storage.ErrNotFound
}
switch scope {
case dbScopeDNS:
if c, ok := dnsConns.get(id); ok && c.DataIsComplete() {
return c, nil
}
case dbScopeIP:
if c, ok := conns.get(id); ok && c.DataIsComplete() {
return c, nil
}
case dbScopeNone:
if proc, ok := process.GetProcessFromStorage(pid); ok {
return proc, nil
}
}
return nil, storage.ErrNotFound
}
// Query returns a an iterator for the supplied query.
func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
it := iterator.New()
module.StartWorker("connection query", func(_ context.Context) error {
s.processQuery(q, it)
return nil
})
return it, nil
}
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
var matches bool
pid, scope, _, ok := parseDBKey(q.DatabaseKeyPrefix())
if !ok {
it.Finish(nil)
return
}
if pid == "" {
// processes
for _, proc := range process.All() {
func() {
proc.Lock()
defer proc.Unlock()
matches = q.Matches(proc)
}()
if matches {
it.Next <- proc
}
}
}
if scope == dbScopeNone || scope == dbScopeDNS {
// dns scopes only
for _, dnsConn := range dnsConns.clone() {
if !dnsConn.DataIsComplete() {
continue
}
func() {
dnsConn.Lock()
defer dnsConn.Unlock()
matches = q.Matches(dnsConn)
}()
if matches {
it.Next <- dnsConn
}
}
}
if scope == dbScopeNone || scope == dbScopeIP {
// connections
for _, conn := range conns.clone() {
if !conn.DataIsComplete() {
continue
}
func() {
conn.Lock()
defer conn.Unlock()
matches = q.Matches(conn)
}()
if matches {
it.Next <- conn
}
}
}
it.Finish(nil)
}
func registerAsDatabase() error {
_, err := database.Register(&database.Database{
Name: "network",
Description: "Network and Firewall Data",
StorageType: "injected",
})
if err != nil {
return err
}
controller, err := database.InjectDatabase("network", &StorageInterface{})
if err != nil {
return err
}
dbController = controller
process.SetDBController(dbController)
return nil
}

258
service/network/dns.go Normal file
View File

@@ -0,0 +1,258 @@
package network
import (
"context"
"fmt"
"strconv"
"sync"
"time"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/nameserver/nsutil"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/service/resolver"
)
var (
dnsRequestConnections = make(map[string]*Connection) // key: <protocol>-<local ip>-<local port>
dnsRequestConnectionsLock sync.RWMutex
openDNSRequests = make(map[string]*Connection) // key: <pid>/<fqdn>
openDNSRequestsLock sync.Mutex
supportedDomainToIPRecordTypes = []uint16{
dns.TypeA,
dns.TypeAAAA,
dns.TypeSVCB,
dns.TypeHTTPS,
}
)
const (
// writeOpenDNSRequestsTickDuration defines the interval in which open dns
// requests are written.
writeOpenDNSRequestsTickDuration = 5 * time.Second
// openDNSRequestLimit defines the duration after which DNS requests without
// a following connection are logged.
openDNSRequestLimit = 3 * time.Second
)
func getDNSRequestConnectionKey(packetInfo *packet.Info) (id string, ok bool) {
// We only support protocols with ports.
if packetInfo.SrcPort == 0 {
return "", false
}
return fmt.Sprintf("%d-%s-%d", packetInfo.Protocol, packetInfo.Src, packetInfo.SrcPort), true
}
// SaveDNSRequestConnection saves a dns request connection for later retrieval.
func SaveDNSRequestConnection(conn *Connection, pkt packet.Packet) {
// Check connection.
if conn.PID == process.UndefinedProcessID || conn.PID == process.SystemProcessID {
// When re-injecting packets on Windows, they are reported with kernel PID (4).
log.Tracer(pkt.Ctx()).Tracef("network: not saving dns request connection because the PID is undefined/kernel")
return
}
// Create key.
key, ok := getDNSRequestConnectionKey(pkt.Info())
if !ok {
log.Tracer(pkt.Ctx()).Debugf("network: not saving dns request connection %s because the protocol is not supported", pkt)
return
}
// Add or update DNS request connection.
log.Tracer(pkt.Ctx()).Tracef("network: saving %s with PID %d as dns request connection for fast DNS request attribution", pkt, conn.PID)
dnsRequestConnectionsLock.Lock()
defer dnsRequestConnectionsLock.Unlock()
dnsRequestConnections[key] = conn
}
// GetDNSRequestConnection returns a saved dns request connection.
func GetDNSRequestConnection(packetInfo *packet.Info) (conn *Connection, ok bool) {
// Make key.
key, ok := getDNSRequestConnectionKey(packetInfo)
if !ok {
return nil, false
}
// Get and return
dnsRequestConnectionsLock.RLock()
defer dnsRequestConnectionsLock.RUnlock()
conn, ok = dnsRequestConnections[key]
return conn, ok
}
// deleteDNSRequestConnection removes a connection from the dns request connections.
func deleteDNSRequestConnection(packetInfo *packet.Info) { //nolint:unused,deadcode
dnsRequestConnectionsLock.Lock()
defer dnsRequestConnectionsLock.Unlock()
key, ok := getDNSRequestConnectionKey(packetInfo)
if ok {
delete(dnsRequestConnections, key)
}
}
// cleanDNSRequestConnections deletes old DNS request connections.
func cleanDNSRequestConnections() {
deleteOlderThan := time.Now().Unix() - 3
dnsRequestConnectionsLock.Lock()
defer dnsRequestConnectionsLock.Unlock()
for key, conn := range dnsRequestConnections {
conn.Lock()
if conn.Ended > 0 && conn.Ended < deleteOlderThan {
delete(dnsRequestConnections, key)
}
conn.Unlock()
}
}
// IsSupportDNSRecordType returns whether the given DSN RR type is supported
// by the network package, as in the requests are specially handled and can be
// "merged" into the resulting connection.
func IsSupportDNSRecordType(rrType uint16) bool {
return slices.Contains[[]uint16, uint16](supportedDomainToIPRecordTypes, rrType)
}
func getDNSRequestCacheKey(pid int, fqdn string, qType uint16) string {
return strconv.Itoa(pid) + "/" + fqdn + dns.Type(qType).String()
}
func removeOpenDNSRequest(pid int, fqdn string) {
openDNSRequestsLock.Lock()
defer openDNSRequestsLock.Unlock()
// Delete PID-specific requests.
for _, dnsType := range supportedDomainToIPRecordTypes {
delete(openDNSRequests, getDNSRequestCacheKey(pid, fqdn, dnsType))
}
// If process is known, also check for non-attributed requests.
if pid != process.UnidentifiedProcessID {
for _, dnsType := range supportedDomainToIPRecordTypes {
delete(openDNSRequests, getDNSRequestCacheKey(process.UnidentifiedProcessID, fqdn, dnsType))
}
}
}
// SaveOpenDNSRequest saves a dns request connection that was allowed to proceed.
func SaveOpenDNSRequest(q *resolver.Query, rrCache *resolver.RRCache, conn *Connection) {
// Only save requests that actually went out (or triggered an async resolve) to reduce clutter.
if rrCache == nil || (rrCache.ServedFromCache && !rrCache.RequestingNew) {
return
}
// Try to "merge" supported requests into the resulting connection.
// Save others immediately.
if !IsSupportDNSRecordType(uint16(q.QType)) {
conn.Save()
return
}
openDNSRequestsLock.Lock()
defer openDNSRequestsLock.Unlock()
// Do not check for an existing open DNS request, as duplicates in such quick
// succession are not worth keeping.
// DNS queries are usually retried pretty quick.
// Save to open dns requests.
key := getDNSRequestCacheKey(conn.process.Pid, conn.Entity.Domain, uint16(q.QType))
openDNSRequests[key] = conn
}
func openDNSRequestWriter(ctx context.Context) error {
ticker := module.NewSleepyTicker(writeOpenDNSRequestsTickDuration, 0)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.Wait():
writeOpenDNSRequestsToDB()
}
}
}
func writeOpenDNSRequestsToDB() {
openDNSRequestsLock.Lock()
defer openDNSRequestsLock.Unlock()
threshold := time.Now().Add(-openDNSRequestLimit).Unix()
for id, conn := range openDNSRequests {
func() {
conn.Lock()
defer conn.Unlock()
if conn.Ended < threshold {
conn.Save()
delete(openDNSRequests, id)
}
}()
}
}
// ReplyWithDNS creates a new reply to the given request with the data from the RRCache, and additional informational records.
func (conn *Connection) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns.Msg {
// Select request responder.
switch conn.Verdict {
case VerdictBlock:
return nsutil.BlockIP().ReplyWithDNS(ctx, request)
case VerdictDrop:
return nil // Do not respond to request.
case VerdictFailed:
return nsutil.BlockIP().ReplyWithDNS(ctx, request)
case VerdictUndecided, VerdictUndeterminable,
VerdictAccept, VerdictRerouteToNameserver, VerdictRerouteToTunnel:
fallthrough
default:
reply := nsutil.ServerFailure().ReplyWithDNS(ctx, request)
nsutil.AddMessagesToReply(ctx, reply, log.ErrorLevel, "INTERNAL ERROR: incorrect use of Connection DNS Responder")
return reply
}
}
// GetExtraRRs returns a slice of RRs with additional informational records.
func (conn *Connection) GetExtraRRs(ctx context.Context, request *dns.Msg) []dns.RR {
// Select level to add the verdict record with.
var level log.Severity
switch conn.Verdict {
case VerdictFailed:
level = log.ErrorLevel
case VerdictUndecided, VerdictUndeterminable,
VerdictAccept, VerdictBlock, VerdictDrop,
VerdictRerouteToNameserver, VerdictRerouteToTunnel:
fallthrough
default:
level = log.InfoLevel
}
// Create resource record with verdict and reason.
rr, err := nsutil.MakeMessageRecord(level, fmt.Sprintf("%s: %s", conn.VerdictVerb(), conn.Reason.Msg))
if err != nil {
log.Tracer(ctx).Warningf("filter: failed to add informational record to reply: %s", err)
return nil
}
extra := []dns.RR{rr}
// Add additional records from Reason.Context.
if rrProvider, ok := conn.Reason.Context.(nsutil.RRProvider); ok {
rrs := rrProvider.GetExtraRRs(ctx, request)
extra = append(extra, rrs...)
}
return extra
}

View File

@@ -0,0 +1,71 @@
//go:build windows
package iphelper
import (
"sync"
"github.com/safing/portmaster/service/network/socket"
)
var (
ipHelper *IPHelper
// lock locks access to the whole DLL.
// TODO: It's unproven if we can access the iphlpapi.dll concurrently, especially as we might be encountering various versions of the DLL. In the future, we could possibly investigate and improve performance here.
lock sync.RWMutex
)
// GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, nil, err
}
return ipHelper.getTable(IPv4, TCP)
}
// GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, nil, err
}
return ipHelper.getTable(IPv6, TCP)
}
// GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, err
}
_, binds, err = ipHelper.getTable(IPv4, UDP)
return
}
// GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, err
}
_, binds, err = ipHelper.getTable(IPv6, UDP)
return
}

View File

@@ -0,0 +1,63 @@
//go:build windows
package iphelper
import (
"errors"
"fmt"
"github.com/tevino/abool"
"golang.org/x/sys/windows"
)
var (
errInvalid = errors.New("IPHelper not initialized or broken")
)
// IPHelper represents a subset of the Windows iphlpapi.dll.
type IPHelper struct {
dll *windows.LazyDLL
getExtendedTCPTable *windows.LazyProc
getExtendedUDPTable *windows.LazyProc
valid *abool.AtomicBool
}
func checkIPHelper() (err error) {
if ipHelper == nil {
ipHelper, err = New()
return err
}
return nil
}
// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded).
func New() (*IPHelper, error) {
new := &IPHelper{}
new.valid = abool.NewBool(false)
var err error
// load dll
new.dll = windows.NewLazySystemDLL("iphlpapi.dll")
err = new.dll.Load()
if err != nil {
return nil, err
}
// load functions
new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable")
err = new.getExtendedTCPTable.Find()
if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err)
}
new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable")
err = new.getExtendedUDPTable.Find()
if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err)
}
new.valid.Set()
return new, nil
}

View File

@@ -0,0 +1,398 @@
//go:build windows
package iphelper
import (
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"unsafe"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/socket"
"golang.org/x/sys/windows"
)
// Windows API constants
const (
iphelperTCPTableOwnerPIDAll uintptr = 5
iphelperUDPTableOwnerPID uintptr = 1
iphelperTCPStateListen uint32 = 2
winErrInsufficientBuffer = uintptr(windows.ERROR_INSUFFICIENT_BUFFER)
winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER)
)
type iphelperTCPTable struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx
numEntries uint32
table [maxStateTableEntries]iphelperTCPRow
}
type iphelperTCPRow struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366913(v=vs.85).aspx
state uint32
localAddr uint32
localPort uint32
remoteAddr uint32
remotePort uint32
owningPid uint32
}
type iphelperTCP6Table struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366905(v=vs.85).aspx
numEntries uint32
table [maxStateTableEntries]iphelperTCP6Row
}
type iphelperTCP6Row struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366896(v=vs.85).aspx
localAddr [16]byte
_ uint32 // localScopeID
localPort uint32
remoteAddr [16]byte
_ uint32 // remoteScopeID
remotePort uint32
state uint32
owningPid uint32
}
type iphelperUDPTable struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366932(v=vs.85).aspx
numEntries uint32
table [maxStateTableEntries]iphelperUDPRow
}
type iphelperUDPRow struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366928(v=vs.85).aspx
localAddr uint32
localPort uint32
owningPid uint32
}
type iphelperUDP6Table struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366925(v=vs.85).aspx
numEntries uint32
table [maxStateTableEntries]iphelperUDP6Row
}
type iphelperUDP6Row struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366923(v=vs.85).aspx
localAddr [16]byte
_ uint32 // localScopeID
localPort uint32
owningPid uint32
}
// IP and Protocol constants
const (
IPv4 uint8 = 4
IPv6 uint8 = 6
TCP uint8 = 6
UDP uint8 = 17
)
type learningBufSize struct {
sync.Mutex
size int
usesLeft int
useFor int
start int
max int
}
func newLearningBufSize(start, max, ttl int) *learningBufSize {
return &learningBufSize{
size: start,
usesLeft: ttl,
useFor: ttl,
start: start,
max: max,
}
}
const (
startBufSize = 1024
// bufSizeUsageTTL defines how often a buffer size is used before it is
// shrunk again.
bufSizeUsageTTL = 100
// maxBufSize is the maximum size we will allocate for responses. This was
// previously set at 65k, which was too little for some production cases.
maxBufSize = 1048576 // 2^20B, 1MB
// maxStateTableEntries is the maximum supported amount of entries of the
// state tables.
// This is never allocated, but just casted to from an unsafe pointer.
maxStateTableEntries = 65535
)
var (
tcp4BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL)
udp4BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL)
tcp6BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL)
udp6BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL)
)
func (lbf *learningBufSize) getBufSize() int {
lbf.Lock()
defer lbf.Unlock()
// using bufSize
lbf.usesLeft--
// check if we want to reset
if lbf.usesLeft <= 0 {
// decrease
lbf.size /= 2
// not too little
if lbf.size < lbf.start {
lbf.size = lbf.start
}
// reset TTL counter
lbf.usesLeft = lbf.useFor
}
return lbf.size
}
func (lbf *learningBufSize) increaseBufSize(minSize int) int {
lbf.Lock()
defer lbf.Unlock()
// increase
lbf.size *= 2
// increase until we reach the minimum size
for lbf.size < minSize {
lbf.size *= 2
}
// not too much
if lbf.size > lbf.max {
lbf.size = lbf.max
}
// reset TTL counter
lbf.usesLeft = lbf.useFor
// return new bufSize
return lbf.size
}
// getTable returns the current connection state table of Windows of the given protocol and IP version.
func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { //nolint:gocognit,gocycle // TODO
// docs: https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
if !ipHelper.valid.IsSet() {
return nil, nil, errInvalid
}
var afClass int
var lbf *learningBufSize
switch ipVersion {
case IPv4:
afClass = windows.AF_INET
if protocol == TCP {
lbf = tcp4BufSize
} else {
lbf = udp4BufSize
}
case IPv6:
afClass = windows.AF_INET6
if protocol == TCP {
lbf = tcp6BufSize
} else {
lbf = udp6BufSize
}
default:
return nil, nil, errors.New("invalid protocol")
}
// try max 5 times
maxTries := 5
usedBufSize := lbf.getBufSize()
var buf []byte
for i := 1; i <= maxTries; i++ {
bufSizeParam := usedBufSize
buf = make([]byte, bufSizeParam)
var r1 uintptr
switch protocol {
case TCP:
r1, _, err = ipHelper.getExtendedTCPTable.Call(
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable
uintptr(unsafe.Pointer(&bufSizeParam)), // _Inout_ PDWORD pdwSize
0, // _In_ BOOL bOrder
uintptr(afClass), // _In_ ULONG ulAf
iphelperTCPTableOwnerPIDAll, // _In_ TCP_TABLE_CLASS TableClass
0, // _In_ ULONG Reserved
)
case UDP:
r1, _, err = ipHelper.getExtendedUDPTable.Call(
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable,
uintptr(unsafe.Pointer(&bufSizeParam)), // _Inout_ PDWORD pdwSize,
0, // _In_ BOOL bOrder,
uintptr(afClass), // _In_ ULONG ulAf,
iphelperUDPTableOwnerPID, // _In_ UDP_TABLE_CLASS TableClass,
0, // _In_ ULONG Reserved
)
}
switch r1 {
case winErrInsufficientBuffer:
if i >= maxTries {
return nil, nil, fmt.Errorf(
"insufficient buffer error (tried %d times): provided %d bytes; required %d bytes - [NT 0x%X] %s",
i, usedBufSize, bufSizeParam, r1, err,
)
}
// bufSizeParam was modified by ipHelper.getExtended*Table to hold the
// required buffer size.
usedBufSize = lbf.increaseBufSize(bufSizeParam)
case winErrInvalidParameter:
return nil, nil, fmt.Errorf("invalid parameter: [NT 0x%X] %s", r1, err)
case windows.NO_ERROR:
// success
break
default:
return nil, nil, fmt.Errorf("unexpected error: [NT 0x%X] %s", r1, err)
}
}
// parse output
switch {
case protocol == TCP && ipVersion == IPv4:
tcpTable := (*iphelperTCPTable)(unsafe.Pointer(&buf[0]))
// Check if we got more entries than supported.
tableEntries := tcpTable.numEntries
if tableEntries > maxStateTableEntries {
tableEntries = maxStateTableEntries
log.Warningf("network/iphelper: received TCPv4 table with more entries than supported: %d/%d", tcpTable.numEntries, maxStateTableEntries)
}
// Cap table to actual entries.
table := tcpTable.table[:tableEntries]
for _, row := range table {
if row.state == iphelperTCPStateListen {
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
} else {
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: convertPort(row.localPort),
},
Remote: socket.Address{
IP: convertIPv4(row.remoteAddr),
Port: convertPort(row.remotePort),
},
PID: int(row.owningPid),
})
}
}
case protocol == TCP && ipVersion == IPv6:
tcpTable := (*iphelperTCP6Table)(unsafe.Pointer(&buf[0]))
// Check if we got more entries than supported.
tableEntries := tcpTable.numEntries
if tableEntries > maxStateTableEntries {
tableEntries = maxStateTableEntries
log.Warningf("network/iphelper: received TCPv6 table with more entries than supported: %d/%d", tcpTable.numEntries, maxStateTableEntries)
}
// Cap table to actual entries.
table := tcpTable.table[:tableEntries]
for _, row := range table {
if row.state == iphelperTCPStateListen {
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
} else {
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: convertPort(row.localPort),
},
Remote: socket.Address{
IP: net.IP(row.remoteAddr[:]),
Port: convertPort(row.remotePort),
},
PID: int(row.owningPid),
})
}
}
case protocol == UDP && ipVersion == IPv4:
udpTable := (*iphelperUDPTable)(unsafe.Pointer(&buf[0]))
// Check if we got more entries than supported.
tableEntries := udpTable.numEntries
if tableEntries > maxStateTableEntries {
tableEntries = maxStateTableEntries
log.Warningf("network/iphelper: received UDPv4 table with more entries than supported: %d/%d", udpTable.numEntries, maxStateTableEntries)
}
// Cap table to actual entries.
table := udpTable.table[:tableEntries]
for _, row := range table {
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
}
case protocol == UDP && ipVersion == IPv6:
udpTable := (*iphelperUDP6Table)(unsafe.Pointer(&buf[0]))
// Check if we got more entries than supported.
tableEntries := udpTable.numEntries
if tableEntries > maxStateTableEntries {
tableEntries = maxStateTableEntries
log.Warningf("network/iphelper: received UDPv6 table with more entries than supported: %d/%d", udpTable.numEntries, maxStateTableEntries)
}
// Cap table to actual entries.
table := udpTable.table[:tableEntries]
for _, row := range table {
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
}
}
return connections, binds, nil
}
// convertIPv4 as needed for iphlpapi.dll
func convertIPv4(input uint32) net.IP {
addressBuf := make([]byte, 4)
binary.LittleEndian.PutUint32(addressBuf, input)
return net.IP(addressBuf)
}
// convertPort converts ports received from iphlpapi.dll
func convertPort(input uint32) uint16 {
return uint16(input>>8 | input<<8)
}

View File

@@ -0,0 +1,54 @@
//go:build windows
package iphelper
import (
"fmt"
"testing"
)
func TestSockets(t *testing.T) {
connections, listeners, err := GetTCP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 4 connections:")
for _, connection := range connections {
fmt.Printf("%+v\n", connection)
}
fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners {
fmt.Printf("%+v\n", listener)
}
connections, listeners, err = GetTCP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 6 connections:")
for _, connection := range connections {
fmt.Printf("%+v\n", connection)
}
fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners {
fmt.Printf("%+v\n", listener)
}
binds, err := GetUDP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 4 binds:")
for _, bind := range binds {
fmt.Printf("%+v\n", bind)
}
binds, err = GetUDP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 6 binds:")
for _, bind := range binds {
fmt.Printf("%+v\n", bind)
}
}

172
service/network/metrics.go Normal file
View File

@@ -0,0 +1,172 @@
package network
import (
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/metrics"
"github.com/safing/portmaster/service/process"
)
var (
packetHandlingHistogram *metrics.Histogram
blockedOutConnCounter *metrics.Counter
encryptedAndTunneledOutConnCounter *metrics.Counter
encryptedOutConnCounter *metrics.Counter
tunneledOutConnCounter *metrics.Counter
outConnCounter *metrics.Counter
)
func registerMetrics() (err error) {
// This needed to be moved here, because every packet is now handled by the
// connection handler worker.
packetHandlingHistogram, err = metrics.NewHistogram(
"firewall/handling/duration/seconds",
nil,
&metrics.Options{
Permission: api.PermitUser,
ExpertiseLevel: config.ExpertiseLevelExpert,
})
if err != nil {
return err
}
_, err = metrics.NewGauge(
"network/connections/active/total",
nil,
func() float64 {
return float64(conns.active())
},
&metrics.Options{
InternalID: "active_connections",
Permission: api.PermitUser,
ExpertiseLevel: config.ExpertiseLevelUser,
})
if err != nil {
return err
}
connCounterID := "network/connections/total"
connCounterOpts := &metrics.Options{
Name: "Connections",
Permission: api.PermitUser,
ExpertiseLevel: config.ExpertiseLevelUser,
Persist: true,
}
blockedOutConnCounter, err = metrics.NewCounter(
connCounterID,
map[string]string{
"direction": "out",
"blocked": "true",
},
&metrics.Options{
Name: "Connections",
InternalID: "blocked_outgoing_connections",
Permission: api.PermitUser,
ExpertiseLevel: config.ExpertiseLevelUser,
Persist: true,
},
)
if err != nil {
return err
}
encryptedAndTunneledOutConnCounter, err = metrics.NewCounter(
connCounterID,
map[string]string{
"direction": "out",
"encrypted": "true",
"tunneled": "true",
},
connCounterOpts,
)
if err != nil {
return err
}
encryptedOutConnCounter, err = metrics.NewCounter(
connCounterID,
map[string]string{
"direction": "out",
"encrypted": "true",
},
connCounterOpts,
)
if err != nil {
return err
}
tunneledOutConnCounter, err = metrics.NewCounter(
connCounterID,
map[string]string{
"direction": "out",
"tunneled": "true",
},
connCounterOpts,
)
if err != nil {
return err
}
outConnCounter, err = metrics.NewCounter(
connCounterID,
map[string]string{
"direction": "out",
},
connCounterOpts,
)
if err != nil {
return err
}
return nil
}
func (conn *Connection) addToMetrics() {
if conn.addedToMetrics {
return
}
// Don't count requests serviced to the network,
// as we have an incomplete view here.
if conn.Process() != nil &&
conn.Process().Pid == process.NetworkHostProcessID {
return
}
// Only count outgoing connections for now.
if conn.Inbound {
return
}
// Check the verdict.
switch conn.Verdict { //nolint:exhaustive // Not critical.
case VerdictBlock, VerdictDrop:
blockedOutConnCounter.Inc()
conn.addedToMetrics = true
return
case VerdictAccept, VerdictRerouteToTunnel:
// Continue to next section.
default:
// Connection is not counted.
return
}
// Only count successful connections, not DNS requests.
if conn.Type == DNSRequest {
return
}
// Select counter based on attributes.
switch {
case conn.Encrypted && conn.Tunneled:
encryptedAndTunneledOutConnCounter.Inc()
case conn.Encrypted:
encryptedOutConnCounter.Inc()
case conn.Tunneled:
tunneledOutConnCounter.Inc()
default:
outConnCounter.Inc()
}
conn.addedToMetrics = true
}

151
service/network/module.go Normal file
View File

@@ -0,0 +1,151 @@
package network
import (
"context"
"fmt"
"strings"
"sync"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network/state"
"github.com/safing/portmaster/service/profile"
)
var (
module *modules.Module
defaultFirewallHandler FirewallHandler
)
// Events.
var (
ConnectionReattributedEvent = "connection re-attributed"
)
func init() {
module = modules.Register("network", prep, start, nil, "base", "netenv", "processes")
module.RegisterEvent(ConnectionReattributedEvent, false)
}
// SetDefaultFirewallHandler sets the default firewall handler.
func SetDefaultFirewallHandler(handler FirewallHandler) {
if defaultFirewallHandler == nil {
defaultFirewallHandler = handler
}
}
func prep() error {
if netenv.IPv6Enabled() {
state.EnableTCPDualStack()
state.EnableUDPDualStack()
}
return registerAPIEndpoints()
}
func start() error {
err := registerAsDatabase()
if err != nil {
return err
}
if err := registerMetrics(); err != nil {
return err
}
module.StartServiceWorker("clean connections", 0, connectionCleaner)
module.StartServiceWorker("write open dns requests", 0, openDNSRequestWriter)
if err := module.RegisterEventHook(
"profiles",
profile.DeletedEvent,
"re-attribute connections from deleted profile",
reAttributeConnections,
); err != nil {
return err
}
return nil
}
var reAttributionLock sync.Mutex
// reAttributeConnections finds all connections of a deleted profile and re-attributes them.
// Expected event data: scoped profile ID.
func reAttributeConnections(_ context.Context, eventData any) error {
profileID, ok := eventData.(string)
if !ok {
return fmt.Errorf("event data is not a string: %v", eventData)
}
profileSource, profileID, ok := strings.Cut(profileID, "/")
if !ok {
return fmt.Errorf("event data does not seem to be a scoped profile ID: %v", eventData)
}
// Hold a lock for re-attribution, to prevent simultaneous processing of the
// same connections and make logging cleaner.
reAttributionLock.Lock()
defer reAttributionLock.Unlock()
// Create tracing context.
ctx, tracer := log.AddTracer(context.Background())
defer tracer.Submit()
tracer.Infof("network: re-attributing connections from deleted profile %s/%s", profileSource, profileID)
// Count and log how many connections were re-attributed.
var reAttributed int
// Re-attribute connections.
for _, conn := range conns.clone() {
if reAttributeConnection(ctx, conn, profileID, profileSource) {
reAttributed++
tracer.Debugf("filter: re-attributed %s to %s", conn, conn.process.PrimaryProfileID)
}
}
// Re-attribute dns connections.
for _, conn := range dnsConns.clone() {
if reAttributeConnection(ctx, conn, profileID, profileSource) {
reAttributed++
tracer.Debugf("filter: re-attributed %s to %s", conn, conn.process.PrimaryProfileID)
}
}
tracer.Infof("filter: re-attributed %d connections", reAttributed)
return nil
}
func reAttributeConnection(ctx context.Context, conn *Connection, profileID, profileSource string) (reAttributed bool) {
// Lock the connection before checking anything to avoid a race condition with connection data collection.
conn.Lock()
defer conn.Unlock()
// Check if the connection has the profile we are looking for.
switch {
case !conn.DataIsComplete():
return false
case conn.ProcessContext.Profile != profileID:
return false
case conn.ProcessContext.Source != profileSource:
return false
}
// Attempt to assign new profile.
err := conn.process.RefetchProfile(ctx)
if err != nil {
log.Tracer(ctx).Warningf("network: failed to refetch profile for %s: %s", conn, err)
return false
}
// Set the new process context.
conn.ProcessContext = getProcessContext(ctx, conn.process)
conn.Save()
// Trigger event for re-attribution.
module.TriggerEvent(ConnectionReattributedEvent, conn.ID)
log.Tracer(ctx).Debugf("filter: re-attributed %s to %s", conn, conn.process.PrimaryProfileID)
return true
}

View File

@@ -0,0 +1,45 @@
package network
import (
"net"
"github.com/safing/portmaster/service/network/netutils"
)
// GetMulticastRequestConn searches for and returns the requesting connnection
// of a possible multicast/broadcast response.
func GetMulticastRequestConn(responseConn *Connection, responseFromNet *net.IPNet) *Connection {
// Calculate the broadcast address the query would have gone to.
responseNetBroadcastIP := netutils.GetBroadcastAddress(responseFromNet.IP, responseFromNet.Mask)
// Find requesting multicast/broadcast connection.
for _, conn := range conns.clone() {
switch {
case !conn.DataIsComplete():
// Ignore connection with incomplete data.
case conn.Inbound:
// Ignore incoming connections.
case conn.Ended != 0:
// Ignore ended connections.
case conn.Entity.Protocol != responseConn.Entity.Protocol:
// Ignore on protocol mismatch.
case conn.LocalPort != responseConn.LocalPort:
// Ignore on local port mismatch.
case !conn.LocalIP.Equal(responseConn.LocalIP):
// Ignore on local IP mismatch.
case !conn.Process().Equal(responseConn.Process()):
// Ignore if processes mismatch.
case conn.Entity.IPScope == netutils.LocalMulticast &&
(responseConn.Entity.IPScope == netutils.LinkLocal ||
responseConn.Entity.IPScope == netutils.SiteLocal):
// We found a (possibly routed) multicast request that matches the response!
return conn
case conn.Entity.IP.Equal(responseNetBroadcastIP) &&
responseFromNet.Contains(conn.LocalIP):
// We found a (link local) broadcast request that matches the response!
return conn
}
}
return nil
}

View File

@@ -0,0 +1,60 @@
package netutils
import (
"errors"
"net"
"strconv"
"github.com/safing/portmaster/service/network/packet"
)
var errInvalidIP = errors.New("invalid IP address")
// IPPortFromAddr extracts or parses the IP address and port contained in the given address.
func IPPortFromAddr(addr net.Addr) (ip net.IP, port uint16, err error) {
// Convert addr to IP if needed.
switch v := addr.(type) {
case *net.TCPAddr:
return v.IP, uint16(v.Port), nil
case *net.UDPAddr:
return v.IP, uint16(v.Port), nil
case *net.IPAddr:
return v.IP, 0, nil
case *net.UnixAddr:
return nil, 0, errors.New("unix addresses don't have IPs")
default:
return ParseIPPort(addr.String())
}
}
// ProtocolFromNetwork returns the protocol from the given net, as used in the "net" golang stdlib.
func ProtocolFromNetwork(net string) (protocol packet.IPProtocol) {
switch net {
case "tcp", "tcp4", "tcp6":
return packet.TCP
case "udp", "udp4", "udp6":
return packet.UDP
default:
return 0
}
}
// ParseIPPort parses a <ip>:port formatted address.
func ParseIPPort(address string) (net.IP, uint16, error) {
ipString, portString, err := net.SplitHostPort(address)
if err != nil {
return nil, 0, err
}
ip := net.ParseIP(ipString)
if ip == nil {
return nil, 0, errInvalidIP
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, 0, err
}
return ip, uint16(port), nil
}

View File

@@ -0,0 +1,99 @@
package netutils
import (
"fmt"
"net"
"regexp"
"strings"
"github.com/miekg/dns"
)
var (
cleanDomainRegex = regexp.MustCompile(
`^` + // match beginning
`(` + // start subdomain group
`(xn--)?` + // idn prefix
`[a-z0-9_-]{1,63}` + // main chunk
`\.` + // ending with a dot
`)*` + // end subdomain group, allow any number of subdomains
`(xn--)?` + // TLD idn prefix
`[a-z0-9_-]{1,63}` + // TLD main chunk with at least one character (for custom ones)
`\.` + // ending with a dot
`$`, // match end
)
// dnsSDDomainRegex is a lot more lax to better suit the allowed characters in DNS-SD.
// Not all characters have been allowed - some special characters were
// removed to reduce the general attack surface.
dnsSDDomainRegex = regexp.MustCompile(
// Start of charset selection.
`^[` +
// Printable ASCII (character code 32-127), excluding some special characters.
` !#$%&()*+,\-\./0-9:;=?@A-Z[\\\]^_\a-z{|}~` +
// Only latin characters from extended ASCII (character code 128-255).
`ŠŒŽšœžŸ¡¿ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿ` +
// End of charset selection.
`]*$`,
)
)
// IsValidFqdn returns whether the given string is a valid fqdn.
func IsValidFqdn(fqdn string) bool {
// root zone
if fqdn == "." {
return true
}
// check max length
if len(fqdn) > 256 {
return false
}
// IsFqdn checks if a domain name is fully qualified.
if !dns.IsFqdn(fqdn) {
return false
}
// Use special check for .local domains to support DNS-SD.
if strings.HasSuffix(fqdn, ".local.") {
return dnsSDDomainRegex.MatchString(fqdn)
}
// check with regex
if !cleanDomainRegex.MatchString(fqdn) {
return false
}
// IsDomainName checks if s is a valid domain name, it returns the number of
// labels and true, when a domain name is valid. Note that non fully qualified
// domain name is considered valid, in this case the last label is counted in
// the number of labels. When false is returned the number of labels is not
// defined. Also note that this function is extremely liberal; almost any
// string is a valid domain name as the DNS is 8 bit protocol. It checks if each
// label fits in 63 characters and that the entire name will fit into the 255
// octet wire format limit.
_, ok := dns.IsDomainName(fqdn)
return ok
}
// IPsToRRs transforms the given IPs to resource records.
func IPsToRRs(domain string, ips []net.IP) ([]dns.RR, error) {
records := make([]dns.RR, 0, len(ips))
var rr dns.RR
var err error
for _, ip := range ips {
if ip.To4() != nil {
rr, err = dns.NewRR(fmt.Sprintf("%s 17 IN A %s", domain, ip))
} else {
rr, err = dns.NewRR(fmt.Sprintf("%s 17 IN AAAA %s", domain, ip))
}
if err != nil {
return nil, fmt.Errorf("failed to create record for %s: %w", ip, err)
}
records = append(records, rr)
}
return records, nil
}

View File

@@ -0,0 +1,47 @@
package netutils
import "testing"
func testDomainValidity(t *testing.T, domain string, isValid bool) {
t.Helper()
if IsValidFqdn(domain) != isValid {
t.Errorf("domain %s failed check: was valid=%v, expected valid=%v", domain, IsValidFqdn(domain), isValid)
}
}
func TestDNSValidation(t *testing.T) {
t.Parallel()
// valid
testDomainValidity(t, ".", true)
testDomainValidity(t, "at.", true)
testDomainValidity(t, "orf.at.", true)
testDomainValidity(t, "www.orf.at.", true)
testDomainValidity(t, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.u.v.x.y.z.example.org.", true)
testDomainValidity(t, "a_a.com.", true)
testDomainValidity(t, "a-a.com.", true)
testDomainValidity(t, "a_a.com.", true)
testDomainValidity(t, "a-a.com.", true)
testDomainValidity(t, "xn--a.com.", true)
testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasd.com.", true)
// maybe valid
testDomainValidity(t, "-.com.", true)
testDomainValidity(t, "_.com.", true)
testDomainValidity(t, "a_.com.", true)
testDomainValidity(t, "a-.com.", true)
testDomainValidity(t, "_a.com.", true)
testDomainValidity(t, "-a.com.", true)
// invalid
testDomainValidity(t, ".com.", false)
testDomainValidity(t, ".com.", false)
testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
testDomainValidity(t, "asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.as.com.", false)
// real world examples
testDomainValidity(t, "iuqerfsodp9ifjaposdfjhgosurijfaewrwergwea.com.", true)
}

View File

@@ -0,0 +1,160 @@
package netutils
import "net"
// IPScope is the scope of the IP address.
type IPScope int8
// Defined IP Scopes.
const (
Invalid IPScope = iota - 1
Undefined
HostLocal
LinkLocal
SiteLocal
Global
LocalMulticast
GlobalMulticast
)
// ClassifyIP returns the network scope of the given IP address.
// Deprecated: Please use the new GetIPScope instead.
func ClassifyIP(ip net.IP) IPScope {
return GetIPScope(ip)
}
// GetIPScope returns the network scope of the given IP address.
func GetIPScope(ip net.IP) IPScope { //nolint:gocognit
if ip4 := ip.To4(); ip4 != nil {
// IPv4
switch {
case ip4[0] == 0 && ip4[1] == 0 && ip4[2] == 0 && ip4[3] == 0:
// 0.0.0.0/32
return LocalMulticast // Used as source for L2 based protocols with no L3 addressing.
case ip4[0] == 0:
// 0.0.0.0/8
return Invalid
case ip4[0] == 10:
// 10.0.0.0/8 (RFC1918)
return SiteLocal
case ip4[0] == 100 && ip4[1]&0b11000000 == 64:
// 100.64.0.0/10 (RFC6598)
return SiteLocal
case ip4[0] == 127:
// 127.0.0.0/8 (RFC1918)
return HostLocal
case ip4[0] == 169 && ip4[1] == 254:
// 169.254.0.0/16 (RFC3927)
return LinkLocal
case ip4[0] == 172 && ip4[1]&0b11110000 == 16:
// 172.16.0.0/12 (RFC1918)
return SiteLocal
case ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2:
// 192.0.2.0/24 (TEST-NET-1, RFC5737)
return Invalid
case ip4[0] == 192 && ip4[1] == 168:
// 192.168.0.0/16 (RFC1918)
return SiteLocal
case ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100:
// 198.51.100.0/24 (TEST-NET-2, RFC5737)
return Invalid
case ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113:
// 203.0.113.0/24 (TEST-NET-3, RFC5737)
return Invalid
case ip4[0] == 224:
// 224.0.0.0/8 (RFC5771)
return LocalMulticast
case ip4[0] == 233 && ip4[1] == 252 && ip4[2] == 0:
// 233.252.0.0/24 (MCAST-TEST-NET; RFC5771, RFC6676)
return Invalid
case ip4[0] >= 225 && ip4[0] <= 238:
// 225.0.0.0/8 - 238.0.0.0/8 (RFC5771)
return GlobalMulticast
case ip4[0] == 239:
// 239.0.0.0/8 (RFC2365)
return LocalMulticast
case ip4[0] == 255 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 255:
// 255.255.255.255/32
return LocalMulticast
case ip4[0] >= 240:
// 240.0.0.0/8 - 255.0.0.0/8 (minus 255.255.255.255/32)
return Invalid
default:
return Global
}
} else if len(ip) == net.IPv6len {
// IPv6
switch {
case ip.Equal(net.IPv6zero):
return Invalid
case ip.Equal(net.IPv6loopback):
return HostLocal
case ip[0]&0xfe == 0xfc:
// fc00::/7
return SiteLocal
case ip[0] == 0xfe && ip[1]&0xc0 == 0x80:
// fe80::/10
return LinkLocal
case ip[0] == 0xff && ip[1] <= 0x05:
// ff00::/16 - ff05::/16
return LocalMulticast
case ip[0] == 0xff:
// other ff00::/8
return GlobalMulticast
default:
return Global
}
}
return Invalid
}
// IsLocalhost returns whether the IP refers to the host itself.
func (scope IPScope) IsLocalhost() bool {
return scope == HostLocal
}
// IsLAN returns true if the scope is site-local or link-local.
func (scope IPScope) IsLAN() bool {
switch scope { //nolint:exhaustive // Looking for something specific.
case SiteLocal, LinkLocal, LocalMulticast:
return true
default:
return false
}
}
// IsGlobal returns true if the scope is global.
func (scope IPScope) IsGlobal() bool {
switch scope { //nolint:exhaustive // Looking for something specific.
case Global, GlobalMulticast:
return true
default:
return false
}
}
// GetBroadcastAddress returns the broadcast address of the given IP and network mask.
// If a mixed IPv4/IPv6 input is given, it returns nil.
func GetBroadcastAddress(ip net.IP, netMask net.IPMask) net.IP {
// Convert to standard v4.
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
mask := net.IP(netMask)
if ip4Mask := mask.To4(); ip4Mask != nil {
mask = ip4Mask
}
// Check for mixed v4/v6 input.
if len(ip) != len(mask) {
return nil
}
// Merge to broadcast address
n := len(ip)
broadcastAddress := make(net.IP, n)
for i := 0; i < n; i++ {
broadcastAddress[i] = ip[i] | ^mask[i]
}
return broadcastAddress
}

View File

@@ -0,0 +1,51 @@
package netutils
import (
"net"
"testing"
)
func TestIPScope(t *testing.T) {
t.Parallel()
testScope(t, net.IPv4(71, 87, 113, 211), Global)
testScope(t, net.IPv4(127, 0, 0, 1), HostLocal)
testScope(t, net.IPv4(127, 255, 255, 1), HostLocal)
testScope(t, net.IPv4(192, 168, 172, 24), SiteLocal)
testScope(t, net.IPv4(172, 15, 1, 1), Global)
testScope(t, net.IPv4(172, 16, 1, 1), SiteLocal)
testScope(t, net.IPv4(172, 31, 1, 1), SiteLocal)
testScope(t, net.IPv4(172, 32, 1, 1), Global)
}
func testScope(t *testing.T, ip net.IP, expectedScope IPScope) {
t.Helper()
c := GetIPScope(ip)
if c != expectedScope {
t.Errorf("%s is %s, expected %s", ip, scopeName(c), scopeName(expectedScope))
}
}
func scopeName(c IPScope) string {
switch c {
case Invalid:
return "invalid"
case Undefined:
return "undefined"
case HostLocal:
return "hostLocal"
case LinkLocal:
return "linkLocal"
case SiteLocal:
return "siteLocal"
case Global:
return "global"
case LocalMulticast:
return "localMulticast"
case GlobalMulticast:
return "globalMulticast"
default:
return "undefined"
}
}

View File

@@ -0,0 +1,51 @@
package netutils
import (
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/tcpassembly"
)
// SimpleStreamAssemblerManager is a simple manager for github.com/google/gopacket/tcpassembly.
type SimpleStreamAssemblerManager struct {
InitLock sync.Mutex
lastAssembler *SimpleStreamAssembler
}
// New returns a new stream assembler.
func (m *SimpleStreamAssemblerManager) New(net, transport gopacket.Flow) tcpassembly.Stream {
assembler := new(SimpleStreamAssembler)
m.lastAssembler = assembler
return assembler
}
// GetLastAssembler returns the newest created stream assembler.
func (m *SimpleStreamAssemblerManager) GetLastAssembler() *SimpleStreamAssembler {
return m.lastAssembler
}
// SimpleStreamAssembler is a simple assembler for github.com/google/gopacket/tcpassembly.
type SimpleStreamAssembler struct {
Cumulated []byte
CumulatedLen int
Complete bool
}
// NewSimpleStreamAssembler returns a new SimpleStreamAssembler.
func NewSimpleStreamAssembler() *SimpleStreamAssembler {
return &SimpleStreamAssembler{}
}
// Reassembled implements tcpassembly.Stream's Reassembled function.
func (a *SimpleStreamAssembler) Reassembled(reassembly []tcpassembly.Reassembly) {
for _, entry := range reassembly {
a.Cumulated = append(a.Cumulated, entry.Bytes...)
}
a.CumulatedLen = len(a.Cumulated)
}
// ReassemblyComplete implements tcpassembly.Stream's ReassemblyComplete function.
func (a *SimpleStreamAssembler) ReassemblyComplete() {
a.Complete = true
}

View File

@@ -0,0 +1,35 @@
package packet
import "fmt"
// BandwidthUpdate holds an update to the seen bandwidth of a connection.
type BandwidthUpdate struct {
ConnID string
BytesReceived uint64
BytesSent uint64
Method BandwidthUpdateMethod
}
// BandwidthUpdateMethod defines how the bandwidth data of a bandwidth update should be interpreted.
type BandwidthUpdateMethod uint8
// Bandwidth Update Methods.
const (
Absolute BandwidthUpdateMethod = iota
Additive
)
func (bu *BandwidthUpdate) String() string {
return fmt.Sprintf("%s: %dB recv | %dB sent [%s]", bu.ConnID, bu.BytesReceived, bu.BytesSent, bu.Method)
}
func (bum BandwidthUpdateMethod) String() string {
switch bum {
case Absolute:
return "absolute"
case Additive:
return "additive"
default:
return "unknown"
}
}

View File

@@ -0,0 +1,117 @@
package packet
import (
"errors"
"fmt"
)
// Basic Types.
type (
// IPVersion represents an IP version.
IPVersion uint8
// IPProtocol represents an IP protocol.
IPProtocol uint8
// Verdict describes the decision on a packet.
Verdict uint8
)
// Basic Constants.
const (
IPv4 = IPVersion(4)
IPv6 = IPVersion(6)
InBound = true
OutBound = false
ICMP = IPProtocol(1)
IGMP = IPProtocol(2)
TCP = IPProtocol(6)
UDP = IPProtocol(17)
ICMPv6 = IPProtocol(58)
UDPLite = IPProtocol(136)
RAW = IPProtocol(255)
AnyHostInternalProtocol61 = IPProtocol(61)
)
// Verdicts.
const (
DROP Verdict = iota
BLOCK
ACCEPT
STOLEN
QUEUE
REPEAT
STOP
)
// ErrFailedToLoadPayload is returned by GetPayload if it failed for an unspecified reason, or is not implemented on the current system.
var ErrFailedToLoadPayload = errors.New("could not load packet payload")
// ByteSize returns the byte size of the ip (IPv4 = 4 bytes, IPv6 = 16).
func (v IPVersion) ByteSize() int {
switch v {
case IPv4:
return 4
case IPv6:
return 16
}
return 0
}
// String returns the string representation of the IP version: "IPv4" or "IPv6".
func (v IPVersion) String() string {
switch v {
case IPv4:
return "IPv4"
case IPv6:
return "IPv6"
}
return fmt.Sprintf("<unknown ip version, %d>", uint8(v))
}
// String returns the string representation (abbreviation) of the protocol.
func (p IPProtocol) String() string {
switch p {
case RAW:
return "RAW"
case TCP:
return "TCP"
case UDP:
return "UDP"
case UDPLite:
return "UDPLite"
case ICMP:
return "ICMP"
case ICMPv6:
return "ICMPv6"
case IGMP:
return "IGMP"
case AnyHostInternalProtocol61:
fallthrough
default:
return fmt.Sprintf("<unknown protocol, %d>", uint8(p))
}
}
// String returns the string representation of the verdict.
func (v Verdict) String() string {
switch v {
case DROP:
return "DROP"
case BLOCK:
return "BLOCK"
case ACCEPT:
return "ACCEPT"
case STOLEN:
return "STOLEN"
case QUEUE:
return "QUEUE"
case REPEAT:
return "REPEAT"
case STOP:
return "STOP"
default:
return fmt.Sprintf("<unsupported verdict, %d>", uint8(v))
}
}

View File

@@ -0,0 +1,77 @@
package packet
import (
"errors"
"fmt"
)
// InfoPacket does not represent an actual packet, but only holds metadata.
// Implements the packet.Packet interface.
type InfoPacket struct {
Base
}
// NewInfoPacket returns a new InfoPacket with the given info.
func NewInfoPacket(info Info) *InfoPacket {
return &InfoPacket{
Base{
info: info,
},
}
}
// InfoOnly returns whether the packet is informational only and does not
// represent an actual packet.
func (pkt *InfoPacket) InfoOnly() bool {
return true
}
// LoadPacketData does nothing on Linux, as data is always fully parsed.
func (pkt *InfoPacket) LoadPacketData() error {
return fmt.Errorf("%w: info-only packet", ErrFailedToLoadPayload)
}
// ErrInfoOnlyPacket is returned for unsupported operations on an info-only packet.
var ErrInfoOnlyPacket = errors.New("info-only packet")
// Accept does nothing on an info-only packet.
func (pkt *InfoPacket) Accept() error {
return ErrInfoOnlyPacket
}
// Block does nothing on an info-only packet.
func (pkt *InfoPacket) Block() error {
return ErrInfoOnlyPacket
}
// Drop does nothing on an info-only packet.
func (pkt *InfoPacket) Drop() error {
return ErrInfoOnlyPacket
}
// PermanentAccept does nothing on an info-only packet.
func (pkt *InfoPacket) PermanentAccept() error {
return ErrInfoOnlyPacket
}
// PermanentBlock does nothing on an info-only packet.
func (pkt *InfoPacket) PermanentBlock() error {
return ErrInfoOnlyPacket
}
// PermanentDrop does nothing on an info-only packet.
func (pkt *InfoPacket) PermanentDrop() error {
return ErrInfoOnlyPacket
}
// RerouteToNameserver does nothing on an info-only packet.
func (pkt *InfoPacket) RerouteToNameserver() error {
return ErrInfoOnlyPacket
}
// RerouteToTunnel does nothing on an info-only packet.
func (pkt *InfoPacket) RerouteToTunnel() error {
return ErrInfoOnlyPacket
}
var _ Packet = &InfoPacket{}

View File

@@ -0,0 +1,266 @@
package packet
import (
"context"
"fmt"
"net"
"github.com/google/gopacket"
)
// Base is a base structure for satisfying the Packet interface.
type Base struct {
ctx context.Context
info Info
connID string
layers gopacket.Packet
layer3Data []byte
layer5Data []byte
}
// FastTrackedByIntegration returns whether the packet has been fast-track
// accepted by the OS integration.
func (pkt *Base) FastTrackedByIntegration() bool {
return false
}
// InfoOnly returns whether the packet is informational only and does not
// represent an actual packet.
func (pkt *Base) InfoOnly() bool {
return false
}
// ExpectInfo returns whether the next packet is expected to be informational only.
func (pkt *Base) ExpectInfo() bool {
return false
}
// SetCtx sets the packet context.
func (pkt *Base) SetCtx(ctx context.Context) {
pkt.ctx = ctx
}
// Ctx returns the packet context.
func (pkt *Base) Ctx() context.Context {
return pkt.ctx
}
// Info returns the packet Info.
func (pkt *Base) Info() *Info {
return &pkt.info
}
// SetPacketInfo sets a new packet Info. This must only used when initializing the packet structure.
func (pkt *Base) SetPacketInfo(packetInfo Info) {
pkt.info = packetInfo
}
// SetInbound sets a the packet direction to inbound. This must only used when initializing the packet structure.
func (pkt *Base) SetInbound() {
pkt.info.Inbound = true
}
// SetOutbound sets a the packet direction to outbound. This must only used when initializing the packet structure.
func (pkt *Base) SetOutbound() {
pkt.info.Inbound = false
}
// IsInbound checks if the packet is inbound.
func (pkt *Base) IsInbound() bool {
return pkt.info.Inbound
}
// IsOutbound checks if the packet is outbound.
func (pkt *Base) IsOutbound() bool {
return !pkt.info.Inbound
}
// HasPorts checks if the packet has a protocol that uses ports.
func (pkt *Base) HasPorts() bool {
switch pkt.info.Protocol {
case TCP:
return true
case UDP, UDPLite:
return true
case ICMP, ICMPv6, IGMP, RAW, AnyHostInternalProtocol61:
fallthrough
default:
return false
}
}
// LoadPacketData loads packet data from the integration, if not yet done.
func (pkt *Base) LoadPacketData() error {
return ErrFailedToLoadPayload
}
// Layers returns the parsed layer data.
func (pkt *Base) Layers() gopacket.Packet {
return pkt.layers
}
// Raw returns the raw Layer 3 Network Data.
func (pkt *Base) Raw() []byte {
return pkt.layer3Data
}
// Payload returns the raw Layer 5 Network Data.
func (pkt *Base) Payload() []byte {
return pkt.layer5Data
}
// GetConnectionID returns the link ID for this packet.
func (pkt *Base) GetConnectionID() string {
if pkt.connID == "" {
pkt.connID = pkt.info.CreateConnectionID()
}
return pkt.connID
}
// MatchesAddress checks if a the packet matches a given endpoint (remote or local) in protocol, network and port.
//
// Comparison matrix:
//
// ====== IN OUT
//
// Local Dst Src
// Remote Src Dst
// .
func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.IPNet, port uint16) bool {
if pkt.info.Protocol != protocol {
return false
}
if pkt.info.Inbound != remote {
if !network.Contains(pkt.info.Src) {
return false
}
if pkt.info.SrcPort != port {
return false
}
} else {
if !network.Contains(pkt.info.Dst) {
return false
}
if pkt.info.DstPort != port {
return false
}
}
return true
}
// MatchesIP checks if a the packet matches a given endpoint (remote or local) IP.
//
// Comparison matrix:
//
// ====== IN OUT
//
// Local Dst Src
// Remote Src Dst
// .
func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool {
if pkt.info.Inbound != endpoint {
if network.Contains(pkt.info.Src) {
return true
}
} else {
if network.Contains(pkt.info.Dst) {
return true
}
}
return false
}
// FORMATTING
func (pkt *Base) String() string {
return pkt.FmtPacket()
}
// FmtPacket returns the most important information about the packet as a string.
func (pkt *Base) FmtPacket() string {
if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP {
if pkt.info.Inbound {
return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort)
}
return fmt.Sprintf("OUT %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort)
}
if pkt.info.Inbound {
return fmt.Sprintf("IN %s %s <-> %s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src)
}
return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst)
}
// FmtProtocol returns the protocol as a string.
func (pkt *Base) FmtProtocol() string {
return pkt.info.Protocol.String()
}
// FmtRemoteIP returns the remote IP address as a string.
func (pkt *Base) FmtRemoteIP() string {
if pkt.info.Inbound {
return pkt.info.Src.String()
}
return pkt.info.Dst.String()
}
// FmtRemotePort returns the remote port as a string.
func (pkt *Base) FmtRemotePort() string {
if pkt.info.SrcPort != 0 {
if pkt.info.Inbound {
return fmt.Sprintf("%d", pkt.info.SrcPort)
}
return fmt.Sprintf("%d", pkt.info.DstPort)
}
return "-"
}
// FmtRemoteAddress returns the full remote address (protocol, IP, port) as a string.
func (pkt *Base) FmtRemoteAddress() string {
return fmt.Sprintf("%s:%s:%s", pkt.info.Protocol.String(), pkt.FmtRemoteIP(), pkt.FmtRemotePort())
}
// Packet is an interface to a network packet to provide object behavior the same across all systems.
type Packet interface {
// Verdicts.
Accept() error
Block() error
Drop() error
PermanentAccept() error
PermanentBlock() error
PermanentDrop() error
RerouteToNameserver() error
RerouteToTunnel() error
FastTrackedByIntegration() bool
InfoOnly() bool
ExpectInfo() bool
// Info.
SetCtx(context.Context)
Ctx() context.Context
Info() *Info
SetPacketInfo(Info)
IsInbound() bool
IsOutbound() bool
SetInbound()
SetOutbound()
HasPorts() bool
GetConnectionID() string
// Payload.
LoadPacketData() error
Layers() gopacket.Packet
Raw() []byte
Payload() []byte
// Matching.
MatchesAddress(bool, IPProtocol, *net.IPNet, uint16) bool
MatchesIP(bool, *net.IPNet) bool
// Formatting.
String() string
FmtPacket() string
FmtProtocol() string
FmtRemoteIP() string
FmtRemotePort() string
FmtRemoteAddress() string
}

View File

@@ -0,0 +1,76 @@
package packet
import (
"fmt"
"net"
"time"
)
// Info holds IP and TCP/UDP header information.
type Info struct {
Inbound bool
InTunnel bool
Version IPVersion
Protocol IPProtocol
SrcPort, DstPort uint16
Src, Dst net.IP
PID int
SeenAt time.Time
}
// LocalIP returns the local IP of the packet.
func (pi *Info) LocalIP() net.IP {
if pi.Inbound {
return pi.Dst
}
return pi.Src
}
// RemoteIP returns the remote IP of the packet.
func (pi *Info) RemoteIP() net.IP {
if pi.Inbound {
return pi.Src
}
return pi.Dst
}
// LocalPort returns the local port of the packet.
func (pi *Info) LocalPort() uint16 {
if pi.Inbound {
return pi.DstPort
}
return pi.SrcPort
}
// RemotePort returns the remote port of the packet.
func (pi *Info) RemotePort() uint16 {
if pi.Inbound {
return pi.SrcPort
}
return pi.DstPort
}
// CreateConnectionID creates a connection ID.
// In most circumstances, this method should not be used directly, but
// packet.GetConnectionID() should be called instead.
func (pi *Info) CreateConnectionID() string {
return CreateConnectionID(pi.Protocol, pi.Src, pi.SrcPort, pi.Dst, pi.DstPort, pi.Inbound)
}
// CreateConnectionID creates a connection ID.
func CreateConnectionID(protocol IPProtocol, src net.IP, srcPort uint16, dst net.IP, dstPort uint16, inbound bool) string {
// TODO: make this ID not depend on the packet direction for better support for forwarded packets.
if protocol == TCP || protocol == UDP {
if inbound {
return fmt.Sprintf("%d-%s-%d-%s-%d", protocol, dst, dstPort, src, srcPort)
}
return fmt.Sprintf("%d-%s-%d-%s-%d", protocol, src, srcPort, dst, dstPort)
}
if inbound {
return fmt.Sprintf("%d-%s-%s", protocol, dst, src)
}
return fmt.Sprintf("%d-%s-%s", protocol, src, dst)
}

View File

@@ -0,0 +1,160 @@
package packet
import (
"errors"
"fmt"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var layerType2IPProtocol map[gopacket.LayerType]IPProtocol
func genIPProtocolFromLayerType() {
layerType2IPProtocol = make(map[gopacket.LayerType]IPProtocol)
for k, v := range layers.IPProtocolMetadata {
layerType2IPProtocol[v.LayerType] = IPProtocol(k)
}
}
func parseIPv4(packet gopacket.Packet, info *Info) error {
if ipv4, ok := packet.NetworkLayer().(*layers.IPv4); ok {
info.Version = IPv4
info.Src = ipv4.SrcIP
info.Dst = ipv4.DstIP
info.Protocol = IPProtocol(ipv4.Protocol)
}
return nil
}
func parseIPv6(packet gopacket.Packet, info *Info) error {
if ipv6, ok := packet.NetworkLayer().(*layers.IPv6); ok {
info.Version = IPv6
info.Src = ipv6.SrcIP
info.Dst = ipv6.DstIP
// we set Protocol to NextHeader as a fallback. If TCP or
// UDP layers are detected (somewhere in the list of options)
// the Protocol field is adjusted correctly.
info.Protocol = IPProtocol(ipv6.NextHeader)
}
return nil
}
func parseTCP(packet gopacket.Packet, info *Info) error {
if tcp, ok := packet.TransportLayer().(*layers.TCP); ok {
info.Protocol = TCP
info.SrcPort = uint16(tcp.SrcPort)
info.DstPort = uint16(tcp.DstPort)
}
return nil
}
func parseUDP(packet gopacket.Packet, info *Info) error {
if udp, ok := packet.TransportLayer().(*layers.UDP); ok {
info.Protocol = UDP
info.SrcPort = uint16(udp.SrcPort)
info.DstPort = uint16(udp.DstPort)
}
return nil
}
/*
func parseUDPLite(packet gopacket.Packet, info *Info) error {
if udpLite, ok := packet.TransportLayer().(*layers.UDPLite); ok {
info.Protocol = UDPLite
info.SrcPort = uint16(udpLite.SrcPort)
info.DstPort = uint16(udpLite.DstPort)
}
return nil
}
*/
func parseICMPv4(packet gopacket.Packet, info *Info) error {
if icmp, ok := packet.Layer(layers.LayerTypeICMPv4).(*layers.ICMPv4); ok {
info.Protocol = ICMP
_ = icmp
}
return nil
}
func parseICMPv6(packet gopacket.Packet, info *Info) error {
if icmp6, ok := packet.Layer(layers.LayerTypeICMPv6).(*layers.ICMPv6); ok {
info.Protocol = ICMPv6
_ = icmp6
}
return nil
}
func parseIGMP(packet gopacket.Packet, info *Info) error {
// gopacket uses LayerTypeIGMP for v1, v2 and v3 and may thus
// either return layers.IGMP or layers.IGMPv1or2
if layer := packet.Layer(layers.LayerTypeIGMP); layer != nil {
info.Protocol = IGMP
}
return nil
}
func checkError(packet gopacket.Packet, info *Info) error {
// Check for known unparseable before checking the error layer.
if info.Protocol == AnyHostInternalProtocol61 {
return nil
}
if err := packet.ErrorLayer(); err != nil {
return err.Error()
}
return nil
}
// Parse parses an IP packet and saves the information in the given packet object.
func Parse(packetData []byte, pktBase *Base) (err error) {
if len(packetData) == 0 {
return errors.New("empty packet")
}
pktBase.layer3Data = packetData
ipVersion := packetData[0] >> 4
var networkLayerType gopacket.LayerType
switch ipVersion {
case 4:
networkLayerType = layers.LayerTypeIPv4
case 6:
networkLayerType = layers.LayerTypeIPv6
default:
return fmt.Errorf("unknown IP version or network protocol: %02x", ipVersion)
}
packet := gopacket.NewPacket(packetData, networkLayerType, gopacket.DecodeOptions{
Lazy: true,
NoCopy: true,
})
availableDecoders := []func(gopacket.Packet, *Info) error{
parseIPv4,
parseIPv6,
parseTCP,
parseUDP,
// parseUDPLite, // We don't yet support udplite.
parseICMPv4,
parseICMPv6,
parseIGMP,
checkError,
}
for _, dec := range availableDecoders {
if err := dec(packet, pktBase.Info()); err != nil {
return err
}
}
pktBase.layers = packet
if transport := packet.TransportLayer(); transport != nil {
pktBase.layer5Data = transport.LayerPayload()
}
return nil
}
func init() {
genIPProtocolFromLayerType()
}

57
service/network/ports.go Normal file
View File

@@ -0,0 +1,57 @@
package network
import (
"github.com/safing/portbase/log"
"github.com/safing/portbase/rng"
)
// GetUnusedLocalPort returns a local port of the specified protocol that is
// currently unused and is unlikely to be used within the next seconds.
func GetUnusedLocalPort(protocol uint8) (port uint16, ok bool) {
allConns := conns.clone()
tries := 1000
// Try up to 1000 times to find an unused port.
nextPort:
for i := 0; i < tries; i++ {
// Generate random port between 10000 and 65535
rN, err := rng.Number(55535)
if err != nil {
log.Warningf("network: failed to generate random port: %s", err)
return 0, false
}
port := uint16(rN + 10000)
// Shrink range when we chew through the tries.
portRangeStart := port - 10
// Check if the generated port is unused.
nextConnection:
for _, conn := range allConns {
switch {
case !conn.DataIsComplete():
// Skip connection if the data is not complete.
continue nextConnection
case conn.Entity.Protocol != protocol:
// Skip connection if the protocol does not match the protocol of interest.
continue nextConnection
case conn.LocalPort <= port && conn.LocalPort >= portRangeStart:
// Skip port if the local port is in dangerous proximity.
// Consecutive port numbers are very common.
continue nextPort
}
}
// Log if it took more than 10 attempts.
if i >= 10 {
log.Warningf("network: took %d attempts to find a suitable unused port for pre-auth", i+1)
}
// The checks have passed. We have found a good unused port.
return port, true
}
return 0, false
}

View File

@@ -0,0 +1,107 @@
//go:build linux
package proc
import (
"errors"
"io/fs"
"os"
"strconv"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/socket"
)
// GetPID returns the already existing pid of the given socket info or searches for it.
// This also acts as a getter for socket.Info.PID, as locking for that occurs here.
func GetPID(socketInfo socket.Info) (pid int) {
// Get currently assigned PID to the socket info.
currentPid := socketInfo.GetPID()
// If the current PID already is valid (ie. not unidentified), return it immediately.
if currentPid != socket.UndefinedProcessID {
return currentPid
}
// Find PID for the given UID and inode.
pid = findPID(socketInfo.GetUIDandInode())
// Set the newly found PID on the socket info.
socketInfo.SetPID(pid)
// Return found PID.
return pid
}
// findPID returns the pid of the given uid and socket inode.
func findPID(uid, inode int) (pid int) {
socketName := "socket:[" + strconv.Itoa(inode) + "]"
// Always update pid table (it has a call limiter anyway)
updatePids()
// Get all pids for the given uid.
pids, ok := getPidsByUser(uid)
if !ok {
return socket.UndefinedProcessID
}
// Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to
// be searched for.
for j := len(pids) - 1; j >= 0; j-- {
if pidHasSocket(pids[j], socketName) {
return pids[j]
}
}
return socket.UndefinedProcessID
}
func pidHasSocket(pid int, socketName string) bool {
socketBase := "/proc/" + strconv.Itoa(pid) + "/fd"
entries := readDirNames(socketBase)
if len(entries) == 0 {
return false
}
socketBase += "/"
// Look through the FDs in reverse order, because higher/newer FDs will be
// more likely to be searched for.
for i := len(entries) - 1; i >= 0; i-- {
link, err := os.Readlink(socketBase + entries[i])
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Warningf("proc: failed to read link /proc/%d/fd/%s: %s", pid, entries[i], err)
}
continue
}
if link == socketName {
return true
}
}
return false
}
// readDirNames only reads the directory names. Using os.ReadDir() would call `lstat` on every
// resulting directory name, which we don't need. This function will be called a lot, so we should
// refrain from unnecessary work.
func readDirNames(dir string) (names []string) {
file, err := os.Open(dir)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Warningf("proc: could not open directory %s: %s", dir, err)
}
return
}
defer func() {
_ = file.Close()
}()
names, err = file.Readdirnames(0)
if err != nil {
log.Warningf("proc: could not get entries from directory %s: %s", dir, err)
return []string{}
}
return
}

View File

@@ -0,0 +1,81 @@
//go:build linux
package proc
import (
"errors"
"io/fs"
"os"
"strconv"
"sync"
"syscall"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
)
var (
// pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs.
pidsByUser = make(map[int][]int)
pidsByUserLock sync.RWMutex
fetchPidsByUser = utils.NewCallLimiter(10 * time.Millisecond)
)
// getPidsByUser returns the cached PIDs for the given UID.
func getPidsByUser(uid int) (pids []int, ok bool) {
pidsByUserLock.RLock()
defer pidsByUserLock.RUnlock()
pids, ok = pidsByUser[uid]
return
}
// updatePids fetches and creates a new pidsByUser map using a call limiter.
func updatePids() {
fetchPidsByUser.Do(func() {
newPidsByUser := make(map[int][]int)
pidCnt := 0
entries := readDirNames("/proc")
if len(entries) == 0 {
log.Warning("proc: found no PIDs in /proc")
return
}
entryLoop:
for _, entry := range entries {
pid, err := strconv.ParseInt(entry, 10, 32)
if err != nil {
continue entryLoop
}
statData, err := os.Stat("/proc/" + strconv.FormatInt(pid, 10))
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Warningf("proc: could not stat /proc/%d: %s", pid, err)
}
continue entryLoop
}
sys, ok := statData.Sys().(*syscall.Stat_t)
if !ok {
log.Warningf("proc: unable to parse /proc/%d: wrong type", pid)
continue entryLoop
}
pids, ok := newPidsByUser[int(sys.Uid)]
if ok {
newPidsByUser[int(sys.Uid)] = append(pids, int(pid))
} else {
newPidsByUser[int(sys.Uid)] = []int{int(pid)}
}
pidCnt++
}
// log.Tracef("proc: updated PID table with %d entries", pidCnt)
pidsByUserLock.Lock()
defer pidsByUserLock.Unlock()
pidsByUser = newPidsByUser
})
}

View File

@@ -0,0 +1,244 @@
//go:build linux
package proc
import (
"bufio"
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
"strings"
"unicode"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/socket"
)
/*
1. find socket inode
- by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP?
- /proc/net/{tcp|udp}[6]
2. get list of processes of uid
3. find socket inode in process fds
- if not found, refresh map of uid->pids
- if not found, check ALL pids: maybe euid != uid
4. gather process info
Cache every step!
*/
// Network Related Constants.
const (
TCP4 uint8 = iota
UDP4
TCP6
UDP6
ICMP4
ICMP6
tcp4ProcFile = "/proc/net/tcp"
tcp6ProcFile = "/proc/net/tcp6"
udp4ProcFile = "/proc/net/udp"
udp6ProcFile = "/proc/net/udp6"
tcpListenStateHex = "0A"
)
// GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP4, tcp4ProcFile)
}
// GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP6, tcp6ProcFile)
}
// GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP4, udp4ProcFile)
return
}
// GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP6, udp6ProcFile)
return
}
const (
// hint: we split fields by multiple delimiters, see procDelimiter
fieldIndexLocalIP = 1
fieldIndexLocalPort = 2
fieldIndexRemoteIP = 3
fieldIndexRemotePort = 4
fieldIndexUID = 11
fieldIndexInode = 13
)
func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
var ipConverter func(string) net.IP
switch stack {
case TCP4, UDP4:
ipConverter = convertIPv4
case TCP6, UDP6:
ipConverter = convertIPv6
default:
return nil, nil, fmt.Errorf("unsupported table stack: %d", stack)
}
// open file
socketData, err := os.Open(procFile)
if err != nil {
return nil, nil, err
}
defer func() {
_ = socketData.Close()
}()
// file scanner
scanner := bufio.NewScanner(socketData)
scanner.Split(bufio.ScanLines)
// parse
scanner.Scan() // skip first row
for scanner.Scan() {
fields := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(fields) < 14 {
// log.Tracef("proc: too short: %s", fields)
continue
}
localIP := ipConverter(fields[fieldIndexLocalIP])
if localIP == nil {
continue
}
localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16)
if err != nil {
log.Warningf("proc: could not parse port: %s", err)
continue
}
uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32)
// log.Tracef("uid: %s", fields[fieldIndexUID])
if err != nil {
log.Warningf("proc: could not parse uid %s: %s", fields[11], err)
continue
}
inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32)
// log.Tracef("inode: %s", fields[fieldIndexInode])
if err != nil {
log.Warningf("proc: could not parse inode %s: %s", fields[13], err)
continue
}
switch stack {
case UDP4, UDP6:
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
PID: socket.UndefinedProcessID,
UID: int(uid),
Inode: int(inode),
})
case TCP4, TCP6:
if fields[5] == tcpListenStateHex {
// listener
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
PID: socket.UndefinedProcessID,
UID: int(uid),
Inode: int(inode),
})
} else {
// connection
remoteIP := ipConverter(fields[fieldIndexRemoteIP])
if remoteIP == nil {
continue
}
remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16)
if err != nil {
log.Warningf("proc: could not parse port: %s", err)
continue
}
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
Remote: socket.Address{
IP: remoteIP,
Port: uint16(remotePort),
},
PID: socket.UndefinedProcessID,
UID: int(uid),
Inode: int(inode),
})
}
}
}
return connections, binds, nil
}
func procDelimiter(c rune) bool {
return unicode.IsSpace(c) || c == ':'
}
func convertIPv4(data string) net.IP {
// Decode and bullshit check the data length.
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("proc: could not parse IPv4 %s: %s", data, err)
return nil
}
if len(decoded) != 4 {
log.Warningf("proc: decoded IPv4 %s has wrong length", decoded)
return nil
}
// Build the IPv4 address with the reversed byte order.
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
return ip
}
func convertIPv6(data string) net.IP {
// Decode and bullshit check the data length.
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("proc: could not parse IPv6 %s: %s", data, err)
return nil
}
if len(decoded) != 16 {
log.Warningf("proc: decoded IPv6 %s has wrong length", decoded)
return nil
}
// Build the IPv6 address with the translated byte order.
for i := 0; i < 16; i += 4 {
decoded[i], decoded[i+1], decoded[i+2], decoded[i+3] = decoded[i+3], decoded[i+2], decoded[i+1], decoded[i]
}
ip := net.IP(decoded)
return ip
}

View File

@@ -0,0 +1,62 @@
//go:build linux
package proc
import (
"fmt"
"testing"
)
func TestSockets(t *testing.T) {
t.Parallel()
connections, listeners, err := GetTCP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 4 connections:")
for _, connection := range connections {
pid := GetPID(connection)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners {
pid := GetPID(listener)
fmt.Printf("%d: %+v\n", pid, listener)
}
connections, listeners, err = GetTCP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 6 connections:")
for _, connection := range connections {
pid := GetPID(connection)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners {
pid := GetPID(listener)
fmt.Printf("%d: %+v\n", pid, listener)
}
binds, err := GetUDP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 4 binds:")
for _, bind := range binds {
pid := GetPID(bind)
fmt.Printf("%d: %+v\n", pid, bind)
}
binds, err = GetUDP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 6 binds:")
for _, bind := range binds {
pid := GetPID(bind)
fmt.Printf("%d: %+v\n", pid, bind)
}
}

View File

@@ -0,0 +1,86 @@
package reference
import (
"strconv"
"strings"
)
var (
portNames = map[uint16]string{
20: "FTP-DATA",
21: "FTP",
22: "SSH",
23: "TELNET",
25: "SMTP",
43: "WHOIS",
53: "DNS",
67: "DHCP_SERVER",
68: "DHCP_CLIENT",
69: "TFTP",
80: "HTTP",
110: "POP3",
123: "NTP",
143: "IMAP",
161: "SNMP",
179: "BGP",
194: "IRC",
389: "LDAP",
443: "HTTPS",
445: "SMB",
587: "SMTP_ALT",
465: "SMTP_SSL",
993: "IMAP_SSL",
995: "POP3_SSL",
}
portNumbers = map[string]uint16{
"FTP-DATA": 20,
"FTP": 21,
"SSH": 22,
"TELNET": 23,
"SMTP": 25,
"WHOIS": 43,
"DNS": 53,
"DHCP-SERVER": 67,
"DHCP_SERVER": 67,
"DHCP-CLIENT": 68,
"DHCP_CLIENT": 68,
"TFTP": 69,
"HTTP": 80,
"POP3": 110,
"NTP": 123,
"IMAP": 143,
"SNMP": 161,
"BGP": 179,
"IRC": 194,
"LDAP": 389,
"HTTPS": 443,
"SMB": 445,
"SMTP-ALT": 587,
"SMTP_ALT": 587,
"SMTP-SSL": 465,
"SMTP_SSL": 465,
"IMAP-SSL": 993,
"IMAP_SSL": 993,
"POP3-SSL": 995,
"POP3_SSL": 995,
}
)
// GetPortName returns the name of a port number.
func GetPortName(port uint16) (name string) {
name, ok := portNames[port]
if ok {
return name
}
return strconv.Itoa(int(port))
}
// GetPortNumber returns the number of a port name.
func GetPortNumber(port string) (number uint16, ok bool) {
number, ok = portNumbers[strings.ToUpper(port)]
if ok {
return number, true
}
return 0, false
}

View File

@@ -0,0 +1,75 @@
package reference
import (
"strconv"
"strings"
)
var (
protocolNames = map[uint8]string{
1: "ICMP",
2: "IGMP",
6: "TCP",
17: "UDP",
27: "RDP",
58: "ICMP6",
33: "DCCP",
136: "UDP-LITE",
}
protocolNumbers = map[string]uint8{
"ICMP": 1,
"IGMP": 2,
"TCP": 6,
"UDP": 17,
"RDP": 27,
"DCCP": 33,
"ICMP6": 58,
"UDP-LITE": 136,
}
)
// GetProtocolName returns the name of a IP protocol number.
func GetProtocolName(protocol uint8) (name string) {
name, ok := protocolNames[protocol]
if ok {
return name
}
return strconv.Itoa(int(protocol))
}
// GetProtocolNumber returns the number of a IP protocol name.
func GetProtocolNumber(protocol string) (number uint8, ok bool) {
number, ok = protocolNumbers[strings.ToUpper(protocol)]
if ok {
return number, true
}
return 0, false
}
// IsPacketProtocol returns whether the given protocol number is a known packet based protocol.
// Note: Not fully complete. Calling IsPacketProtocol() does not equal calling !IsStreamProtocol().
func IsPacketProtocol(protocol uint8) bool {
switch protocol {
case 1, // ICMP
17, // UDP
27, // RDP
58, // ICMP6
33, // DCCP
136: // UDP-LITE
return true
default:
return false
}
}
// IsStreamProtocol returns whether the given protocol number is a known stream based protocol.
// Note: Not fully complete. Calling IsPacketProtocol() does not equal calling !IsStreamProtocol().
func IsStreamProtocol(protocol uint8) bool {
switch protocol {
case 6: // TCP
return true
default:
return false
}
}

View File

@@ -0,0 +1,120 @@
package socket
import (
"net"
"sync"
)
const (
// UndefinedProcessID signifies that the process ID is unknown.
// It must match portmaster/process.UndefinedProcessID
// It is duplicated here because of import loops.
UndefinedProcessID = -1
)
// ConnectionInfo holds socket information returned by the system.
type ConnectionInfo struct {
sync.Mutex
Local Address
Remote Address
PID int
UID int
Inode int
}
// BindInfo holds socket information returned by the system.
type BindInfo struct {
sync.Mutex
Local Address
PID int
UID int
Inode int
ListensAny bool
}
// Address is an IP + Port pair.
type Address struct {
IP net.IP
Port uint16
}
// Info is a generic interface to both ConnectionInfo and BindInfo.
type Info interface {
GetPID() int
SetPID(int)
GetUID() int
GetUIDandInode() (int, int)
}
// GetPID returns the PID.
func (i *ConnectionInfo) GetPID() int {
i.Lock()
defer i.Unlock()
return i.PID
}
// SetPID sets the PID to the given value.
func (i *ConnectionInfo) SetPID(pid int) {
i.Lock()
defer i.Unlock()
i.PID = pid
}
// GetUID returns the UID.
func (i *ConnectionInfo) GetUID() int {
i.Lock()
defer i.Unlock()
return i.UID
}
// GetUIDandInode returns the UID and Inode.
func (i *ConnectionInfo) GetUIDandInode() (int, int) {
i.Lock()
defer i.Unlock()
return i.UID, i.Inode
}
// GetPID returns the PID.
func (i *BindInfo) GetPID() int {
i.Lock()
defer i.Unlock()
return i.PID
}
// SetPID sets the PID to the given value.
func (i *BindInfo) SetPID(pid int) {
i.Lock()
defer i.Unlock()
i.PID = pid
}
// GetUID returns the UID.
func (i *BindInfo) GetUID() int {
i.Lock()
defer i.Unlock()
return i.UID
}
// GetUIDandInode returns the UID and Inode.
func (i *BindInfo) GetUIDandInode() (int, int) {
i.Lock()
defer i.Unlock()
return i.UID, i.Inode
}
// Compile time checks.
var (
_ Info = new(ConnectionInfo)
_ Info = new(BindInfo)
)

View File

@@ -0,0 +1,102 @@
package state
import (
"time"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/network/socket"
)
const (
// UDPConnectionTTL defines the duration after which unseen UDP connections are regarded as ended.
UDPConnectionTTL = 10 * time.Minute
)
// Exists checks if the given connection is present in the system state tables.
func Exists(pktInfo *packet.Info, now time.Time) (exists bool) {
// TODO: create lookup maps before running a flurry of Exists() checks.
switch {
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
return tcp4Table.exists(pktInfo)
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP:
return tcp6Table.exists(pktInfo)
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP:
return udp4Table.exists(pktInfo, now)
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP:
return udp6Table.exists(pktInfo, now)
default:
return false
}
}
func (table *tcpTable) exists(pktInfo *packet.Info) (exists bool) {
// Update tables if older than the connection that is checked.
if table.lastUpdateAt.Load() < pktInfo.SeenAt.UnixNano() {
table.updateTables()
}
table.lock.RLock()
defer table.lock.RUnlock()
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
remoteIP := pktInfo.RemoteIP()
remotePort := pktInfo.RemotePort()
// search connections
for _, socketInfo := range table.connections {
if localPort == socketInfo.Local.Port &&
remotePort == socketInfo.Remote.Port &&
remoteIP.Equal(socketInfo.Remote.IP) &&
localIP.Equal(socketInfo.Local.IP) {
return true
}
}
return false
}
func (table *udpTable) exists(pktInfo *packet.Info, now time.Time) (exists bool) {
// Update tables if older than the connection that is checked.
if table.lastUpdateAt.Load() < pktInfo.SeenAt.UnixNano() {
table.updateTables()
}
table.lock.RLock()
defer table.lock.RUnlock()
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
remoteIP := pktInfo.RemoteIP()
remotePort := pktInfo.RemotePort()
connThreshhold := now.Add(-UDPConnectionTTL)
// search binds
for _, socketInfo := range table.binds {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
udpConnState, ok := table.getConnState(socketInfo, socket.Address{
IP: remoteIP,
Port: remotePort,
})
switch {
case !ok:
return false
case udpConnState.lastSeen.After(connThreshhold):
return true
default:
return false
}
}
}
return false
}

View File

@@ -0,0 +1,38 @@
package state
import (
"sync"
"github.com/safing/portbase/database/record"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network/socket"
)
// Info holds network state information as provided by the system.
type Info struct {
record.Base
sync.Mutex
TCP4Connections []*socket.ConnectionInfo
TCP4Listeners []*socket.BindInfo
TCP6Connections []*socket.ConnectionInfo
TCP6Listeners []*socket.BindInfo
UDP4Binds []*socket.BindInfo
UDP6Binds []*socket.BindInfo
}
// GetInfo returns all system state tables. The returned data must not be modified.
func GetInfo() *Info {
info := &Info{}
info.TCP4Connections, info.TCP4Listeners = tcp4Table.updateTables()
info.UDP4Binds = udp4Table.updateTables()
if netenv.IPv6Enabled() {
info.TCP6Connections, info.TCP6Listeners = tcp6Table.updateTables()
info.UDP6Binds = udp6Table.updateTables()
}
info.UpdateMeta()
return info
}

View File

@@ -0,0 +1,264 @@
package state
import (
"errors"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/network/socket"
)
// - TCP
// - Outbound: Match listeners (in!), then connections (out!)
// - Inbound: Match listeners (in!), then connections (out!)
// - Clean via connections
// - UDP
// - Any connection: match specific local address or zero IP
// - In or out: save direction of first packet:
// - map[<local udp bind ip+port>]map[<remote ip+port>]{direction, lastSeen}
// - only clean if <local udp bind ip+port> is removed by OS
// - limit <remote ip+port> to 256 entries?
// - clean <remote ip+port> after 72hrs?
// - switch direction to outbound if outbound packet is seen?
// - IP: Unidentified Process
// Errors.
var (
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
ErrPIDNotFound = errors.New("could not find pid for socket inode")
)
const (
lookupTries = 5
fastLookupTries = 2
)
// Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound.
func Lookup(pktInfo *packet.Info, fast bool) (pid int, inbound bool, err error) {
// auto-detect version
if pktInfo.Version == 0 {
if ip := pktInfo.LocalIP().To4(); ip != nil {
pktInfo.Version = packet.IPv4
} else {
pktInfo.Version = packet.IPv6
}
}
switch {
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
return tcp4Table.lookup(pktInfo, fast)
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP:
return tcp6Table.lookup(pktInfo, fast)
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP:
return udp4Table.lookup(pktInfo, fast)
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP:
return udp6Table.lookup(pktInfo, fast)
default:
return socket.UndefinedProcessID, pktInfo.Inbound, errors.New("unsupported protocol for finding process")
}
}
func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) (
pid int,
inbound bool,
err error,
) {
// Prepare variables.
var (
connections []*socket.ConnectionInfo
listeners []*socket.BindInfo
dualStackConnections []*socket.ConnectionInfo
dualStackListeners []*socket.BindInfo
)
// Search for the socket until found.
for i := 1; i <= lookupTries; i++ {
// Use existing tables for first check if packet was seen after last table update.
if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() {
connections, listeners = table.getCurrentTables()
} else {
connections, listeners = table.updateTables()
}
// Check tables for socket.
socketInfo, inbound := findTCPSocket(pktInfo, connections, listeners)
// If there's a match, check if we have the PID and return.
if socketInfo != nil {
return CheckPID(socketInfo, inbound)
}
// DUAL-STACK
// Skip if dualStack is not enabled.
if table.dualStack == nil {
continue
}
// Use existing tables for first check if packet was seen after last table update.
if i == 1 && pktInfo.SeenAt.UnixNano() >= table.dualStack.lastUpdateAt.Load() {
dualStackConnections, dualStackListeners = table.dualStack.getCurrentTables()
} else {
dualStackConnections, dualStackListeners = table.dualStack.updateTables()
}
// Check tables for socket.
socketInfo, inbound = findTCPSocket(pktInfo, dualStackConnections, dualStackListeners)
// If there's a match, check if we have the PID and return.
if socketInfo != nil {
return CheckPID(socketInfo, inbound)
}
// Search less if we want to be fast.
if fast && i >= fastLookupTries {
break
}
}
return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound
}
func findTCPSocket(
pktInfo *packet.Info,
connections []*socket.ConnectionInfo,
listeners []*socket.BindInfo,
) (
socketInfo socket.Info,
inbound bool,
) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
// always search listeners first
for _, socketInfo := range listeners {
if localPort == socketInfo.Local.Port &&
(socketInfo.ListensAny || localIP.Equal(socketInfo.Local.IP)) {
return socketInfo, true
}
}
remoteIP := pktInfo.RemoteIP()
remotePort := pktInfo.RemotePort()
// search connections
for _, socketInfo := range connections {
if localPort == socketInfo.Local.Port &&
remotePort == socketInfo.Remote.Port &&
remoteIP.Equal(socketInfo.Remote.IP) &&
localIP.Equal(socketInfo.Local.IP) {
return socketInfo, false
}
}
return nil, false
}
func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) (
pid int,
inbound bool,
err error,
) {
// TODO: Currently broadcast/multicast scopes are not checked, so we might
// attribute an incoming broadcast/multicast packet to the wrong process if
// there are multiple processes listening on the same local port, but
// binding to different addresses. This highly unusual for clients.
isInboundMulticast := pktInfo.Inbound && netutils.GetIPScope(pktInfo.LocalIP()) == netutils.LocalMulticast
// Prepare variables.
var (
binds []*socket.BindInfo
dualStackBinds []*socket.BindInfo
)
// Search for the socket until found.
for i := 1; i <= lookupTries; i++ {
// Get or update tables.
if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() {
binds = table.getCurrentTables()
} else {
binds = table.updateTables()
}
// Check tables for socket.
socketInfo := findUDPSocket(pktInfo, binds, isInboundMulticast)
// If there's a match, do some last checks and return.
if socketInfo != nil {
// If there is no remote port, do check for the direction of the
// connection. This will be the case for pure checking functions
// that do not want to change direction state.
if pktInfo.RemotePort() == 0 {
return CheckPID(socketInfo, pktInfo.Inbound)
}
// Get (and save) the direction of the connection.
connInbound := table.getDirection(socketInfo, pktInfo)
// Check we have the PID and return.
return CheckPID(socketInfo, connInbound)
}
// DUAL-STACK
// Skip if dualStack is not enabled.
if table.dualStack == nil {
continue
}
// Get or update tables.
if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() {
dualStackBinds = table.dualStack.getCurrentTables()
} else {
dualStackBinds = table.dualStack.updateTables()
}
// Check tables for socket.
socketInfo = findUDPSocket(pktInfo, dualStackBinds, isInboundMulticast)
// If there's a match, do some last checks and return.
if socketInfo != nil {
// If there is no remote port, do check for the direction of the
// connection. This will be the case for pure checking functions
// that do not want to change direction state.
if pktInfo.RemotePort() == 0 {
return CheckPID(socketInfo, pktInfo.Inbound)
}
// Get (and save) the direction of the connection.
connInbound := table.getDirection(socketInfo, pktInfo)
// Check we have the PID and return.
return CheckPID(socketInfo, connInbound)
}
// Search less if we want to be fast.
if fast && i >= fastLookupTries {
break
}
}
return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound
}
func findUDPSocket(pktInfo *packet.Info, binds []*socket.BindInfo, isInboundMulticast bool) (socketInfo *socket.BindInfo) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
// search binds
for _, socketInfo := range binds {
if localPort == socketInfo.Local.Port &&
(socketInfo.ListensAny || // zero IP (dual-stack)
isInboundMulticast || // inbound broadcast, multicast
localIP.Equal(socketInfo.Local.IP)) {
return socketInfo
}
}
return nil
}

View File

@@ -0,0 +1,46 @@
//go:build !windows && !linux
// +build !windows,!linux
package state
import (
"time"
"github.com/safing/portbase/config"
"github.com/safing/portmaster/service/network/socket"
)
func init() {
// This increases performance on unsupported system.
// It's not critical at all and does not break anything if it fails.
go func() {
// Wait for one minute before we set the default value, as we
// currently cannot easily integrate into the startup procedure.
time.Sleep(1 * time.Minute)
// We cannot use process.CfgOptionEnableProcessDetectionKey, because of an import loop.
config.SetDefaultConfigOption("core/enableProcessDetection", false)
}()
}
func getTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return nil, nil, nil
}
func getTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return nil, nil, nil
}
func getUDP4Table() (binds []*socket.BindInfo, err error) {
return nil, nil
}
func getUDP6Table() (binds []*socket.BindInfo, err error) {
return nil, nil
}
// CheckPID checks the if socket info already has a PID and if not, tries to find it.
// Depending on the OS, this might be a no-op.
func CheckPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
return socketInfo.GetPID(), connInbound, nil
}

View File

@@ -0,0 +1,40 @@
package state
import (
"time"
"github.com/safing/portmaster/service/network/proc"
"github.com/safing/portmaster/service/network/socket"
)
var (
getTCP4Table = proc.GetTCP4Table
getTCP6Table = proc.GetTCP6Table
getUDP4Table = proc.GetUDP4Table
getUDP6Table = proc.GetUDP6Table
checkPIDTries = 5
checkPIDBaseWaitTime = 5 * time.Millisecond
)
// CheckPID checks the if socket info already has a PID and if not, tries to find it.
// Depending on the OS, this might be a no-op.
func CheckPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
for i := 1; i <= checkPIDTries; i++ {
// look for PID
pid = proc.GetPID(socketInfo)
if pid != socket.UndefinedProcessID {
// if we found a PID, return
break
}
// every time, except for the last iteration
if i < checkPIDTries {
// we found no PID, we could have been too fast, give the kernel some time to think
// back off timer: with 5ms baseWaitTime: 5, 10, 15, 20, 25 - 75ms in total
time.Sleep(time.Duration(i) * checkPIDBaseWaitTime)
}
}
return pid, connInbound, nil
}

View File

@@ -0,0 +1,19 @@
package state
import (
"github.com/safing/portmaster/service/network/iphelper"
"github.com/safing/portmaster/service/network/socket"
)
var (
getTCP4Table = iphelper.GetTCP4Table
getTCP6Table = iphelper.GetTCP6Table
getUDP4Table = iphelper.GetUDP4Table
getUDP6Table = iphelper.GetUDP6Table
)
// CheckPID checks the if socket info already has a PID and if not, tries to find it.
// Depending on the OS, this might be a no-op.
func CheckPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
return socketInfo.GetPID(), connInbound, nil
}

View File

@@ -0,0 +1,91 @@
package state
import (
"net"
"sync"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
"github.com/safing/portmaster/service/network/socket"
)
const (
minDurationBetweenTableUpdates = 10 * time.Millisecond
)
type tcpTable struct {
version int
connections []*socket.ConnectionInfo
listeners []*socket.BindInfo
lock sync.RWMutex
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
lastUpdateAt atomic.Int64
fetchLimiter *utils.CallLimiter
fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error)
dualStack *tcpTable
}
var (
tcp6Table = &tcpTable{
version: 6,
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
fetchTable: getTCP6Table,
}
tcp4Table = &tcpTable{
version: 4,
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
fetchTable: getTCP4Table,
}
)
// EnableTCPDualStack adds the TCP6 table to the TCP4 table as a dual-stack.
// Must be called before any lookup operation.
func EnableTCPDualStack() {
tcp4Table.dualStack = tcp6Table
}
func (table *tcpTable) getCurrentTables() (
connections []*socket.ConnectionInfo,
listeners []*socket.BindInfo,
) {
table.lock.RLock()
defer table.lock.RUnlock()
return table.connections, table.listeners
}
func (table *tcpTable) updateTables() (
connections []*socket.ConnectionInfo,
listeners []*socket.BindInfo,
) {
// Fetch tables.
table.fetchLimiter.Do(func() {
// Fetch new tables from system.
connections, listeners, err := table.fetchTable()
if err != nil {
log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err)
return
}
// Pre-check for any listeners.
for _, bindInfo := range listeners {
bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero)
}
// Apply new tables.
table.lock.Lock()
defer table.lock.Unlock()
table.connections = connections
table.listeners = listeners
table.lastUpdateAt.Store(time.Now().UnixNano())
})
return table.getCurrentTables()
}

View File

@@ -0,0 +1,210 @@
package state
import (
"context"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/network/socket"
)
type udpTable struct {
version int
binds []*socket.BindInfo
lock sync.RWMutex
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
lastUpdateAt atomic.Int64
fetchLimiter *utils.CallLimiter
fetchTable func() (binds []*socket.BindInfo, err error)
states map[string]map[string]*udpState
statesLock sync.Mutex
dualStack *udpTable
}
type udpState struct {
inbound bool
lastSeen time.Time
}
const (
// UDPConnStateTTL is the maximum time a udp connection state is held.
UDPConnStateTTL = 72 * time.Hour
// UDPConnStateShortenedTTL is a shortened maximum time a udp connection state is held, if there more entries than defined by AggressiveCleaningThreshold.
UDPConnStateShortenedTTL = 3 * time.Hour
// AggressiveCleaningThreshold defines the soft limit of udp connection state held per udp socket.
AggressiveCleaningThreshold = 256
)
var (
udp6Table = &udpTable{
version: 6,
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
fetchTable: getUDP6Table,
states: make(map[string]map[string]*udpState),
}
udp4Table = &udpTable{
version: 4,
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
fetchTable: getUDP4Table,
states: make(map[string]map[string]*udpState),
}
)
// EnableUDPDualStack adds the UDP6 table to the UDP4 table as a dual-stack.
// Must be called before any lookup operation.
func EnableUDPDualStack() {
udp4Table.dualStack = udp6Table
}
func (table *udpTable) getCurrentTables() (binds []*socket.BindInfo) {
table.lock.RLock()
defer table.lock.RUnlock()
return table.binds
}
func (table *udpTable) updateTables() (binds []*socket.BindInfo) {
// Fetch tables.
table.fetchLimiter.Do(func() {
// Fetch new tables from system.
binds, err := table.fetchTable()
if err != nil {
log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err)
return
}
// Pre-check for any listeners.
for _, bindInfo := range binds {
bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero)
}
// Apply new tables.
table.lock.Lock()
defer table.lock.Unlock()
table.binds = binds
table.lastUpdateAt.Store(time.Now().UnixNano())
})
return table.getCurrentTables()
}
// CleanUDPStates cleans the udp connection states which save connection directions.
func CleanUDPStates(_ context.Context) {
now := time.Now().UTC()
udp4Table.updateTables()
udp4Table.cleanStates(now)
if netenv.IPv6Enabled() {
udp6Table.updateTables()
udp6Table.cleanStates(now)
}
}
func (table *udpTable) getConnState(
socketInfo *socket.BindInfo,
remoteAddress socket.Address,
) (udpConnState *udpState, ok bool) {
table.statesLock.Lock()
defer table.statesLock.Unlock()
bindMap, ok := table.states[makeUDPStateKey(socketInfo.Local)]
if ok {
udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)]
return
}
return nil, false
}
func (table *udpTable) getDirection(
socketInfo *socket.BindInfo,
pktInfo *packet.Info,
) (connDirection bool) {
table.statesLock.Lock()
defer table.statesLock.Unlock()
localKey := makeUDPStateKey(socketInfo.Local)
bindMap, ok := table.states[localKey]
if !ok {
bindMap = make(map[string]*udpState)
table.states[localKey] = bindMap
}
remoteKey := makeUDPStateKey(socket.Address{
IP: pktInfo.RemoteIP(),
Port: pktInfo.RemotePort(),
})
udpConnState, ok := bindMap[remoteKey]
if !ok {
bindMap[remoteKey] = &udpState{
inbound: pktInfo.Inbound,
lastSeen: time.Now().UTC(),
}
return pktInfo.Inbound
}
udpConnState.lastSeen = time.Now().UTC()
return udpConnState.inbound
}
func (table *udpTable) cleanStates(now time.Time) {
// compute thresholds
threshold := now.Add(-UDPConnStateTTL)
shortThreshhold := now.Add(-UDPConnStateShortenedTTL)
// make lookup map of all active keys
bindKeys := make(map[string]struct{})
table.lock.RLock()
for _, socketInfo := range table.binds {
bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{}
}
table.lock.RUnlock()
table.statesLock.Lock()
defer table.statesLock.Unlock()
// clean the udp state storage
for localKey, bindMap := range table.states {
if _, active := bindKeys[localKey]; active {
// clean old entries
for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(threshold) {
delete(bindMap, remoteKey)
}
}
// if there are too many clean more aggressively
if len(bindMap) > AggressiveCleaningThreshold {
for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(shortThreshhold) {
delete(bindMap, remoteKey)
}
}
}
} else {
// delete the whole thing
delete(table.states, localKey)
}
}
}
func makeUDPStateKey(address socket.Address) string {
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
return string(address.IP) + strconv.Itoa(int(address.Port))
}

83
service/network/status.go Normal file
View File

@@ -0,0 +1,83 @@
package network
// Verdict describes the decision made about a connection or link.
type Verdict int8
// All possible verdicts that can be applied to a network
// connection.
const (
// VerdictUndecided is the default status of new connections.
VerdictUndecided Verdict = 0
VerdictUndeterminable Verdict = 1
VerdictAccept Verdict = 2
VerdictBlock Verdict = 3
VerdictDrop Verdict = 4
VerdictRerouteToNameserver Verdict = 5
VerdictRerouteToTunnel Verdict = 6
VerdictFailed Verdict = 7
)
func (v Verdict) String() string {
switch v {
case VerdictUndecided:
return "<Undecided>"
case VerdictUndeterminable:
return "<Undeterminable>"
case VerdictAccept:
return "Accept"
case VerdictBlock:
return "Block"
case VerdictDrop:
return "Drop"
case VerdictRerouteToNameserver:
return "RerouteToNameserver"
case VerdictRerouteToTunnel:
return "RerouteToTunnel"
case VerdictFailed:
return "Failed"
default:
return "<INVALID VERDICT>"
}
}
// Verb returns the verdict as a past tense verb.
func (v Verdict) Verb() string {
switch v {
case VerdictUndecided:
return "undecided"
case VerdictUndeterminable:
return "undeterminable"
case VerdictAccept:
return "accepted"
case VerdictBlock:
return "blocked"
case VerdictDrop:
return "dropped"
case VerdictRerouteToNameserver:
return "redirected to nameserver"
case VerdictRerouteToTunnel:
return "tunneled"
case VerdictFailed:
return "failed"
default:
return "invalid"
}
}
// Packet Directions.
const (
Inbound = true
Outbound = false
)
// Non-Domain Scopes.
const (
IncomingHost = "IH"
IncomingLAN = "IL"
IncomingInternet = "II"
IncomingInvalid = "IX"
PeerHost = "PH"
PeerLAN = "PL"
PeerInternet = "PI"
PeerInvalid = "PX"
)