wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
340
service/network/api.go
Normal file
340
service/network/api.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/utils/debug"
|
||||
"github.com/safing/portmaster/service/network/state"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
"github.com/safing/portmaster/service/status"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
)
|
||||
|
||||
func registerAPIEndpoints() error {
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "debug/network",
|
||||
Read: api.PermitUser,
|
||||
BelongsTo: module,
|
||||
DataFunc: debugInfo,
|
||||
Name: "Get Network Debug Information",
|
||||
Description: "Returns network debugging information, similar to debug/core, but with connection data.",
|
||||
Parameters: []api.Parameter{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Field: "style",
|
||||
Value: "github",
|
||||
Description: "Specify the formatting style. The default is simple markdown formatting.",
|
||||
},
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Field: "profile",
|
||||
Value: "<Source>/<ID>",
|
||||
Description: "Specify a profile source and ID for which network connection should be reported.",
|
||||
},
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Field: "where",
|
||||
Value: "<query>",
|
||||
Description: "Specify a query to limit the connections included in the report. The default is to include all connections.",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "debug/network/state",
|
||||
Read: api.PermitUser,
|
||||
BelongsTo: module,
|
||||
StructFunc: func(ar *api.Request) (i interface{}, err error) {
|
||||
return state.GetInfo(), nil
|
||||
},
|
||||
Name: "Get Network State Table Data",
|
||||
Description: "Returns the current network state tables from the OS.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugInfo returns the debugging information for support requests.
|
||||
func debugInfo(ar *api.Request) (data []byte, err error) {
|
||||
// Create debug information helper.
|
||||
di := new(debug.Info)
|
||||
di.Style = ar.Request.URL.Query().Get("style")
|
||||
|
||||
// Add debug information.
|
||||
|
||||
// Very basic information at the start.
|
||||
di.AddVersionInfo()
|
||||
di.AddPlatformInfo(ar.Context())
|
||||
|
||||
// Errors and unexpected logs.
|
||||
di.AddLastReportedModuleError()
|
||||
di.AddLastUnexpectedLogs()
|
||||
|
||||
// Network Connections.
|
||||
AddNetworkDebugData(
|
||||
di,
|
||||
ar.Request.URL.Query().Get("profile"),
|
||||
ar.Request.URL.Query().Get("where"),
|
||||
)
|
||||
|
||||
// Status Information from various modules.
|
||||
status.AddToDebugInfo(di)
|
||||
// captain.AddToDebugInfo(di) // TODO: Cannot use due to import loop.
|
||||
resolver.AddToDebugInfo(di)
|
||||
config.AddToDebugInfo(di)
|
||||
|
||||
// Detailed information.
|
||||
updates.AddToDebugInfo(di)
|
||||
// compat.AddToDebugInfo(di) // TODO: Cannot use due to interception import requirement which we don't want for SPN Hubs.
|
||||
di.AddGoroutineStack()
|
||||
|
||||
// Return data.
|
||||
return di.Bytes(), nil
|
||||
}
|
||||
|
||||
// AddNetworkDebugData adds the network debug data of the given profile to the debug data.
|
||||
func AddNetworkDebugData(di *debug.Info, profile, where string) {
|
||||
// Prepend where prefix to query if necessary.
|
||||
if where != "" && !strings.HasPrefix(where, "where ") {
|
||||
where = "where " + where
|
||||
}
|
||||
|
||||
// Build query.
|
||||
q, err := query.ParseQuery("query network: " + where)
|
||||
if err != nil {
|
||||
di.AddSection(
|
||||
"Network: Debug Failed",
|
||||
debug.NoFlags,
|
||||
fmt.Sprintf("Failed to build query: %s", err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Get iterator.
|
||||
it, err := dbController.Query(q, true, true)
|
||||
if err != nil {
|
||||
di.AddSection(
|
||||
"Network: Debug Failed",
|
||||
debug.NoFlags,
|
||||
fmt.Sprintf("Failed to run query: %s", err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Collect matching connections.
|
||||
var ( //nolint:prealloc // We don't know the size.
|
||||
debugConns []*Connection
|
||||
accepted int
|
||||
total int
|
||||
)
|
||||
|
||||
for maybeConn := range it.Next {
|
||||
// Switch to correct type.
|
||||
conn, ok := maybeConn.(*Connection)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the profile matches
|
||||
if profile != "" {
|
||||
found := false
|
||||
|
||||
// Get layer IDs and search for a match.
|
||||
layerIDs := conn.Process().Profile().LayerIDs
|
||||
for _, layerID := range layerIDs {
|
||||
if profile == layerID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Skip if the profile does not match.
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Count.
|
||||
total++
|
||||
switch conn.Verdict { //nolint:exhaustive
|
||||
case VerdictAccept,
|
||||
VerdictRerouteToNameserver,
|
||||
VerdictRerouteToTunnel:
|
||||
|
||||
accepted++
|
||||
}
|
||||
|
||||
// Add to list.
|
||||
debugConns = append(debugConns, conn)
|
||||
}
|
||||
|
||||
// Add it all.
|
||||
di.AddSection(
|
||||
fmt.Sprintf(
|
||||
"Network: %d/%d Connections",
|
||||
accepted,
|
||||
total,
|
||||
),
|
||||
debug.UseCodeSection|debug.AddContentLineBreaks,
|
||||
buildNetworkDebugInfoData(debugConns),
|
||||
)
|
||||
}
|
||||
|
||||
func buildNetworkDebugInfoData(debugConns []*Connection) string {
|
||||
// Sort
|
||||
sort.Sort(connectionsByGroup(debugConns))
|
||||
|
||||
// Format lines
|
||||
var buf strings.Builder
|
||||
currentPID := process.UndefinedProcessID
|
||||
for _, conn := range debugConns {
|
||||
conn.Lock()
|
||||
|
||||
// Add process infomration if it differs from previous connection.
|
||||
if currentPID != conn.ProcessContext.PID {
|
||||
if currentPID != process.UndefinedProcessID {
|
||||
buf.WriteString("\n\n\n")
|
||||
}
|
||||
buf.WriteString("ProfileName: " + conn.ProcessContext.ProfileName)
|
||||
buf.WriteString("\nProfile: " + conn.ProcessContext.Profile)
|
||||
buf.WriteString("\nSource: " + conn.ProcessContext.Source)
|
||||
buf.WriteString("\nProcessName: " + conn.ProcessContext.ProcessName)
|
||||
buf.WriteString("\nBinaryPath: " + conn.ProcessContext.BinaryPath)
|
||||
buf.WriteString("\nCmdLine: " + conn.ProcessContext.CmdLine)
|
||||
buf.WriteString("\nPID: " + strconv.Itoa(conn.ProcessContext.PID))
|
||||
buf.WriteString("\n")
|
||||
|
||||
// Set current PID in order to not print the process information again.
|
||||
currentPID = conn.ProcessContext.PID
|
||||
}
|
||||
|
||||
// Add connection.
|
||||
buf.WriteString("\n")
|
||||
buf.WriteString(conn.debugInfoLine())
|
||||
|
||||
conn.Unlock()
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (conn *Connection) debugInfoLine() string {
|
||||
var connectionData string
|
||||
if conn.Type == IPConnection {
|
||||
// Format IP/Port pair for connections.
|
||||
connectionData = fmt.Sprintf(
|
||||
"% 15s:%- 5s %s % 15s:%- 5s",
|
||||
conn.LocalIP,
|
||||
strconv.Itoa(int(conn.LocalPort)),
|
||||
conn.fmtProtocolAndDirectionComponent(conn.IPProtocol.String()),
|
||||
conn.Entity.IP,
|
||||
strconv.Itoa(int(conn.Entity.Port)),
|
||||
)
|
||||
} else {
|
||||
// Leave empty for DNS Requests.
|
||||
connectionData = " "
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"% 14s %s%- 25s %s-%s P#%d [%s] %s - by %s @ %s",
|
||||
conn.VerdictVerb(),
|
||||
connectionData,
|
||||
conn.fmtDomainComponent(),
|
||||
time.Unix(conn.Started, 0).Format("15:04:05"),
|
||||
conn.fmtEndTimeComponent(),
|
||||
conn.ProcessContext.PID,
|
||||
conn.fmtFlagsComponent(),
|
||||
conn.Reason.Msg,
|
||||
conn.Reason.OptionKey,
|
||||
conn.fmtReasonProfileComponent(),
|
||||
)
|
||||
}
|
||||
|
||||
func (conn *Connection) fmtDomainComponent() string {
|
||||
if conn.Entity.Domain != "" {
|
||||
return " to " + conn.Entity.Domain
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (conn *Connection) fmtProtocolAndDirectionComponent(protocol string) string {
|
||||
if conn.Inbound {
|
||||
return "<" + protocol
|
||||
}
|
||||
return protocol + ">"
|
||||
}
|
||||
|
||||
func (conn *Connection) fmtFlagsComponent() string {
|
||||
var f string
|
||||
|
||||
if conn.Internal {
|
||||
f += "I"
|
||||
}
|
||||
if conn.Encrypted {
|
||||
f += "E"
|
||||
}
|
||||
if conn.Tunneled {
|
||||
f += "T"
|
||||
}
|
||||
if len(conn.activeInspectors) > 0 {
|
||||
f += "A"
|
||||
}
|
||||
if conn.addedToMetrics {
|
||||
f += "M"
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (conn *Connection) fmtEndTimeComponent() string {
|
||||
if conn.Ended == 0 {
|
||||
return " " // Use same width as a timestamp.
|
||||
}
|
||||
return time.Unix(conn.Ended, 0).Format("15:04:05")
|
||||
}
|
||||
|
||||
func (conn *Connection) fmtReasonProfileComponent() string {
|
||||
if conn.Reason.Profile == "" {
|
||||
return "global"
|
||||
}
|
||||
return conn.Reason.Profile
|
||||
}
|
||||
|
||||
type connectionsByGroup []*Connection
|
||||
|
||||
func (a connectionsByGroup) Len() int { return len(a) }
|
||||
func (a connectionsByGroup) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a connectionsByGroup) Less(i, j int) bool {
|
||||
// Sort by:
|
||||
|
||||
// 1. Profile ID
|
||||
if a[i].ProcessContext.Profile != a[j].ProcessContext.Profile {
|
||||
return a[i].ProcessContext.Profile < a[j].ProcessContext.Profile
|
||||
}
|
||||
|
||||
// 2. Process Binary
|
||||
if a[i].ProcessContext.BinaryPath != a[j].ProcessContext.BinaryPath {
|
||||
return a[i].ProcessContext.BinaryPath < a[j].ProcessContext.BinaryPath
|
||||
}
|
||||
|
||||
// 3. Process ID
|
||||
if a[i].ProcessContext.PID != a[j].ProcessContext.PID {
|
||||
return a[i].ProcessContext.PID < a[j].ProcessContext.PID
|
||||
}
|
||||
|
||||
// 4. Started
|
||||
return a[i].Started < a[j].Started
|
||||
}
|
||||
149
service/network/api_test.go
Normal file
149
service/network/api_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/service/intel"
|
||||
)
|
||||
|
||||
func TestDebugInfoLineFormatting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, conn := range connectionTestData {
|
||||
fmt.Println(conn.debugInfoLine())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDebugInfoFormatting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fmt.Println(buildNetworkDebugInfoData(connectionTestData))
|
||||
}
|
||||
|
||||
var connectionTestData = []*Connection{
|
||||
{
|
||||
ID: "17-255.255.255.255-29810-192.168.0.23-40672",
|
||||
Scope: "IL",
|
||||
IPVersion: 4,
|
||||
Inbound: true,
|
||||
IPProtocol: 17,
|
||||
LocalIP: net.ParseIP("255.255.255.255"),
|
||||
LocalPort: 29810,
|
||||
Entity: &intel.Entity{
|
||||
Protocol: 17,
|
||||
Port: 40672,
|
||||
Domain: "",
|
||||
ReverseDomain: "",
|
||||
IP: net.ParseIP("192.168.0.23"),
|
||||
Country: "",
|
||||
ASN: 0,
|
||||
},
|
||||
Verdict: 2,
|
||||
Reason: Reason{
|
||||
Msg: "incoming connection blocked by default",
|
||||
OptionKey: "filter/serviceEndpoints",
|
||||
Profile: "",
|
||||
},
|
||||
Started: 1614010349,
|
||||
Ended: 1614010350,
|
||||
VerdictPermanent: true,
|
||||
Inspecting: false,
|
||||
Tunneled: false,
|
||||
Encrypted: false,
|
||||
ProcessContext: ProcessContext{
|
||||
ProcessName: "Unidentified Processes",
|
||||
ProfileName: "Unidentified Processes",
|
||||
BinaryPath: "",
|
||||
PID: -1,
|
||||
Profile: "_unidentified",
|
||||
Source: "local",
|
||||
},
|
||||
Internal: false,
|
||||
ProfileRevisionCounter: 1,
|
||||
},
|
||||
{
|
||||
ID: "6-192.168.0.176-55216-13.32.6.15-80",
|
||||
Scope: "PI",
|
||||
IPVersion: 4,
|
||||
Inbound: false,
|
||||
IPProtocol: 6,
|
||||
LocalIP: net.ParseIP("192.168.0.176"),
|
||||
LocalPort: 55216,
|
||||
Entity: &intel.Entity{
|
||||
Protocol: 6,
|
||||
Port: 80,
|
||||
Domain: "",
|
||||
ReverseDomain: "",
|
||||
IP: net.ParseIP("13.32.6.15"),
|
||||
Country: "DE",
|
||||
ASN: 16509,
|
||||
},
|
||||
Verdict: 2,
|
||||
Reason: Reason{
|
||||
Msg: "default permit",
|
||||
OptionKey: "filter/defaultAction",
|
||||
Profile: "",
|
||||
},
|
||||
Started: 1614010475,
|
||||
Ended: 1614010565,
|
||||
VerdictPermanent: true,
|
||||
Inspecting: false,
|
||||
Tunneled: false,
|
||||
Encrypted: false,
|
||||
ProcessContext: ProcessContext{
|
||||
ProcessName: "NetworkManager",
|
||||
ProfileName: "Network Manager",
|
||||
BinaryPath: "/usr/sbin/NetworkManager",
|
||||
PID: 1273,
|
||||
Profile: "3a9b0eb5-c7fe-4bc7-9b93-a90f4ff84b5b",
|
||||
Source: "local",
|
||||
},
|
||||
Internal: true,
|
||||
ProfileRevisionCounter: 1,
|
||||
},
|
||||
{
|
||||
ID: "6-192.168.0.176-49982-142.250.74.211-443",
|
||||
Scope: "pkg.go.dev.",
|
||||
IPVersion: 4,
|
||||
Inbound: false,
|
||||
IPProtocol: 6,
|
||||
LocalIP: net.ParseIP("192.168.0.176"),
|
||||
LocalPort: 49982,
|
||||
Entity: &intel.Entity{
|
||||
Protocol: 6,
|
||||
Port: 443,
|
||||
Domain: "pkg.go.dev.",
|
||||
ReverseDomain: "",
|
||||
CNAME: []string{
|
||||
"ghs.googlehosted.com.",
|
||||
},
|
||||
IP: net.ParseIP("142.250.74.211"),
|
||||
Country: "US",
|
||||
ASN: 15169,
|
||||
},
|
||||
Verdict: 2,
|
||||
Reason: Reason{
|
||||
Msg: "default permit",
|
||||
OptionKey: "filter/defaultAction",
|
||||
Profile: "",
|
||||
},
|
||||
Started: 1614010415,
|
||||
Ended: 1614010745,
|
||||
VerdictPermanent: true,
|
||||
Inspecting: false,
|
||||
Tunneled: false,
|
||||
Encrypted: false,
|
||||
ProcessContext: ProcessContext{
|
||||
ProcessName: "firefox",
|
||||
ProfileName: "Firefox",
|
||||
BinaryPath: "/usr/bin/firefox",
|
||||
PID: 5710,
|
||||
Profile: "74b30392-9e4d-4157-83a9-fffafd3e2bde",
|
||||
Source: "local",
|
||||
},
|
||||
Internal: false,
|
||||
ProfileRevisionCounter: 1,
|
||||
},
|
||||
}
|
||||
132
service/network/clean.go
Normal file
132
service/network/clean.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/network/state"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
)
|
||||
|
||||
const (
|
||||
// DeleteConnsAfterEndedThreshold defines the amount of time after which
|
||||
// ended connections should be removed from the internal connection state.
|
||||
DeleteConnsAfterEndedThreshold = 10 * time.Minute
|
||||
|
||||
// DeleteIncompleteConnsAfterStartedThreshold defines the amount of time after
|
||||
// which incomplete connections should be removed from the internal
|
||||
// connection state.
|
||||
DeleteIncompleteConnsAfterStartedThreshold = 1 * time.Minute
|
||||
|
||||
cleanerTickDuration = 5 * time.Second
|
||||
)
|
||||
|
||||
func connectionCleaner(ctx context.Context) error {
|
||||
ticker := module.NewSleepyTicker(cleanerTickDuration, 0)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return nil
|
||||
case <-ticker.Wait():
|
||||
// clean connections and processes
|
||||
activePIDs := cleanConnections()
|
||||
process.CleanProcessStorage(activePIDs)
|
||||
|
||||
// clean udp connection states
|
||||
state.CleanUDPStates(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanConnections() (activePIDs map[int]struct{}) {
|
||||
activePIDs = make(map[int]struct{})
|
||||
|
||||
_ = module.RunMicroTask("clean connections", 0, func(ctx context.Context) error {
|
||||
now := time.Now().UTC()
|
||||
nowUnix := now.Unix()
|
||||
ignoreNewer := nowUnix - 1
|
||||
deleteOlderThan := now.Add(-DeleteConnsAfterEndedThreshold).Unix()
|
||||
deleteIncompleteOlderThan := now.Add(-DeleteIncompleteConnsAfterStartedThreshold).Unix()
|
||||
|
||||
// network connections
|
||||
for _, conn := range conns.clone() {
|
||||
conn.Lock()
|
||||
|
||||
// delete inactive connections
|
||||
switch {
|
||||
case conn.Started >= ignoreNewer:
|
||||
// Skip very fresh connections to evade edge cases.
|
||||
case !conn.DataIsComplete():
|
||||
// Step 0: delete old incomplete connections
|
||||
if conn.Started < deleteIncompleteOlderThan {
|
||||
// Stop the firewall handler, in case one is running.
|
||||
conn.StopFirewallHandler()
|
||||
// Remove connection from state.
|
||||
conn.delete()
|
||||
}
|
||||
case conn.Ended == 0:
|
||||
// Step 1: check if still active
|
||||
exists := state.Exists(&packet.Info{
|
||||
Inbound: false, // src == local
|
||||
Version: conn.IPVersion,
|
||||
Protocol: conn.IPProtocol,
|
||||
Src: conn.LocalIP,
|
||||
SrcPort: conn.LocalPort,
|
||||
Dst: conn.Entity.IP,
|
||||
DstPort: conn.Entity.Port,
|
||||
PID: process.UndefinedProcessID,
|
||||
SeenAt: time.Unix(conn.Started, 0), // State tables will be updated if older than this.
|
||||
}, now)
|
||||
|
||||
// Step 2: mark as ended
|
||||
if !exists {
|
||||
conn.Ended = nowUnix
|
||||
|
||||
// Stop the firewall handler, in case one is running.
|
||||
conn.StopFirewallHandler()
|
||||
|
||||
// Save to database.
|
||||
conn.Save()
|
||||
}
|
||||
|
||||
// If the connection has an associated process, add its PID to the active PID list.
|
||||
if conn.process != nil {
|
||||
activePIDs[conn.process.Pid] = struct{}{}
|
||||
}
|
||||
case conn.Ended < deleteOlderThan:
|
||||
// Step 3: delete
|
||||
// DEBUG:
|
||||
// log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0))
|
||||
|
||||
// Remove connection from state.
|
||||
conn.delete()
|
||||
}
|
||||
|
||||
conn.Unlock()
|
||||
}
|
||||
|
||||
// dns requests
|
||||
for _, conn := range dnsConns.clone() {
|
||||
conn.Lock()
|
||||
|
||||
// delete old dns connections
|
||||
if conn.Ended < deleteOlderThan {
|
||||
log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0))
|
||||
conn.delete()
|
||||
}
|
||||
|
||||
conn.Unlock()
|
||||
}
|
||||
|
||||
// rerouted dns requests
|
||||
cleanDNSRequestConnections()
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return activePIDs
|
||||
}
|
||||
1068
service/network/connection.go
Normal file
1068
service/network/connection.go
Normal file
File diff suppressed because it is too large
Load Diff
68
service/network/connection_android.go
Normal file
68
service/network/connection_android.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/intel"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// NewDefaultConnection creates a new connection with default values except local and remote IPs and protocols.
|
||||
func NewDefaultConnection(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, ipVersion packet.IPVersion, protocol packet.IPProtocol) *Connection {
|
||||
connInfo := &Connection{
|
||||
ID: fmt.Sprintf("%s-%s-%d-%s-%d", protocol.String(), localIP, localPort, remoteIP, remotePort),
|
||||
Type: IPConnection,
|
||||
External: false,
|
||||
IPVersion: ipVersion,
|
||||
Inbound: false,
|
||||
IPProtocol: protocol,
|
||||
LocalIP: localIP,
|
||||
LocalIPScope: netutils.Global,
|
||||
LocalPort: localPort,
|
||||
PID: process.UnidentifiedProcessID,
|
||||
Entity: (&intel.Entity{
|
||||
IP: remoteIP,
|
||||
Protocol: uint8(protocol),
|
||||
Port: remotePort,
|
||||
}).Init(0),
|
||||
Resolver: nil,
|
||||
Started: time.Now().Unix(),
|
||||
VerdictPermanent: false,
|
||||
Tunneled: true,
|
||||
Encrypted: false,
|
||||
DataComplete: abool.NewBool(true),
|
||||
Internal: false,
|
||||
addedToMetrics: true, // Metrics are not needed for now. This will mark the Connection to be ignored.
|
||||
process: process.GetUnidentifiedProcess(context.Background()),
|
||||
}
|
||||
|
||||
// TODO: Quick fix for the SPN.
|
||||
// Use inspection framework for proper encryption detection.
|
||||
switch connInfo.Entity.DstPort() {
|
||||
case
|
||||
22, // SSH
|
||||
443, // HTTPS
|
||||
465, // SMTP-SSL
|
||||
853, // DoT
|
||||
993, // IMAP-SSL
|
||||
995: // POP3-SSL
|
||||
connInfo.Encrypted = true
|
||||
}
|
||||
|
||||
var layeredProfile = connInfo.process.Profile()
|
||||
connInfo.TunnelOpts = &navigator.Options{
|
||||
HubPolicies: layeredProfile.StackedExitHubPolicies(),
|
||||
CheckHubExitPolicyWith: connInfo.Entity,
|
||||
RequireTrustedDestinationHubs: !connInfo.Encrypted,
|
||||
RoutingProfile: layeredProfile.SPNRoutingAlgorithm(),
|
||||
}
|
||||
|
||||
return connInfo
|
||||
}
|
||||
97
service/network/connection_store.go
Normal file
97
service/network/connection_store.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type connectionStore struct {
|
||||
rw sync.RWMutex
|
||||
items map[string]*Connection
|
||||
}
|
||||
|
||||
func newConnectionStore() *connectionStore {
|
||||
return &connectionStore{
|
||||
items: make(map[string]*Connection, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *connectionStore) add(conn *Connection) {
|
||||
cs.rw.Lock()
|
||||
defer cs.rw.Unlock()
|
||||
|
||||
cs.items[conn.ID] = conn
|
||||
}
|
||||
|
||||
func (cs *connectionStore) delete(conn *Connection) {
|
||||
cs.rw.Lock()
|
||||
defer cs.rw.Unlock()
|
||||
|
||||
delete(cs.items, conn.ID)
|
||||
}
|
||||
|
||||
func (cs *connectionStore) get(id string) (*Connection, bool) {
|
||||
cs.rw.RLock()
|
||||
defer cs.rw.RUnlock()
|
||||
|
||||
conn, ok := cs.items[id]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// findByPrefix returns the first connection where the key matches the given prefix.
|
||||
// If the prefix matches multiple entries, the result is not deterministic.
|
||||
func (cs *connectionStore) findByPrefix(prefix string) (*Connection, bool) { //nolint:unused
|
||||
cs.rw.RLock()
|
||||
defer cs.rw.RUnlock()
|
||||
|
||||
for key, conn := range cs.items {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
return conn, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (cs *connectionStore) clone() map[string]*Connection {
|
||||
cs.rw.RLock()
|
||||
defer cs.rw.RUnlock()
|
||||
|
||||
m := make(map[string]*Connection, len(cs.items))
|
||||
for key, conn := range cs.items {
|
||||
m[key] = conn
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (cs *connectionStore) list() []*Connection {
|
||||
cs.rw.RLock()
|
||||
defer cs.rw.RUnlock()
|
||||
|
||||
l := make([]*Connection, 0, len(cs.items))
|
||||
for _, conn := range cs.items {
|
||||
l = append(l, conn)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (cs *connectionStore) len() int { //nolint:unused // TODO: Clean up if still unused.
|
||||
cs.rw.RLock()
|
||||
defer cs.rw.RUnlock()
|
||||
|
||||
return len(cs.items)
|
||||
}
|
||||
|
||||
func (cs *connectionStore) active() int {
|
||||
// Clone and count all active connections.
|
||||
var cnt int
|
||||
for _, conn := range cs.clone() {
|
||||
conn.Lock()
|
||||
if conn.Ended != 0 {
|
||||
cnt++
|
||||
}
|
||||
conn.Unlock()
|
||||
}
|
||||
|
||||
return cnt
|
||||
}
|
||||
202
service/network/database.go
Normal file
202
service/network/database.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/iterator"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/database/storage"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
)
|
||||
|
||||
const (
|
||||
dbScopeNone = ""
|
||||
dbScopeDNS = "dns"
|
||||
dbScopeIP = "ip"
|
||||
)
|
||||
|
||||
var (
|
||||
dbController *database.Controller
|
||||
|
||||
dnsConns = newConnectionStore()
|
||||
conns = newConnectionStore()
|
||||
)
|
||||
|
||||
// StorageInterface provices a storage.Interface to the
|
||||
// configuration manager.
|
||||
type StorageInterface struct {
|
||||
storage.InjectBase
|
||||
}
|
||||
|
||||
// Database prefixes:
|
||||
// Processes: network:tree/<PID>
|
||||
// DNS Requests: network:tree/<PID>/dns/<ID>
|
||||
// IP Connections: network:tree/<PID>/ip/<ID>
|
||||
|
||||
func makeKey(pid int, scope, id string) string {
|
||||
if scope == "" {
|
||||
return "network:tree/" + strconv.Itoa(pid)
|
||||
}
|
||||
return fmt.Sprintf("network:tree/%d/%s/%s", pid, scope, id)
|
||||
}
|
||||
|
||||
func parseDBKey(key string) (processKey string, scope, id string, ok bool) {
|
||||
// Split into segments.
|
||||
segments := strings.Split(key, "/")
|
||||
|
||||
// Keys have 2 or 4 segments.
|
||||
switch len(segments) {
|
||||
case 4:
|
||||
id = segments[3]
|
||||
|
||||
fallthrough
|
||||
case 3:
|
||||
scope = segments[2]
|
||||
// Sanity check.
|
||||
switch scope {
|
||||
case dbScopeNone, dbScopeDNS, dbScopeIP:
|
||||
// Parsed id matches possible values.
|
||||
// The empty string is for matching a trailing slash for in query prefix.
|
||||
// TODO: For queries, also prefixes of these values are valid.
|
||||
default:
|
||||
// Unknown scope.
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
fallthrough
|
||||
case 2:
|
||||
processKey = segments[1]
|
||||
return processKey, scope, id, true
|
||||
case 1:
|
||||
// This is a valid query prefix, but not process ID was given.
|
||||
return "", "", "", true
|
||||
default:
|
||||
return "", "", "", false
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (s *StorageInterface) Get(key string) (record.Record, error) {
|
||||
// Parse key and check if valid.
|
||||
pid, scope, id, ok := parseDBKey(strings.TrimPrefix(key, "network:"))
|
||||
if !ok || pid == "" {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
switch scope {
|
||||
case dbScopeDNS:
|
||||
if c, ok := dnsConns.get(id); ok && c.DataIsComplete() {
|
||||
return c, nil
|
||||
}
|
||||
case dbScopeIP:
|
||||
if c, ok := conns.get(id); ok && c.DataIsComplete() {
|
||||
return c, nil
|
||||
}
|
||||
case dbScopeNone:
|
||||
if proc, ok := process.GetProcessFromStorage(pid); ok {
|
||||
return proc, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
// Query returns a an iterator for the supplied query.
|
||||
func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
it := iterator.New()
|
||||
|
||||
module.StartWorker("connection query", func(_ context.Context) error {
|
||||
s.processQuery(q, it)
|
||||
return nil
|
||||
})
|
||||
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
|
||||
var matches bool
|
||||
pid, scope, _, ok := parseDBKey(q.DatabaseKeyPrefix())
|
||||
if !ok {
|
||||
it.Finish(nil)
|
||||
return
|
||||
}
|
||||
|
||||
if pid == "" {
|
||||
// processes
|
||||
for _, proc := range process.All() {
|
||||
func() {
|
||||
proc.Lock()
|
||||
defer proc.Unlock()
|
||||
matches = q.Matches(proc)
|
||||
}()
|
||||
if matches {
|
||||
it.Next <- proc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if scope == dbScopeNone || scope == dbScopeDNS {
|
||||
// dns scopes only
|
||||
for _, dnsConn := range dnsConns.clone() {
|
||||
if !dnsConn.DataIsComplete() {
|
||||
continue
|
||||
}
|
||||
|
||||
func() {
|
||||
dnsConn.Lock()
|
||||
defer dnsConn.Unlock()
|
||||
matches = q.Matches(dnsConn)
|
||||
}()
|
||||
|
||||
if matches {
|
||||
it.Next <- dnsConn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if scope == dbScopeNone || scope == dbScopeIP {
|
||||
// connections
|
||||
for _, conn := range conns.clone() {
|
||||
if !conn.DataIsComplete() {
|
||||
continue
|
||||
}
|
||||
|
||||
func() {
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
matches = q.Matches(conn)
|
||||
}()
|
||||
|
||||
if matches {
|
||||
it.Next <- conn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it.Finish(nil)
|
||||
}
|
||||
|
||||
func registerAsDatabase() error {
|
||||
_, err := database.Register(&database.Database{
|
||||
Name: "network",
|
||||
Description: "Network and Firewall Data",
|
||||
StorageType: "injected",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
controller, err := database.InjectDatabase("network", &StorageInterface{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbController = controller
|
||||
process.SetDBController(dbController)
|
||||
return nil
|
||||
}
|
||||
258
service/network/dns.go
Normal file
258
service/network/dns.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/nameserver/nsutil"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
)
|
||||
|
||||
var (
|
||||
dnsRequestConnections = make(map[string]*Connection) // key: <protocol>-<local ip>-<local port>
|
||||
dnsRequestConnectionsLock sync.RWMutex
|
||||
|
||||
openDNSRequests = make(map[string]*Connection) // key: <pid>/<fqdn>
|
||||
openDNSRequestsLock sync.Mutex
|
||||
|
||||
supportedDomainToIPRecordTypes = []uint16{
|
||||
dns.TypeA,
|
||||
dns.TypeAAAA,
|
||||
dns.TypeSVCB,
|
||||
dns.TypeHTTPS,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
// writeOpenDNSRequestsTickDuration defines the interval in which open dns
|
||||
// requests are written.
|
||||
writeOpenDNSRequestsTickDuration = 5 * time.Second
|
||||
|
||||
// openDNSRequestLimit defines the duration after which DNS requests without
|
||||
// a following connection are logged.
|
||||
openDNSRequestLimit = 3 * time.Second
|
||||
)
|
||||
|
||||
func getDNSRequestConnectionKey(packetInfo *packet.Info) (id string, ok bool) {
|
||||
// We only support protocols with ports.
|
||||
if packetInfo.SrcPort == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d-%s-%d", packetInfo.Protocol, packetInfo.Src, packetInfo.SrcPort), true
|
||||
}
|
||||
|
||||
// SaveDNSRequestConnection saves a dns request connection for later retrieval.
|
||||
func SaveDNSRequestConnection(conn *Connection, pkt packet.Packet) {
|
||||
// Check connection.
|
||||
if conn.PID == process.UndefinedProcessID || conn.PID == process.SystemProcessID {
|
||||
// When re-injecting packets on Windows, they are reported with kernel PID (4).
|
||||
log.Tracer(pkt.Ctx()).Tracef("network: not saving dns request connection because the PID is undefined/kernel")
|
||||
return
|
||||
}
|
||||
|
||||
// Create key.
|
||||
key, ok := getDNSRequestConnectionKey(pkt.Info())
|
||||
if !ok {
|
||||
log.Tracer(pkt.Ctx()).Debugf("network: not saving dns request connection %s because the protocol is not supported", pkt)
|
||||
return
|
||||
}
|
||||
|
||||
// Add or update DNS request connection.
|
||||
log.Tracer(pkt.Ctx()).Tracef("network: saving %s with PID %d as dns request connection for fast DNS request attribution", pkt, conn.PID)
|
||||
dnsRequestConnectionsLock.Lock()
|
||||
defer dnsRequestConnectionsLock.Unlock()
|
||||
dnsRequestConnections[key] = conn
|
||||
}
|
||||
|
||||
// GetDNSRequestConnection returns a saved dns request connection.
|
||||
func GetDNSRequestConnection(packetInfo *packet.Info) (conn *Connection, ok bool) {
|
||||
// Make key.
|
||||
key, ok := getDNSRequestConnectionKey(packetInfo)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Get and return
|
||||
dnsRequestConnectionsLock.RLock()
|
||||
defer dnsRequestConnectionsLock.RUnlock()
|
||||
|
||||
conn, ok = dnsRequestConnections[key]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// deleteDNSRequestConnection removes a connection from the dns request connections.
|
||||
func deleteDNSRequestConnection(packetInfo *packet.Info) { //nolint:unused,deadcode
|
||||
dnsRequestConnectionsLock.Lock()
|
||||
defer dnsRequestConnectionsLock.Unlock()
|
||||
|
||||
key, ok := getDNSRequestConnectionKey(packetInfo)
|
||||
if ok {
|
||||
delete(dnsRequestConnections, key)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanDNSRequestConnections deletes old DNS request connections.
|
||||
func cleanDNSRequestConnections() {
|
||||
deleteOlderThan := time.Now().Unix() - 3
|
||||
|
||||
dnsRequestConnectionsLock.Lock()
|
||||
defer dnsRequestConnectionsLock.Unlock()
|
||||
|
||||
for key, conn := range dnsRequestConnections {
|
||||
conn.Lock()
|
||||
|
||||
if conn.Ended > 0 && conn.Ended < deleteOlderThan {
|
||||
delete(dnsRequestConnections, key)
|
||||
}
|
||||
|
||||
conn.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// IsSupportDNSRecordType returns whether the given DSN RR type is supported
|
||||
// by the network package, as in the requests are specially handled and can be
|
||||
// "merged" into the resulting connection.
|
||||
func IsSupportDNSRecordType(rrType uint16) bool {
|
||||
return slices.Contains[[]uint16, uint16](supportedDomainToIPRecordTypes, rrType)
|
||||
}
|
||||
|
||||
func getDNSRequestCacheKey(pid int, fqdn string, qType uint16) string {
|
||||
return strconv.Itoa(pid) + "/" + fqdn + dns.Type(qType).String()
|
||||
}
|
||||
|
||||
func removeOpenDNSRequest(pid int, fqdn string) {
|
||||
openDNSRequestsLock.Lock()
|
||||
defer openDNSRequestsLock.Unlock()
|
||||
|
||||
// Delete PID-specific requests.
|
||||
for _, dnsType := range supportedDomainToIPRecordTypes {
|
||||
delete(openDNSRequests, getDNSRequestCacheKey(pid, fqdn, dnsType))
|
||||
}
|
||||
|
||||
// If process is known, also check for non-attributed requests.
|
||||
if pid != process.UnidentifiedProcessID {
|
||||
for _, dnsType := range supportedDomainToIPRecordTypes {
|
||||
delete(openDNSRequests, getDNSRequestCacheKey(process.UnidentifiedProcessID, fqdn, dnsType))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveOpenDNSRequest saves a dns request connection that was allowed to proceed.
|
||||
func SaveOpenDNSRequest(q *resolver.Query, rrCache *resolver.RRCache, conn *Connection) {
|
||||
// Only save requests that actually went out (or triggered an async resolve) to reduce clutter.
|
||||
if rrCache == nil || (rrCache.ServedFromCache && !rrCache.RequestingNew) {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to "merge" supported requests into the resulting connection.
|
||||
// Save others immediately.
|
||||
if !IsSupportDNSRecordType(uint16(q.QType)) {
|
||||
conn.Save()
|
||||
return
|
||||
}
|
||||
|
||||
openDNSRequestsLock.Lock()
|
||||
defer openDNSRequestsLock.Unlock()
|
||||
|
||||
// Do not check for an existing open DNS request, as duplicates in such quick
|
||||
// succession are not worth keeping.
|
||||
// DNS queries are usually retried pretty quick.
|
||||
|
||||
// Save to open dns requests.
|
||||
key := getDNSRequestCacheKey(conn.process.Pid, conn.Entity.Domain, uint16(q.QType))
|
||||
openDNSRequests[key] = conn
|
||||
}
|
||||
|
||||
func openDNSRequestWriter(ctx context.Context) error {
|
||||
ticker := module.NewSleepyTicker(writeOpenDNSRequestsTickDuration, 0)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-ticker.Wait():
|
||||
writeOpenDNSRequestsToDB()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeOpenDNSRequestsToDB() {
|
||||
openDNSRequestsLock.Lock()
|
||||
defer openDNSRequestsLock.Unlock()
|
||||
|
||||
threshold := time.Now().Add(-openDNSRequestLimit).Unix()
|
||||
for id, conn := range openDNSRequests {
|
||||
func() {
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
|
||||
if conn.Ended < threshold {
|
||||
conn.Save()
|
||||
delete(openDNSRequests, id)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// ReplyWithDNS creates a new reply to the given request with the data from the RRCache, and additional informational records.
|
||||
func (conn *Connection) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns.Msg {
|
||||
// Select request responder.
|
||||
switch conn.Verdict {
|
||||
case VerdictBlock:
|
||||
return nsutil.BlockIP().ReplyWithDNS(ctx, request)
|
||||
case VerdictDrop:
|
||||
return nil // Do not respond to request.
|
||||
case VerdictFailed:
|
||||
return nsutil.BlockIP().ReplyWithDNS(ctx, request)
|
||||
case VerdictUndecided, VerdictUndeterminable,
|
||||
VerdictAccept, VerdictRerouteToNameserver, VerdictRerouteToTunnel:
|
||||
fallthrough
|
||||
default:
|
||||
reply := nsutil.ServerFailure().ReplyWithDNS(ctx, request)
|
||||
nsutil.AddMessagesToReply(ctx, reply, log.ErrorLevel, "INTERNAL ERROR: incorrect use of Connection DNS Responder")
|
||||
return reply
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtraRRs returns a slice of RRs with additional informational records.
|
||||
func (conn *Connection) GetExtraRRs(ctx context.Context, request *dns.Msg) []dns.RR {
|
||||
// Select level to add the verdict record with.
|
||||
var level log.Severity
|
||||
switch conn.Verdict {
|
||||
case VerdictFailed:
|
||||
level = log.ErrorLevel
|
||||
case VerdictUndecided, VerdictUndeterminable,
|
||||
VerdictAccept, VerdictBlock, VerdictDrop,
|
||||
VerdictRerouteToNameserver, VerdictRerouteToTunnel:
|
||||
fallthrough
|
||||
default:
|
||||
level = log.InfoLevel
|
||||
}
|
||||
|
||||
// Create resource record with verdict and reason.
|
||||
rr, err := nsutil.MakeMessageRecord(level, fmt.Sprintf("%s: %s", conn.VerdictVerb(), conn.Reason.Msg))
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Warningf("filter: failed to add informational record to reply: %s", err)
|
||||
return nil
|
||||
}
|
||||
extra := []dns.RR{rr}
|
||||
|
||||
// Add additional records from Reason.Context.
|
||||
if rrProvider, ok := conn.Reason.Context.(nsutil.RRProvider); ok {
|
||||
rrs := rrProvider.GetExtraRRs(ctx, request)
|
||||
extra = append(extra, rrs...)
|
||||
}
|
||||
|
||||
return extra
|
||||
}
|
||||
71
service/network/iphelper/get.go
Normal file
71
service/network/iphelper/get.go
Normal file
@@ -0,0 +1,71 @@
|
||||
//go:build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
var (
|
||||
ipHelper *IPHelper
|
||||
|
||||
// lock locks access to the whole DLL.
|
||||
// TODO: It's unproven if we can access the iphlpapi.dll concurrently, especially as we might be encountering various versions of the DLL. In the future, we could possibly investigate and improve performance here.
|
||||
lock sync.RWMutex
|
||||
)
|
||||
|
||||
// GetTCP4Table returns the system table for IPv4 TCP activity.
|
||||
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err = checkIPHelper()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return ipHelper.getTable(IPv4, TCP)
|
||||
}
|
||||
|
||||
// GetTCP6Table returns the system table for IPv6 TCP activity.
|
||||
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err = checkIPHelper()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return ipHelper.getTable(IPv6, TCP)
|
||||
}
|
||||
|
||||
// GetUDP4Table returns the system table for IPv4 UDP activity.
|
||||
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err = checkIPHelper()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, binds, err = ipHelper.getTable(IPv4, UDP)
|
||||
return
|
||||
}
|
||||
|
||||
// GetUDP6Table returns the system table for IPv6 UDP activity.
|
||||
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err = checkIPHelper()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, binds, err = ipHelper.getTable(IPv6, UDP)
|
||||
return
|
||||
}
|
||||
63
service/network/iphelper/iphelper.go
Normal file
63
service/network/iphelper/iphelper.go
Normal file
@@ -0,0 +1,63 @@
|
||||
//go:build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalid = errors.New("IPHelper not initialized or broken")
|
||||
)
|
||||
|
||||
// IPHelper represents a subset of the Windows iphlpapi.dll.
|
||||
type IPHelper struct {
|
||||
dll *windows.LazyDLL
|
||||
|
||||
getExtendedTCPTable *windows.LazyProc
|
||||
getExtendedUDPTable *windows.LazyProc
|
||||
|
||||
valid *abool.AtomicBool
|
||||
}
|
||||
|
||||
func checkIPHelper() (err error) {
|
||||
if ipHelper == nil {
|
||||
ipHelper, err = New()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded).
|
||||
func New() (*IPHelper, error) {
|
||||
|
||||
new := &IPHelper{}
|
||||
new.valid = abool.NewBool(false)
|
||||
var err error
|
||||
|
||||
// load dll
|
||||
new.dll = windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
err = new.dll.Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// load functions
|
||||
new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable")
|
||||
err = new.getExtendedTCPTable.Find()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err)
|
||||
}
|
||||
new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable")
|
||||
err = new.getExtendedUDPTable.Find()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err)
|
||||
}
|
||||
|
||||
new.valid.Set()
|
||||
return new, nil
|
||||
}
|
||||
398
service/network/iphelper/tables.go
Normal file
398
service/network/iphelper/tables.go
Normal file
@@ -0,0 +1,398 @@
|
||||
//go:build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// Windows API constants
|
||||
const (
|
||||
iphelperTCPTableOwnerPIDAll uintptr = 5
|
||||
iphelperUDPTableOwnerPID uintptr = 1
|
||||
iphelperTCPStateListen uint32 = 2
|
||||
|
||||
winErrInsufficientBuffer = uintptr(windows.ERROR_INSUFFICIENT_BUFFER)
|
||||
winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER)
|
||||
)
|
||||
|
||||
type iphelperTCPTable struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx
|
||||
numEntries uint32
|
||||
table [maxStateTableEntries]iphelperTCPRow
|
||||
}
|
||||
|
||||
type iphelperTCPRow struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366913(v=vs.85).aspx
|
||||
state uint32
|
||||
localAddr uint32
|
||||
localPort uint32
|
||||
remoteAddr uint32
|
||||
remotePort uint32
|
||||
owningPid uint32
|
||||
}
|
||||
|
||||
type iphelperTCP6Table struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366905(v=vs.85).aspx
|
||||
numEntries uint32
|
||||
table [maxStateTableEntries]iphelperTCP6Row
|
||||
}
|
||||
|
||||
type iphelperTCP6Row struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366896(v=vs.85).aspx
|
||||
localAddr [16]byte
|
||||
_ uint32 // localScopeID
|
||||
localPort uint32
|
||||
remoteAddr [16]byte
|
||||
_ uint32 // remoteScopeID
|
||||
remotePort uint32
|
||||
state uint32
|
||||
owningPid uint32
|
||||
}
|
||||
|
||||
type iphelperUDPTable struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366932(v=vs.85).aspx
|
||||
numEntries uint32
|
||||
table [maxStateTableEntries]iphelperUDPRow
|
||||
}
|
||||
|
||||
type iphelperUDPRow struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366928(v=vs.85).aspx
|
||||
localAddr uint32
|
||||
localPort uint32
|
||||
owningPid uint32
|
||||
}
|
||||
|
||||
type iphelperUDP6Table struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366925(v=vs.85).aspx
|
||||
numEntries uint32
|
||||
table [maxStateTableEntries]iphelperUDP6Row
|
||||
}
|
||||
|
||||
type iphelperUDP6Row struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366923(v=vs.85).aspx
|
||||
localAddr [16]byte
|
||||
_ uint32 // localScopeID
|
||||
localPort uint32
|
||||
owningPid uint32
|
||||
}
|
||||
|
||||
// IP and Protocol constants
|
||||
const (
|
||||
IPv4 uint8 = 4
|
||||
IPv6 uint8 = 6
|
||||
|
||||
TCP uint8 = 6
|
||||
UDP uint8 = 17
|
||||
)
|
||||
|
||||
type learningBufSize struct {
|
||||
sync.Mutex
|
||||
|
||||
size int
|
||||
usesLeft int
|
||||
useFor int
|
||||
start int
|
||||
max int
|
||||
}
|
||||
|
||||
func newLearningBufSize(start, max, ttl int) *learningBufSize {
|
||||
return &learningBufSize{
|
||||
size: start,
|
||||
usesLeft: ttl,
|
||||
useFor: ttl,
|
||||
start: start,
|
||||
max: max,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
startBufSize = 1024
|
||||
|
||||
// bufSizeUsageTTL defines how often a buffer size is used before it is
|
||||
// shrunk again.
|
||||
bufSizeUsageTTL = 100
|
||||
|
||||
// maxBufSize is the maximum size we will allocate for responses. This was
|
||||
// previously set at 65k, which was too little for some production cases.
|
||||
maxBufSize = 1048576 // 2^20B, 1MB
|
||||
|
||||
// maxStateTableEntries is the maximum supported amount of entries of the
|
||||
// state tables.
|
||||
// This is never allocated, but just casted to from an unsafe pointer.
|
||||
maxStateTableEntries = 65535
|
||||
)
|
||||
|
||||
var (
|
||||
tcp4BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL)
|
||||
udp4BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL)
|
||||
tcp6BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL)
|
||||
udp6BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL)
|
||||
)
|
||||
|
||||
func (lbf *learningBufSize) getBufSize() int {
|
||||
lbf.Lock()
|
||||
defer lbf.Unlock()
|
||||
|
||||
// using bufSize
|
||||
lbf.usesLeft--
|
||||
// check if we want to reset
|
||||
if lbf.usesLeft <= 0 {
|
||||
// decrease
|
||||
lbf.size /= 2
|
||||
// not too little
|
||||
if lbf.size < lbf.start {
|
||||
lbf.size = lbf.start
|
||||
}
|
||||
// reset TTL counter
|
||||
lbf.usesLeft = lbf.useFor
|
||||
}
|
||||
|
||||
return lbf.size
|
||||
}
|
||||
|
||||
func (lbf *learningBufSize) increaseBufSize(minSize int) int {
|
||||
lbf.Lock()
|
||||
defer lbf.Unlock()
|
||||
|
||||
// increase
|
||||
lbf.size *= 2
|
||||
// increase until we reach the minimum size
|
||||
for lbf.size < minSize {
|
||||
lbf.size *= 2
|
||||
}
|
||||
// not too much
|
||||
if lbf.size > lbf.max {
|
||||
lbf.size = lbf.max
|
||||
}
|
||||
// reset TTL counter
|
||||
lbf.usesLeft = lbf.useFor
|
||||
// return new bufSize
|
||||
return lbf.size
|
||||
}
|
||||
|
||||
// getTable returns the current connection state table of Windows of the given protocol and IP version.
|
||||
func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { //nolint:gocognit,gocycle // TODO
|
||||
// docs: https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
|
||||
|
||||
if !ipHelper.valid.IsSet() {
|
||||
return nil, nil, errInvalid
|
||||
}
|
||||
|
||||
var afClass int
|
||||
var lbf *learningBufSize
|
||||
switch ipVersion {
|
||||
case IPv4:
|
||||
afClass = windows.AF_INET
|
||||
if protocol == TCP {
|
||||
lbf = tcp4BufSize
|
||||
} else {
|
||||
lbf = udp4BufSize
|
||||
}
|
||||
case IPv6:
|
||||
afClass = windows.AF_INET6
|
||||
if protocol == TCP {
|
||||
lbf = tcp6BufSize
|
||||
} else {
|
||||
lbf = udp6BufSize
|
||||
}
|
||||
default:
|
||||
return nil, nil, errors.New("invalid protocol")
|
||||
}
|
||||
|
||||
// try max 5 times
|
||||
maxTries := 5
|
||||
usedBufSize := lbf.getBufSize()
|
||||
var buf []byte
|
||||
|
||||
for i := 1; i <= maxTries; i++ {
|
||||
bufSizeParam := usedBufSize
|
||||
buf = make([]byte, bufSizeParam)
|
||||
var r1 uintptr
|
||||
|
||||
switch protocol {
|
||||
case TCP:
|
||||
r1, _, err = ipHelper.getExtendedTCPTable.Call(
|
||||
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable
|
||||
uintptr(unsafe.Pointer(&bufSizeParam)), // _Inout_ PDWORD pdwSize
|
||||
0, // _In_ BOOL bOrder
|
||||
uintptr(afClass), // _In_ ULONG ulAf
|
||||
iphelperTCPTableOwnerPIDAll, // _In_ TCP_TABLE_CLASS TableClass
|
||||
0, // _In_ ULONG Reserved
|
||||
)
|
||||
case UDP:
|
||||
r1, _, err = ipHelper.getExtendedUDPTable.Call(
|
||||
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable,
|
||||
uintptr(unsafe.Pointer(&bufSizeParam)), // _Inout_ PDWORD pdwSize,
|
||||
0, // _In_ BOOL bOrder,
|
||||
uintptr(afClass), // _In_ ULONG ulAf,
|
||||
iphelperUDPTableOwnerPID, // _In_ UDP_TABLE_CLASS TableClass,
|
||||
0, // _In_ ULONG Reserved
|
||||
)
|
||||
}
|
||||
|
||||
switch r1 {
|
||||
case winErrInsufficientBuffer:
|
||||
if i >= maxTries {
|
||||
return nil, nil, fmt.Errorf(
|
||||
"insufficient buffer error (tried %d times): provided %d bytes; required %d bytes - [NT 0x%X] %s",
|
||||
i, usedBufSize, bufSizeParam, r1, err,
|
||||
)
|
||||
}
|
||||
// bufSizeParam was modified by ipHelper.getExtended*Table to hold the
|
||||
// required buffer size.
|
||||
usedBufSize = lbf.increaseBufSize(bufSizeParam)
|
||||
case winErrInvalidParameter:
|
||||
return nil, nil, fmt.Errorf("invalid parameter: [NT 0x%X] %s", r1, err)
|
||||
case windows.NO_ERROR:
|
||||
// success
|
||||
break
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unexpected error: [NT 0x%X] %s", r1, err)
|
||||
}
|
||||
}
|
||||
|
||||
// parse output
|
||||
switch {
|
||||
case protocol == TCP && ipVersion == IPv4:
|
||||
|
||||
tcpTable := (*iphelperTCPTable)(unsafe.Pointer(&buf[0]))
|
||||
// Check if we got more entries than supported.
|
||||
tableEntries := tcpTable.numEntries
|
||||
if tableEntries > maxStateTableEntries {
|
||||
tableEntries = maxStateTableEntries
|
||||
log.Warningf("network/iphelper: received TCPv4 table with more entries than supported: %d/%d", tcpTable.numEntries, maxStateTableEntries)
|
||||
}
|
||||
// Cap table to actual entries.
|
||||
table := tcpTable.table[:tableEntries]
|
||||
|
||||
for _, row := range table {
|
||||
if row.state == iphelperTCPStateListen {
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: convertIPv4(row.localAddr),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
} else {
|
||||
connections = append(connections, &socket.ConnectionInfo{
|
||||
Local: socket.Address{
|
||||
IP: convertIPv4(row.localAddr),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
Remote: socket.Address{
|
||||
IP: convertIPv4(row.remoteAddr),
|
||||
Port: convertPort(row.remotePort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
case protocol == TCP && ipVersion == IPv6:
|
||||
|
||||
tcpTable := (*iphelperTCP6Table)(unsafe.Pointer(&buf[0]))
|
||||
// Check if we got more entries than supported.
|
||||
tableEntries := tcpTable.numEntries
|
||||
if tableEntries > maxStateTableEntries {
|
||||
tableEntries = maxStateTableEntries
|
||||
log.Warningf("network/iphelper: received TCPv6 table with more entries than supported: %d/%d", tcpTable.numEntries, maxStateTableEntries)
|
||||
}
|
||||
// Cap table to actual entries.
|
||||
table := tcpTable.table[:tableEntries]
|
||||
|
||||
for _, row := range table {
|
||||
if row.state == iphelperTCPStateListen {
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: net.IP(row.localAddr[:]),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
} else {
|
||||
connections = append(connections, &socket.ConnectionInfo{
|
||||
Local: socket.Address{
|
||||
IP: net.IP(row.localAddr[:]),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
Remote: socket.Address{
|
||||
IP: net.IP(row.remoteAddr[:]),
|
||||
Port: convertPort(row.remotePort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
case protocol == UDP && ipVersion == IPv4:
|
||||
|
||||
udpTable := (*iphelperUDPTable)(unsafe.Pointer(&buf[0]))
|
||||
// Check if we got more entries than supported.
|
||||
tableEntries := udpTable.numEntries
|
||||
if tableEntries > maxStateTableEntries {
|
||||
tableEntries = maxStateTableEntries
|
||||
log.Warningf("network/iphelper: received UDPv4 table with more entries than supported: %d/%d", udpTable.numEntries, maxStateTableEntries)
|
||||
}
|
||||
// Cap table to actual entries.
|
||||
table := udpTable.table[:tableEntries]
|
||||
|
||||
for _, row := range table {
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: convertIPv4(row.localAddr),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
}
|
||||
|
||||
case protocol == UDP && ipVersion == IPv6:
|
||||
|
||||
udpTable := (*iphelperUDP6Table)(unsafe.Pointer(&buf[0]))
|
||||
// Check if we got more entries than supported.
|
||||
tableEntries := udpTable.numEntries
|
||||
if tableEntries > maxStateTableEntries {
|
||||
tableEntries = maxStateTableEntries
|
||||
log.Warningf("network/iphelper: received UDPv6 table with more entries than supported: %d/%d", udpTable.numEntries, maxStateTableEntries)
|
||||
}
|
||||
// Cap table to actual entries.
|
||||
table := udpTable.table[:tableEntries]
|
||||
|
||||
for _, row := range table {
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: net.IP(row.localAddr[:]),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return connections, binds, nil
|
||||
}
|
||||
|
||||
// convertIPv4 as needed for iphlpapi.dll
|
||||
func convertIPv4(input uint32) net.IP {
|
||||
addressBuf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(addressBuf, input)
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
// convertPort converts ports received from iphlpapi.dll
|
||||
func convertPort(input uint32) uint16 {
|
||||
return uint16(input>>8 | input<<8)
|
||||
}
|
||||
54
service/network/iphelper/tables_test.go
Normal file
54
service/network/iphelper/tables_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
//go:build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSockets(t *testing.T) {
|
||||
connections, listeners, err := GetTCP4Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nTCP 4 connections:")
|
||||
for _, connection := range connections {
|
||||
fmt.Printf("%+v\n", connection)
|
||||
}
|
||||
fmt.Println("\nTCP 4 listeners:")
|
||||
for _, listener := range listeners {
|
||||
fmt.Printf("%+v\n", listener)
|
||||
}
|
||||
|
||||
connections, listeners, err = GetTCP6Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nTCP 6 connections:")
|
||||
for _, connection := range connections {
|
||||
fmt.Printf("%+v\n", connection)
|
||||
}
|
||||
fmt.Println("\nTCP 6 listeners:")
|
||||
for _, listener := range listeners {
|
||||
fmt.Printf("%+v\n", listener)
|
||||
}
|
||||
|
||||
binds, err := GetUDP4Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nUDP 4 binds:")
|
||||
for _, bind := range binds {
|
||||
fmt.Printf("%+v\n", bind)
|
||||
}
|
||||
|
||||
binds, err = GetUDP6Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nUDP 6 binds:")
|
||||
for _, bind := range binds {
|
||||
fmt.Printf("%+v\n", bind)
|
||||
}
|
||||
}
|
||||
172
service/network/metrics.go
Normal file
172
service/network/metrics.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/metrics"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
)
|
||||
|
||||
var (
|
||||
packetHandlingHistogram *metrics.Histogram
|
||||
blockedOutConnCounter *metrics.Counter
|
||||
encryptedAndTunneledOutConnCounter *metrics.Counter
|
||||
encryptedOutConnCounter *metrics.Counter
|
||||
tunneledOutConnCounter *metrics.Counter
|
||||
outConnCounter *metrics.Counter
|
||||
)
|
||||
|
||||
func registerMetrics() (err error) {
|
||||
// This needed to be moved here, because every packet is now handled by the
|
||||
// connection handler worker.
|
||||
packetHandlingHistogram, err = metrics.NewHistogram(
|
||||
"firewall/handling/duration/seconds",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Permission: api.PermitUser,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"network/connections/active/total",
|
||||
nil,
|
||||
func() float64 {
|
||||
return float64(conns.active())
|
||||
},
|
||||
&metrics.Options{
|
||||
InternalID: "active_connections",
|
||||
Permission: api.PermitUser,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connCounterID := "network/connections/total"
|
||||
connCounterOpts := &metrics.Options{
|
||||
Name: "Connections",
|
||||
Permission: api.PermitUser,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
Persist: true,
|
||||
}
|
||||
|
||||
blockedOutConnCounter, err = metrics.NewCounter(
|
||||
connCounterID,
|
||||
map[string]string{
|
||||
"direction": "out",
|
||||
"blocked": "true",
|
||||
},
|
||||
&metrics.Options{
|
||||
Name: "Connections",
|
||||
InternalID: "blocked_outgoing_connections",
|
||||
Permission: api.PermitUser,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
Persist: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encryptedAndTunneledOutConnCounter, err = metrics.NewCounter(
|
||||
connCounterID,
|
||||
map[string]string{
|
||||
"direction": "out",
|
||||
"encrypted": "true",
|
||||
"tunneled": "true",
|
||||
},
|
||||
connCounterOpts,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encryptedOutConnCounter, err = metrics.NewCounter(
|
||||
connCounterID,
|
||||
map[string]string{
|
||||
"direction": "out",
|
||||
"encrypted": "true",
|
||||
},
|
||||
connCounterOpts,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tunneledOutConnCounter, err = metrics.NewCounter(
|
||||
connCounterID,
|
||||
map[string]string{
|
||||
"direction": "out",
|
||||
"tunneled": "true",
|
||||
},
|
||||
connCounterOpts,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
outConnCounter, err = metrics.NewCounter(
|
||||
connCounterID,
|
||||
map[string]string{
|
||||
"direction": "out",
|
||||
},
|
||||
connCounterOpts,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Connection) addToMetrics() {
|
||||
if conn.addedToMetrics {
|
||||
return
|
||||
}
|
||||
|
||||
// Don't count requests serviced to the network,
|
||||
// as we have an incomplete view here.
|
||||
if conn.Process() != nil &&
|
||||
conn.Process().Pid == process.NetworkHostProcessID {
|
||||
return
|
||||
}
|
||||
|
||||
// Only count outgoing connections for now.
|
||||
if conn.Inbound {
|
||||
return
|
||||
}
|
||||
|
||||
// Check the verdict.
|
||||
switch conn.Verdict { //nolint:exhaustive // Not critical.
|
||||
case VerdictBlock, VerdictDrop:
|
||||
blockedOutConnCounter.Inc()
|
||||
conn.addedToMetrics = true
|
||||
return
|
||||
case VerdictAccept, VerdictRerouteToTunnel:
|
||||
// Continue to next section.
|
||||
default:
|
||||
// Connection is not counted.
|
||||
return
|
||||
}
|
||||
|
||||
// Only count successful connections, not DNS requests.
|
||||
if conn.Type == DNSRequest {
|
||||
return
|
||||
}
|
||||
|
||||
// Select counter based on attributes.
|
||||
switch {
|
||||
case conn.Encrypted && conn.Tunneled:
|
||||
encryptedAndTunneledOutConnCounter.Inc()
|
||||
case conn.Encrypted:
|
||||
encryptedOutConnCounter.Inc()
|
||||
case conn.Tunneled:
|
||||
tunneledOutConnCounter.Inc()
|
||||
default:
|
||||
outConnCounter.Inc()
|
||||
}
|
||||
conn.addedToMetrics = true
|
||||
}
|
||||
151
service/network/module.go
Normal file
151
service/network/module.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network/state"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
defaultFirewallHandler FirewallHandler
|
||||
)
|
||||
|
||||
// Events.
|
||||
var (
|
||||
ConnectionReattributedEvent = "connection re-attributed"
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("network", prep, start, nil, "base", "netenv", "processes")
|
||||
module.RegisterEvent(ConnectionReattributedEvent, false)
|
||||
}
|
||||
|
||||
// SetDefaultFirewallHandler sets the default firewall handler.
|
||||
func SetDefaultFirewallHandler(handler FirewallHandler) {
|
||||
if defaultFirewallHandler == nil {
|
||||
defaultFirewallHandler = handler
|
||||
}
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
if netenv.IPv6Enabled() {
|
||||
state.EnableTCPDualStack()
|
||||
state.EnableUDPDualStack()
|
||||
}
|
||||
|
||||
return registerAPIEndpoints()
|
||||
}
|
||||
|
||||
func start() error {
|
||||
err := registerAsDatabase()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := registerMetrics(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
module.StartServiceWorker("clean connections", 0, connectionCleaner)
|
||||
module.StartServiceWorker("write open dns requests", 0, openDNSRequestWriter)
|
||||
|
||||
if err := module.RegisterEventHook(
|
||||
"profiles",
|
||||
profile.DeletedEvent,
|
||||
"re-attribute connections from deleted profile",
|
||||
reAttributeConnections,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var reAttributionLock sync.Mutex
|
||||
|
||||
// reAttributeConnections finds all connections of a deleted profile and re-attributes them.
|
||||
// Expected event data: scoped profile ID.
|
||||
func reAttributeConnections(_ context.Context, eventData any) error {
|
||||
profileID, ok := eventData.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("event data is not a string: %v", eventData)
|
||||
}
|
||||
profileSource, profileID, ok := strings.Cut(profileID, "/")
|
||||
if !ok {
|
||||
return fmt.Errorf("event data does not seem to be a scoped profile ID: %v", eventData)
|
||||
}
|
||||
|
||||
// Hold a lock for re-attribution, to prevent simultaneous processing of the
|
||||
// same connections and make logging cleaner.
|
||||
reAttributionLock.Lock()
|
||||
defer reAttributionLock.Unlock()
|
||||
|
||||
// Create tracing context.
|
||||
ctx, tracer := log.AddTracer(context.Background())
|
||||
defer tracer.Submit()
|
||||
tracer.Infof("network: re-attributing connections from deleted profile %s/%s", profileSource, profileID)
|
||||
|
||||
// Count and log how many connections were re-attributed.
|
||||
var reAttributed int
|
||||
|
||||
// Re-attribute connections.
|
||||
for _, conn := range conns.clone() {
|
||||
if reAttributeConnection(ctx, conn, profileID, profileSource) {
|
||||
reAttributed++
|
||||
tracer.Debugf("filter: re-attributed %s to %s", conn, conn.process.PrimaryProfileID)
|
||||
}
|
||||
}
|
||||
|
||||
// Re-attribute dns connections.
|
||||
for _, conn := range dnsConns.clone() {
|
||||
if reAttributeConnection(ctx, conn, profileID, profileSource) {
|
||||
reAttributed++
|
||||
tracer.Debugf("filter: re-attributed %s to %s", conn, conn.process.PrimaryProfileID)
|
||||
}
|
||||
}
|
||||
|
||||
tracer.Infof("filter: re-attributed %d connections", reAttributed)
|
||||
return nil
|
||||
}
|
||||
|
||||
func reAttributeConnection(ctx context.Context, conn *Connection, profileID, profileSource string) (reAttributed bool) {
|
||||
// Lock the connection before checking anything to avoid a race condition with connection data collection.
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
|
||||
// Check if the connection has the profile we are looking for.
|
||||
switch {
|
||||
case !conn.DataIsComplete():
|
||||
return false
|
||||
case conn.ProcessContext.Profile != profileID:
|
||||
return false
|
||||
case conn.ProcessContext.Source != profileSource:
|
||||
return false
|
||||
}
|
||||
|
||||
// Attempt to assign new profile.
|
||||
err := conn.process.RefetchProfile(ctx)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Warningf("network: failed to refetch profile for %s: %s", conn, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Set the new process context.
|
||||
conn.ProcessContext = getProcessContext(ctx, conn.process)
|
||||
conn.Save()
|
||||
|
||||
// Trigger event for re-attribution.
|
||||
module.TriggerEvent(ConnectionReattributedEvent, conn.ID)
|
||||
|
||||
log.Tracer(ctx).Debugf("filter: re-attributed %s to %s", conn, conn.process.PrimaryProfileID)
|
||||
return true
|
||||
}
|
||||
45
service/network/multicast.go
Normal file
45
service/network/multicast.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
)
|
||||
|
||||
// GetMulticastRequestConn searches for and returns the requesting connnection
|
||||
// of a possible multicast/broadcast response.
|
||||
func GetMulticastRequestConn(responseConn *Connection, responseFromNet *net.IPNet) *Connection {
|
||||
// Calculate the broadcast address the query would have gone to.
|
||||
responseNetBroadcastIP := netutils.GetBroadcastAddress(responseFromNet.IP, responseFromNet.Mask)
|
||||
|
||||
// Find requesting multicast/broadcast connection.
|
||||
for _, conn := range conns.clone() {
|
||||
switch {
|
||||
case !conn.DataIsComplete():
|
||||
// Ignore connection with incomplete data.
|
||||
case conn.Inbound:
|
||||
// Ignore incoming connections.
|
||||
case conn.Ended != 0:
|
||||
// Ignore ended connections.
|
||||
case conn.Entity.Protocol != responseConn.Entity.Protocol:
|
||||
// Ignore on protocol mismatch.
|
||||
case conn.LocalPort != responseConn.LocalPort:
|
||||
// Ignore on local port mismatch.
|
||||
case !conn.LocalIP.Equal(responseConn.LocalIP):
|
||||
// Ignore on local IP mismatch.
|
||||
case !conn.Process().Equal(responseConn.Process()):
|
||||
// Ignore if processes mismatch.
|
||||
case conn.Entity.IPScope == netutils.LocalMulticast &&
|
||||
(responseConn.Entity.IPScope == netutils.LinkLocal ||
|
||||
responseConn.Entity.IPScope == netutils.SiteLocal):
|
||||
// We found a (possibly routed) multicast request that matches the response!
|
||||
return conn
|
||||
case conn.Entity.IP.Equal(responseNetBroadcastIP) &&
|
||||
responseFromNet.Contains(conn.LocalIP):
|
||||
// We found a (link local) broadcast request that matches the response!
|
||||
return conn
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
60
service/network/netutils/address.go
Normal file
60
service/network/netutils/address.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package netutils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
var errInvalidIP = errors.New("invalid IP address")
|
||||
|
||||
// IPPortFromAddr extracts or parses the IP address and port contained in the given address.
|
||||
func IPPortFromAddr(addr net.Addr) (ip net.IP, port uint16, err error) {
|
||||
// Convert addr to IP if needed.
|
||||
switch v := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return v.IP, uint16(v.Port), nil
|
||||
case *net.UDPAddr:
|
||||
return v.IP, uint16(v.Port), nil
|
||||
case *net.IPAddr:
|
||||
return v.IP, 0, nil
|
||||
case *net.UnixAddr:
|
||||
return nil, 0, errors.New("unix addresses don't have IPs")
|
||||
default:
|
||||
return ParseIPPort(addr.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ProtocolFromNetwork returns the protocol from the given net, as used in the "net" golang stdlib.
|
||||
func ProtocolFromNetwork(net string) (protocol packet.IPProtocol) {
|
||||
switch net {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
return packet.TCP
|
||||
case "udp", "udp4", "udp6":
|
||||
return packet.UDP
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ParseIPPort parses a <ip>:port formatted address.
|
||||
func ParseIPPort(address string) (net.IP, uint16, error) {
|
||||
ipString, portString, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
ip := net.ParseIP(ipString)
|
||||
if ip == nil {
|
||||
return nil, 0, errInvalidIP
|
||||
}
|
||||
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return ip, uint16(port), nil
|
||||
}
|
||||
99
service/network/netutils/dns.go
Normal file
99
service/network/netutils/dns.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package netutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var (
|
||||
cleanDomainRegex = regexp.MustCompile(
|
||||
`^` + // match beginning
|
||||
`(` + // start subdomain group
|
||||
`(xn--)?` + // idn prefix
|
||||
`[a-z0-9_-]{1,63}` + // main chunk
|
||||
`\.` + // ending with a dot
|
||||
`)*` + // end subdomain group, allow any number of subdomains
|
||||
`(xn--)?` + // TLD idn prefix
|
||||
`[a-z0-9_-]{1,63}` + // TLD main chunk with at least one character (for custom ones)
|
||||
`\.` + // ending with a dot
|
||||
`$`, // match end
|
||||
)
|
||||
|
||||
// dnsSDDomainRegex is a lot more lax to better suit the allowed characters in DNS-SD.
|
||||
// Not all characters have been allowed - some special characters were
|
||||
// removed to reduce the general attack surface.
|
||||
dnsSDDomainRegex = regexp.MustCompile(
|
||||
// Start of charset selection.
|
||||
`^[` +
|
||||
// Printable ASCII (character code 32-127), excluding some special characters.
|
||||
` !#$%&()*+,\-\./0-9:;=?@A-Z[\\\]^_\a-z{|}~` +
|
||||
// Only latin characters from extended ASCII (character code 128-255).
|
||||
`ŠŒŽšœžŸ¡¿ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿ` +
|
||||
// End of charset selection.
|
||||
`]*$`,
|
||||
)
|
||||
)
|
||||
|
||||
// IsValidFqdn returns whether the given string is a valid fqdn.
|
||||
func IsValidFqdn(fqdn string) bool {
|
||||
// root zone
|
||||
if fqdn == "." {
|
||||
return true
|
||||
}
|
||||
|
||||
// check max length
|
||||
if len(fqdn) > 256 {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsFqdn checks if a domain name is fully qualified.
|
||||
if !dns.IsFqdn(fqdn) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use special check for .local domains to support DNS-SD.
|
||||
if strings.HasSuffix(fqdn, ".local.") {
|
||||
return dnsSDDomainRegex.MatchString(fqdn)
|
||||
}
|
||||
|
||||
// check with regex
|
||||
if !cleanDomainRegex.MatchString(fqdn) {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDomainName checks if s is a valid domain name, it returns the number of
|
||||
// labels and true, when a domain name is valid. Note that non fully qualified
|
||||
// domain name is considered valid, in this case the last label is counted in
|
||||
// the number of labels. When false is returned the number of labels is not
|
||||
// defined. Also note that this function is extremely liberal; almost any
|
||||
// string is a valid domain name as the DNS is 8 bit protocol. It checks if each
|
||||
// label fits in 63 characters and that the entire name will fit into the 255
|
||||
// octet wire format limit.
|
||||
_, ok := dns.IsDomainName(fqdn)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IPsToRRs transforms the given IPs to resource records.
|
||||
func IPsToRRs(domain string, ips []net.IP) ([]dns.RR, error) {
|
||||
records := make([]dns.RR, 0, len(ips))
|
||||
var rr dns.RR
|
||||
var err error
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.To4() != nil {
|
||||
rr, err = dns.NewRR(fmt.Sprintf("%s 17 IN A %s", domain, ip))
|
||||
} else {
|
||||
rr, err = dns.NewRR(fmt.Sprintf("%s 17 IN AAAA %s", domain, ip))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create record for %s: %w", ip, err)
|
||||
}
|
||||
records = append(records, rr)
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
47
service/network/netutils/dns_test.go
Normal file
47
service/network/netutils/dns_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package netutils
|
||||
|
||||
import "testing"
|
||||
|
||||
func testDomainValidity(t *testing.T, domain string, isValid bool) {
|
||||
t.Helper()
|
||||
|
||||
if IsValidFqdn(domain) != isValid {
|
||||
t.Errorf("domain %s failed check: was valid=%v, expected valid=%v", domain, IsValidFqdn(domain), isValid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// valid
|
||||
testDomainValidity(t, ".", true)
|
||||
testDomainValidity(t, "at.", true)
|
||||
testDomainValidity(t, "orf.at.", true)
|
||||
testDomainValidity(t, "www.orf.at.", true)
|
||||
testDomainValidity(t, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.u.v.x.y.z.example.org.", true)
|
||||
testDomainValidity(t, "a_a.com.", true)
|
||||
testDomainValidity(t, "a-a.com.", true)
|
||||
testDomainValidity(t, "a_a.com.", true)
|
||||
testDomainValidity(t, "a-a.com.", true)
|
||||
testDomainValidity(t, "xn--a.com.", true)
|
||||
testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasd.com.", true)
|
||||
|
||||
// maybe valid
|
||||
testDomainValidity(t, "-.com.", true)
|
||||
testDomainValidity(t, "_.com.", true)
|
||||
testDomainValidity(t, "a_.com.", true)
|
||||
testDomainValidity(t, "a-.com.", true)
|
||||
testDomainValidity(t, "_a.com.", true)
|
||||
testDomainValidity(t, "-a.com.", true)
|
||||
|
||||
// invalid
|
||||
testDomainValidity(t, ".com.", false)
|
||||
testDomainValidity(t, ".com.", false)
|
||||
testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
|
||||
testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
|
||||
testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
|
||||
testDomainValidity(t, "asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.as.com.", false)
|
||||
|
||||
// real world examples
|
||||
testDomainValidity(t, "iuqerfsodp9ifjaposdfjhgosurijfaewrwergwea.com.", true)
|
||||
}
|
||||
160
service/network/netutils/ip.go
Normal file
160
service/network/netutils/ip.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package netutils
|
||||
|
||||
import "net"
|
||||
|
||||
// IPScope is the scope of the IP address.
|
||||
type IPScope int8
|
||||
|
||||
// Defined IP Scopes.
|
||||
const (
|
||||
Invalid IPScope = iota - 1
|
||||
Undefined
|
||||
HostLocal
|
||||
LinkLocal
|
||||
SiteLocal
|
||||
Global
|
||||
LocalMulticast
|
||||
GlobalMulticast
|
||||
)
|
||||
|
||||
// ClassifyIP returns the network scope of the given IP address.
|
||||
// Deprecated: Please use the new GetIPScope instead.
|
||||
func ClassifyIP(ip net.IP) IPScope {
|
||||
return GetIPScope(ip)
|
||||
}
|
||||
|
||||
// GetIPScope returns the network scope of the given IP address.
|
||||
func GetIPScope(ip net.IP) IPScope { //nolint:gocognit
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4
|
||||
switch {
|
||||
case ip4[0] == 0 && ip4[1] == 0 && ip4[2] == 0 && ip4[3] == 0:
|
||||
// 0.0.0.0/32
|
||||
return LocalMulticast // Used as source for L2 based protocols with no L3 addressing.
|
||||
case ip4[0] == 0:
|
||||
// 0.0.0.0/8
|
||||
return Invalid
|
||||
case ip4[0] == 10:
|
||||
// 10.0.0.0/8 (RFC1918)
|
||||
return SiteLocal
|
||||
case ip4[0] == 100 && ip4[1]&0b11000000 == 64:
|
||||
// 100.64.0.0/10 (RFC6598)
|
||||
return SiteLocal
|
||||
case ip4[0] == 127:
|
||||
// 127.0.0.0/8 (RFC1918)
|
||||
return HostLocal
|
||||
case ip4[0] == 169 && ip4[1] == 254:
|
||||
// 169.254.0.0/16 (RFC3927)
|
||||
return LinkLocal
|
||||
case ip4[0] == 172 && ip4[1]&0b11110000 == 16:
|
||||
// 172.16.0.0/12 (RFC1918)
|
||||
return SiteLocal
|
||||
case ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2:
|
||||
// 192.0.2.0/24 (TEST-NET-1, RFC5737)
|
||||
return Invalid
|
||||
case ip4[0] == 192 && ip4[1] == 168:
|
||||
// 192.168.0.0/16 (RFC1918)
|
||||
return SiteLocal
|
||||
case ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100:
|
||||
// 198.51.100.0/24 (TEST-NET-2, RFC5737)
|
||||
return Invalid
|
||||
case ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113:
|
||||
// 203.0.113.0/24 (TEST-NET-3, RFC5737)
|
||||
return Invalid
|
||||
case ip4[0] == 224:
|
||||
// 224.0.0.0/8 (RFC5771)
|
||||
return LocalMulticast
|
||||
case ip4[0] == 233 && ip4[1] == 252 && ip4[2] == 0:
|
||||
// 233.252.0.0/24 (MCAST-TEST-NET; RFC5771, RFC6676)
|
||||
return Invalid
|
||||
case ip4[0] >= 225 && ip4[0] <= 238:
|
||||
// 225.0.0.0/8 - 238.0.0.0/8 (RFC5771)
|
||||
return GlobalMulticast
|
||||
case ip4[0] == 239:
|
||||
// 239.0.0.0/8 (RFC2365)
|
||||
return LocalMulticast
|
||||
case ip4[0] == 255 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 255:
|
||||
// 255.255.255.255/32
|
||||
return LocalMulticast
|
||||
case ip4[0] >= 240:
|
||||
// 240.0.0.0/8 - 255.0.0.0/8 (minus 255.255.255.255/32)
|
||||
return Invalid
|
||||
default:
|
||||
return Global
|
||||
}
|
||||
} else if len(ip) == net.IPv6len {
|
||||
// IPv6
|
||||
switch {
|
||||
case ip.Equal(net.IPv6zero):
|
||||
return Invalid
|
||||
case ip.Equal(net.IPv6loopback):
|
||||
return HostLocal
|
||||
case ip[0]&0xfe == 0xfc:
|
||||
// fc00::/7
|
||||
return SiteLocal
|
||||
case ip[0] == 0xfe && ip[1]&0xc0 == 0x80:
|
||||
// fe80::/10
|
||||
return LinkLocal
|
||||
case ip[0] == 0xff && ip[1] <= 0x05:
|
||||
// ff00::/16 - ff05::/16
|
||||
return LocalMulticast
|
||||
case ip[0] == 0xff:
|
||||
// other ff00::/8
|
||||
return GlobalMulticast
|
||||
default:
|
||||
return Global
|
||||
}
|
||||
}
|
||||
return Invalid
|
||||
}
|
||||
|
||||
// IsLocalhost returns whether the IP refers to the host itself.
|
||||
func (scope IPScope) IsLocalhost() bool {
|
||||
return scope == HostLocal
|
||||
}
|
||||
|
||||
// IsLAN returns true if the scope is site-local or link-local.
|
||||
func (scope IPScope) IsLAN() bool {
|
||||
switch scope { //nolint:exhaustive // Looking for something specific.
|
||||
case SiteLocal, LinkLocal, LocalMulticast:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal returns true if the scope is global.
|
||||
func (scope IPScope) IsGlobal() bool {
|
||||
switch scope { //nolint:exhaustive // Looking for something specific.
|
||||
case Global, GlobalMulticast:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetBroadcastAddress returns the broadcast address of the given IP and network mask.
|
||||
// If a mixed IPv4/IPv6 input is given, it returns nil.
|
||||
func GetBroadcastAddress(ip net.IP, netMask net.IPMask) net.IP {
|
||||
// Convert to standard v4.
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
ip = ip4
|
||||
}
|
||||
mask := net.IP(netMask)
|
||||
if ip4Mask := mask.To4(); ip4Mask != nil {
|
||||
mask = ip4Mask
|
||||
}
|
||||
|
||||
// Check for mixed v4/v6 input.
|
||||
if len(ip) != len(mask) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Merge to broadcast address
|
||||
n := len(ip)
|
||||
broadcastAddress := make(net.IP, n)
|
||||
for i := 0; i < n; i++ {
|
||||
broadcastAddress[i] = ip[i] | ^mask[i]
|
||||
}
|
||||
return broadcastAddress
|
||||
}
|
||||
51
service/network/netutils/ip_test.go
Normal file
51
service/network/netutils/ip_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package netutils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPScope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testScope(t, net.IPv4(71, 87, 113, 211), Global)
|
||||
testScope(t, net.IPv4(127, 0, 0, 1), HostLocal)
|
||||
testScope(t, net.IPv4(127, 255, 255, 1), HostLocal)
|
||||
testScope(t, net.IPv4(192, 168, 172, 24), SiteLocal)
|
||||
testScope(t, net.IPv4(172, 15, 1, 1), Global)
|
||||
testScope(t, net.IPv4(172, 16, 1, 1), SiteLocal)
|
||||
testScope(t, net.IPv4(172, 31, 1, 1), SiteLocal)
|
||||
testScope(t, net.IPv4(172, 32, 1, 1), Global)
|
||||
}
|
||||
|
||||
func testScope(t *testing.T, ip net.IP, expectedScope IPScope) {
|
||||
t.Helper()
|
||||
|
||||
c := GetIPScope(ip)
|
||||
if c != expectedScope {
|
||||
t.Errorf("%s is %s, expected %s", ip, scopeName(c), scopeName(expectedScope))
|
||||
}
|
||||
}
|
||||
|
||||
func scopeName(c IPScope) string {
|
||||
switch c {
|
||||
case Invalid:
|
||||
return "invalid"
|
||||
case Undefined:
|
||||
return "undefined"
|
||||
case HostLocal:
|
||||
return "hostLocal"
|
||||
case LinkLocal:
|
||||
return "linkLocal"
|
||||
case SiteLocal:
|
||||
return "siteLocal"
|
||||
case Global:
|
||||
return "global"
|
||||
case LocalMulticast:
|
||||
return "localMulticast"
|
||||
case GlobalMulticast:
|
||||
return "globalMulticast"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
51
service/network/netutils/tcpassembly.go
Normal file
51
service/network/netutils/tcpassembly.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package netutils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/tcpassembly"
|
||||
)
|
||||
|
||||
// SimpleStreamAssemblerManager is a simple manager for github.com/google/gopacket/tcpassembly.
|
||||
type SimpleStreamAssemblerManager struct {
|
||||
InitLock sync.Mutex
|
||||
lastAssembler *SimpleStreamAssembler
|
||||
}
|
||||
|
||||
// New returns a new stream assembler.
|
||||
func (m *SimpleStreamAssemblerManager) New(net, transport gopacket.Flow) tcpassembly.Stream {
|
||||
assembler := new(SimpleStreamAssembler)
|
||||
m.lastAssembler = assembler
|
||||
return assembler
|
||||
}
|
||||
|
||||
// GetLastAssembler returns the newest created stream assembler.
|
||||
func (m *SimpleStreamAssemblerManager) GetLastAssembler() *SimpleStreamAssembler {
|
||||
return m.lastAssembler
|
||||
}
|
||||
|
||||
// SimpleStreamAssembler is a simple assembler for github.com/google/gopacket/tcpassembly.
|
||||
type SimpleStreamAssembler struct {
|
||||
Cumulated []byte
|
||||
CumulatedLen int
|
||||
Complete bool
|
||||
}
|
||||
|
||||
// NewSimpleStreamAssembler returns a new SimpleStreamAssembler.
|
||||
func NewSimpleStreamAssembler() *SimpleStreamAssembler {
|
||||
return &SimpleStreamAssembler{}
|
||||
}
|
||||
|
||||
// Reassembled implements tcpassembly.Stream's Reassembled function.
|
||||
func (a *SimpleStreamAssembler) Reassembled(reassembly []tcpassembly.Reassembly) {
|
||||
for _, entry := range reassembly {
|
||||
a.Cumulated = append(a.Cumulated, entry.Bytes...)
|
||||
}
|
||||
a.CumulatedLen = len(a.Cumulated)
|
||||
}
|
||||
|
||||
// ReassemblyComplete implements tcpassembly.Stream's ReassemblyComplete function.
|
||||
func (a *SimpleStreamAssembler) ReassemblyComplete() {
|
||||
a.Complete = true
|
||||
}
|
||||
35
service/network/packet/bandwidth.go
Normal file
35
service/network/packet/bandwidth.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package packet
|
||||
|
||||
import "fmt"
|
||||
|
||||
// BandwidthUpdate holds an update to the seen bandwidth of a connection.
|
||||
type BandwidthUpdate struct {
|
||||
ConnID string
|
||||
BytesReceived uint64
|
||||
BytesSent uint64
|
||||
Method BandwidthUpdateMethod
|
||||
}
|
||||
|
||||
// BandwidthUpdateMethod defines how the bandwidth data of a bandwidth update should be interpreted.
|
||||
type BandwidthUpdateMethod uint8
|
||||
|
||||
// Bandwidth Update Methods.
|
||||
const (
|
||||
Absolute BandwidthUpdateMethod = iota
|
||||
Additive
|
||||
)
|
||||
|
||||
func (bu *BandwidthUpdate) String() string {
|
||||
return fmt.Sprintf("%s: %dB recv | %dB sent [%s]", bu.ConnID, bu.BytesReceived, bu.BytesSent, bu.Method)
|
||||
}
|
||||
|
||||
func (bum BandwidthUpdateMethod) String() string {
|
||||
switch bum {
|
||||
case Absolute:
|
||||
return "absolute"
|
||||
case Additive:
|
||||
return "additive"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
117
service/network/packet/const.go
Normal file
117
service/network/packet/const.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Basic Types.
|
||||
type (
|
||||
// IPVersion represents an IP version.
|
||||
IPVersion uint8
|
||||
// IPProtocol represents an IP protocol.
|
||||
IPProtocol uint8
|
||||
// Verdict describes the decision on a packet.
|
||||
Verdict uint8
|
||||
)
|
||||
|
||||
// Basic Constants.
|
||||
const (
|
||||
IPv4 = IPVersion(4)
|
||||
IPv6 = IPVersion(6)
|
||||
|
||||
InBound = true
|
||||
OutBound = false
|
||||
|
||||
ICMP = IPProtocol(1)
|
||||
IGMP = IPProtocol(2)
|
||||
TCP = IPProtocol(6)
|
||||
UDP = IPProtocol(17)
|
||||
ICMPv6 = IPProtocol(58)
|
||||
UDPLite = IPProtocol(136)
|
||||
RAW = IPProtocol(255)
|
||||
|
||||
AnyHostInternalProtocol61 = IPProtocol(61)
|
||||
)
|
||||
|
||||
// Verdicts.
|
||||
const (
|
||||
DROP Verdict = iota
|
||||
BLOCK
|
||||
ACCEPT
|
||||
STOLEN
|
||||
QUEUE
|
||||
REPEAT
|
||||
STOP
|
||||
)
|
||||
|
||||
// ErrFailedToLoadPayload is returned by GetPayload if it failed for an unspecified reason, or is not implemented on the current system.
|
||||
var ErrFailedToLoadPayload = errors.New("could not load packet payload")
|
||||
|
||||
// ByteSize returns the byte size of the ip (IPv4 = 4 bytes, IPv6 = 16).
|
||||
func (v IPVersion) ByteSize() int {
|
||||
switch v {
|
||||
case IPv4:
|
||||
return 4
|
||||
case IPv6:
|
||||
return 16
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// String returns the string representation of the IP version: "IPv4" or "IPv6".
|
||||
func (v IPVersion) String() string {
|
||||
switch v {
|
||||
case IPv4:
|
||||
return "IPv4"
|
||||
case IPv6:
|
||||
return "IPv6"
|
||||
}
|
||||
return fmt.Sprintf("<unknown ip version, %d>", uint8(v))
|
||||
}
|
||||
|
||||
// String returns the string representation (abbreviation) of the protocol.
|
||||
func (p IPProtocol) String() string {
|
||||
switch p {
|
||||
case RAW:
|
||||
return "RAW"
|
||||
case TCP:
|
||||
return "TCP"
|
||||
case UDP:
|
||||
return "UDP"
|
||||
case UDPLite:
|
||||
return "UDPLite"
|
||||
case ICMP:
|
||||
return "ICMP"
|
||||
case ICMPv6:
|
||||
return "ICMPv6"
|
||||
case IGMP:
|
||||
return "IGMP"
|
||||
case AnyHostInternalProtocol61:
|
||||
fallthrough
|
||||
default:
|
||||
return fmt.Sprintf("<unknown protocol, %d>", uint8(p))
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the verdict.
|
||||
func (v Verdict) String() string {
|
||||
switch v {
|
||||
case DROP:
|
||||
return "DROP"
|
||||
case BLOCK:
|
||||
return "BLOCK"
|
||||
case ACCEPT:
|
||||
return "ACCEPT"
|
||||
case STOLEN:
|
||||
return "STOLEN"
|
||||
case QUEUE:
|
||||
return "QUEUE"
|
||||
case REPEAT:
|
||||
return "REPEAT"
|
||||
case STOP:
|
||||
return "STOP"
|
||||
default:
|
||||
return fmt.Sprintf("<unsupported verdict, %d>", uint8(v))
|
||||
}
|
||||
}
|
||||
77
service/network/packet/info_only.go
Normal file
77
service/network/packet/info_only.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// InfoPacket does not represent an actual packet, but only holds metadata.
|
||||
// Implements the packet.Packet interface.
|
||||
type InfoPacket struct {
|
||||
Base
|
||||
}
|
||||
|
||||
// NewInfoPacket returns a new InfoPacket with the given info.
|
||||
func NewInfoPacket(info Info) *InfoPacket {
|
||||
return &InfoPacket{
|
||||
Base{
|
||||
info: info,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// InfoOnly returns whether the packet is informational only and does not
|
||||
// represent an actual packet.
|
||||
func (pkt *InfoPacket) InfoOnly() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// LoadPacketData does nothing on Linux, as data is always fully parsed.
|
||||
func (pkt *InfoPacket) LoadPacketData() error {
|
||||
return fmt.Errorf("%w: info-only packet", ErrFailedToLoadPayload)
|
||||
}
|
||||
|
||||
// ErrInfoOnlyPacket is returned for unsupported operations on an info-only packet.
|
||||
var ErrInfoOnlyPacket = errors.New("info-only packet")
|
||||
|
||||
// Accept does nothing on an info-only packet.
|
||||
func (pkt *InfoPacket) Accept() error {
|
||||
return ErrInfoOnlyPacket
|
||||
}
|
||||
|
||||
// Block does nothing on an info-only packet.
|
||||
func (pkt *InfoPacket) Block() error {
|
||||
return ErrInfoOnlyPacket
|
||||
}
|
||||
|
||||
// Drop does nothing on an info-only packet.
|
||||
func (pkt *InfoPacket) Drop() error {
|
||||
return ErrInfoOnlyPacket
|
||||
}
|
||||
|
||||
// PermanentAccept does nothing on an info-only packet.
|
||||
func (pkt *InfoPacket) PermanentAccept() error {
|
||||
return ErrInfoOnlyPacket
|
||||
}
|
||||
|
||||
// PermanentBlock does nothing on an info-only packet.
|
||||
func (pkt *InfoPacket) PermanentBlock() error {
|
||||
return ErrInfoOnlyPacket
|
||||
}
|
||||
|
||||
// PermanentDrop does nothing on an info-only packet.
|
||||
func (pkt *InfoPacket) PermanentDrop() error {
|
||||
return ErrInfoOnlyPacket
|
||||
}
|
||||
|
||||
// RerouteToNameserver does nothing on an info-only packet.
|
||||
func (pkt *InfoPacket) RerouteToNameserver() error {
|
||||
return ErrInfoOnlyPacket
|
||||
}
|
||||
|
||||
// RerouteToTunnel does nothing on an info-only packet.
|
||||
func (pkt *InfoPacket) RerouteToTunnel() error {
|
||||
return ErrInfoOnlyPacket
|
||||
}
|
||||
|
||||
var _ Packet = &InfoPacket{}
|
||||
266
service/network/packet/packet.go
Normal file
266
service/network/packet/packet.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
)
|
||||
|
||||
// Base is a base structure for satisfying the Packet interface.
|
||||
type Base struct {
|
||||
ctx context.Context
|
||||
info Info
|
||||
connID string
|
||||
layers gopacket.Packet
|
||||
layer3Data []byte
|
||||
layer5Data []byte
|
||||
}
|
||||
|
||||
// FastTrackedByIntegration returns whether the packet has been fast-track
|
||||
// accepted by the OS integration.
|
||||
func (pkt *Base) FastTrackedByIntegration() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// InfoOnly returns whether the packet is informational only and does not
|
||||
// represent an actual packet.
|
||||
func (pkt *Base) InfoOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ExpectInfo returns whether the next packet is expected to be informational only.
|
||||
func (pkt *Base) ExpectInfo() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetCtx sets the packet context.
|
||||
func (pkt *Base) SetCtx(ctx context.Context) {
|
||||
pkt.ctx = ctx
|
||||
}
|
||||
|
||||
// Ctx returns the packet context.
|
||||
func (pkt *Base) Ctx() context.Context {
|
||||
return pkt.ctx
|
||||
}
|
||||
|
||||
// Info returns the packet Info.
|
||||
func (pkt *Base) Info() *Info {
|
||||
return &pkt.info
|
||||
}
|
||||
|
||||
// SetPacketInfo sets a new packet Info. This must only used when initializing the packet structure.
|
||||
func (pkt *Base) SetPacketInfo(packetInfo Info) {
|
||||
pkt.info = packetInfo
|
||||
}
|
||||
|
||||
// SetInbound sets a the packet direction to inbound. This must only used when initializing the packet structure.
|
||||
func (pkt *Base) SetInbound() {
|
||||
pkt.info.Inbound = true
|
||||
}
|
||||
|
||||
// SetOutbound sets a the packet direction to outbound. This must only used when initializing the packet structure.
|
||||
func (pkt *Base) SetOutbound() {
|
||||
pkt.info.Inbound = false
|
||||
}
|
||||
|
||||
// IsInbound checks if the packet is inbound.
|
||||
func (pkt *Base) IsInbound() bool {
|
||||
return pkt.info.Inbound
|
||||
}
|
||||
|
||||
// IsOutbound checks if the packet is outbound.
|
||||
func (pkt *Base) IsOutbound() bool {
|
||||
return !pkt.info.Inbound
|
||||
}
|
||||
|
||||
// HasPorts checks if the packet has a protocol that uses ports.
|
||||
func (pkt *Base) HasPorts() bool {
|
||||
switch pkt.info.Protocol {
|
||||
case TCP:
|
||||
return true
|
||||
case UDP, UDPLite:
|
||||
return true
|
||||
case ICMP, ICMPv6, IGMP, RAW, AnyHostInternalProtocol61:
|
||||
fallthrough
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// LoadPacketData loads packet data from the integration, if not yet done.
|
||||
func (pkt *Base) LoadPacketData() error {
|
||||
return ErrFailedToLoadPayload
|
||||
}
|
||||
|
||||
// Layers returns the parsed layer data.
|
||||
func (pkt *Base) Layers() gopacket.Packet {
|
||||
return pkt.layers
|
||||
}
|
||||
|
||||
// Raw returns the raw Layer 3 Network Data.
|
||||
func (pkt *Base) Raw() []byte {
|
||||
return pkt.layer3Data
|
||||
}
|
||||
|
||||
// Payload returns the raw Layer 5 Network Data.
|
||||
func (pkt *Base) Payload() []byte {
|
||||
return pkt.layer5Data
|
||||
}
|
||||
|
||||
// GetConnectionID returns the link ID for this packet.
|
||||
func (pkt *Base) GetConnectionID() string {
|
||||
if pkt.connID == "" {
|
||||
pkt.connID = pkt.info.CreateConnectionID()
|
||||
}
|
||||
return pkt.connID
|
||||
}
|
||||
|
||||
// MatchesAddress checks if a the packet matches a given endpoint (remote or local) in protocol, network and port.
|
||||
//
|
||||
// Comparison matrix:
|
||||
//
|
||||
// ====== IN OUT
|
||||
//
|
||||
// Local Dst Src
|
||||
// Remote Src Dst
|
||||
// .
|
||||
func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.IPNet, port uint16) bool {
|
||||
if pkt.info.Protocol != protocol {
|
||||
return false
|
||||
}
|
||||
if pkt.info.Inbound != remote {
|
||||
if !network.Contains(pkt.info.Src) {
|
||||
return false
|
||||
}
|
||||
if pkt.info.SrcPort != port {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if !network.Contains(pkt.info.Dst) {
|
||||
return false
|
||||
}
|
||||
if pkt.info.DstPort != port {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// MatchesIP checks if a the packet matches a given endpoint (remote or local) IP.
|
||||
//
|
||||
// Comparison matrix:
|
||||
//
|
||||
// ====== IN OUT
|
||||
//
|
||||
// Local Dst Src
|
||||
// Remote Src Dst
|
||||
// .
|
||||
func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool {
|
||||
if pkt.info.Inbound != endpoint {
|
||||
if network.Contains(pkt.info.Src) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if network.Contains(pkt.info.Dst) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FORMATTING
|
||||
|
||||
func (pkt *Base) String() string {
|
||||
return pkt.FmtPacket()
|
||||
}
|
||||
|
||||
// FmtPacket returns the most important information about the packet as a string.
|
||||
func (pkt *Base) FmtPacket() string {
|
||||
if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP {
|
||||
if pkt.info.Inbound {
|
||||
return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort)
|
||||
}
|
||||
return fmt.Sprintf("OUT %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort)
|
||||
}
|
||||
if pkt.info.Inbound {
|
||||
return fmt.Sprintf("IN %s %s <-> %s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src)
|
||||
}
|
||||
return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst)
|
||||
}
|
||||
|
||||
// FmtProtocol returns the protocol as a string.
|
||||
func (pkt *Base) FmtProtocol() string {
|
||||
return pkt.info.Protocol.String()
|
||||
}
|
||||
|
||||
// FmtRemoteIP returns the remote IP address as a string.
|
||||
func (pkt *Base) FmtRemoteIP() string {
|
||||
if pkt.info.Inbound {
|
||||
return pkt.info.Src.String()
|
||||
}
|
||||
return pkt.info.Dst.String()
|
||||
}
|
||||
|
||||
// FmtRemotePort returns the remote port as a string.
|
||||
func (pkt *Base) FmtRemotePort() string {
|
||||
if pkt.info.SrcPort != 0 {
|
||||
if pkt.info.Inbound {
|
||||
return fmt.Sprintf("%d", pkt.info.SrcPort)
|
||||
}
|
||||
return fmt.Sprintf("%d", pkt.info.DstPort)
|
||||
}
|
||||
return "-"
|
||||
}
|
||||
|
||||
// FmtRemoteAddress returns the full remote address (protocol, IP, port) as a string.
|
||||
func (pkt *Base) FmtRemoteAddress() string {
|
||||
return fmt.Sprintf("%s:%s:%s", pkt.info.Protocol.String(), pkt.FmtRemoteIP(), pkt.FmtRemotePort())
|
||||
}
|
||||
|
||||
// Packet is an interface to a network packet to provide object behavior the same across all systems.
|
||||
type Packet interface {
|
||||
// Verdicts.
|
||||
Accept() error
|
||||
Block() error
|
||||
Drop() error
|
||||
PermanentAccept() error
|
||||
PermanentBlock() error
|
||||
PermanentDrop() error
|
||||
RerouteToNameserver() error
|
||||
RerouteToTunnel() error
|
||||
FastTrackedByIntegration() bool
|
||||
InfoOnly() bool
|
||||
ExpectInfo() bool
|
||||
|
||||
// Info.
|
||||
SetCtx(context.Context)
|
||||
Ctx() context.Context
|
||||
Info() *Info
|
||||
SetPacketInfo(Info)
|
||||
IsInbound() bool
|
||||
IsOutbound() bool
|
||||
SetInbound()
|
||||
SetOutbound()
|
||||
HasPorts() bool
|
||||
GetConnectionID() string
|
||||
|
||||
// Payload.
|
||||
LoadPacketData() error
|
||||
Layers() gopacket.Packet
|
||||
Raw() []byte
|
||||
Payload() []byte
|
||||
|
||||
// Matching.
|
||||
MatchesAddress(bool, IPProtocol, *net.IPNet, uint16) bool
|
||||
MatchesIP(bool, *net.IPNet) bool
|
||||
|
||||
// Formatting.
|
||||
String() string
|
||||
FmtPacket() string
|
||||
FmtProtocol() string
|
||||
FmtRemoteIP() string
|
||||
FmtRemotePort() string
|
||||
FmtRemoteAddress() string
|
||||
}
|
||||
76
service/network/packet/packetinfo.go
Normal file
76
service/network/packet/packetinfo.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Info holds IP and TCP/UDP header information.
|
||||
type Info struct {
|
||||
Inbound bool
|
||||
InTunnel bool
|
||||
|
||||
Version IPVersion
|
||||
Protocol IPProtocol
|
||||
SrcPort, DstPort uint16
|
||||
Src, Dst net.IP
|
||||
|
||||
PID int
|
||||
SeenAt time.Time
|
||||
}
|
||||
|
||||
// LocalIP returns the local IP of the packet.
|
||||
func (pi *Info) LocalIP() net.IP {
|
||||
if pi.Inbound {
|
||||
return pi.Dst
|
||||
}
|
||||
return pi.Src
|
||||
}
|
||||
|
||||
// RemoteIP returns the remote IP of the packet.
|
||||
func (pi *Info) RemoteIP() net.IP {
|
||||
if pi.Inbound {
|
||||
return pi.Src
|
||||
}
|
||||
return pi.Dst
|
||||
}
|
||||
|
||||
// LocalPort returns the local port of the packet.
|
||||
func (pi *Info) LocalPort() uint16 {
|
||||
if pi.Inbound {
|
||||
return pi.DstPort
|
||||
}
|
||||
return pi.SrcPort
|
||||
}
|
||||
|
||||
// RemotePort returns the remote port of the packet.
|
||||
func (pi *Info) RemotePort() uint16 {
|
||||
if pi.Inbound {
|
||||
return pi.SrcPort
|
||||
}
|
||||
return pi.DstPort
|
||||
}
|
||||
|
||||
// CreateConnectionID creates a connection ID.
|
||||
// In most circumstances, this method should not be used directly, but
|
||||
// packet.GetConnectionID() should be called instead.
|
||||
func (pi *Info) CreateConnectionID() string {
|
||||
return CreateConnectionID(pi.Protocol, pi.Src, pi.SrcPort, pi.Dst, pi.DstPort, pi.Inbound)
|
||||
}
|
||||
|
||||
// CreateConnectionID creates a connection ID.
|
||||
func CreateConnectionID(protocol IPProtocol, src net.IP, srcPort uint16, dst net.IP, dstPort uint16, inbound bool) string {
|
||||
// TODO: make this ID not depend on the packet direction for better support for forwarded packets.
|
||||
if protocol == TCP || protocol == UDP {
|
||||
if inbound {
|
||||
return fmt.Sprintf("%d-%s-%d-%s-%d", protocol, dst, dstPort, src, srcPort)
|
||||
}
|
||||
return fmt.Sprintf("%d-%s-%d-%s-%d", protocol, src, srcPort, dst, dstPort)
|
||||
}
|
||||
|
||||
if inbound {
|
||||
return fmt.Sprintf("%d-%s-%s", protocol, dst, src)
|
||||
}
|
||||
return fmt.Sprintf("%d-%s-%s", protocol, src, dst)
|
||||
}
|
||||
160
service/network/packet/parse.go
Normal file
160
service/network/packet/parse.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
var layerType2IPProtocol map[gopacket.LayerType]IPProtocol
|
||||
|
||||
func genIPProtocolFromLayerType() {
|
||||
layerType2IPProtocol = make(map[gopacket.LayerType]IPProtocol)
|
||||
for k, v := range layers.IPProtocolMetadata {
|
||||
layerType2IPProtocol[v.LayerType] = IPProtocol(k)
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPv4(packet gopacket.Packet, info *Info) error {
|
||||
if ipv4, ok := packet.NetworkLayer().(*layers.IPv4); ok {
|
||||
info.Version = IPv4
|
||||
info.Src = ipv4.SrcIP
|
||||
info.Dst = ipv4.DstIP
|
||||
info.Protocol = IPProtocol(ipv4.Protocol)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseIPv6(packet gopacket.Packet, info *Info) error {
|
||||
if ipv6, ok := packet.NetworkLayer().(*layers.IPv6); ok {
|
||||
info.Version = IPv6
|
||||
info.Src = ipv6.SrcIP
|
||||
info.Dst = ipv6.DstIP
|
||||
// we set Protocol to NextHeader as a fallback. If TCP or
|
||||
// UDP layers are detected (somewhere in the list of options)
|
||||
// the Protocol field is adjusted correctly.
|
||||
info.Protocol = IPProtocol(ipv6.NextHeader)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseTCP(packet gopacket.Packet, info *Info) error {
|
||||
if tcp, ok := packet.TransportLayer().(*layers.TCP); ok {
|
||||
info.Protocol = TCP
|
||||
info.SrcPort = uint16(tcp.SrcPort)
|
||||
info.DstPort = uint16(tcp.DstPort)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseUDP(packet gopacket.Packet, info *Info) error {
|
||||
if udp, ok := packet.TransportLayer().(*layers.UDP); ok {
|
||||
info.Protocol = UDP
|
||||
info.SrcPort = uint16(udp.SrcPort)
|
||||
info.DstPort = uint16(udp.DstPort)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
func parseUDPLite(packet gopacket.Packet, info *Info) error {
|
||||
if udpLite, ok := packet.TransportLayer().(*layers.UDPLite); ok {
|
||||
info.Protocol = UDPLite
|
||||
info.SrcPort = uint16(udpLite.SrcPort)
|
||||
info.DstPort = uint16(udpLite.DstPort)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
|
||||
func parseICMPv4(packet gopacket.Packet, info *Info) error {
|
||||
if icmp, ok := packet.Layer(layers.LayerTypeICMPv4).(*layers.ICMPv4); ok {
|
||||
info.Protocol = ICMP
|
||||
_ = icmp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseICMPv6(packet gopacket.Packet, info *Info) error {
|
||||
if icmp6, ok := packet.Layer(layers.LayerTypeICMPv6).(*layers.ICMPv6); ok {
|
||||
info.Protocol = ICMPv6
|
||||
_ = icmp6
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseIGMP(packet gopacket.Packet, info *Info) error {
|
||||
// gopacket uses LayerTypeIGMP for v1, v2 and v3 and may thus
|
||||
// either return layers.IGMP or layers.IGMPv1or2
|
||||
if layer := packet.Layer(layers.LayerTypeIGMP); layer != nil {
|
||||
info.Protocol = IGMP
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkError(packet gopacket.Packet, info *Info) error {
|
||||
// Check for known unparseable before checking the error layer.
|
||||
if info.Protocol == AnyHostInternalProtocol61 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := packet.ErrorLayer(); err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse parses an IP packet and saves the information in the given packet object.
|
||||
func Parse(packetData []byte, pktBase *Base) (err error) {
|
||||
if len(packetData) == 0 {
|
||||
return errors.New("empty packet")
|
||||
}
|
||||
pktBase.layer3Data = packetData
|
||||
|
||||
ipVersion := packetData[0] >> 4
|
||||
var networkLayerType gopacket.LayerType
|
||||
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
networkLayerType = layers.LayerTypeIPv4
|
||||
case 6:
|
||||
networkLayerType = layers.LayerTypeIPv6
|
||||
default:
|
||||
return fmt.Errorf("unknown IP version or network protocol: %02x", ipVersion)
|
||||
}
|
||||
|
||||
packet := gopacket.NewPacket(packetData, networkLayerType, gopacket.DecodeOptions{
|
||||
Lazy: true,
|
||||
NoCopy: true,
|
||||
})
|
||||
|
||||
availableDecoders := []func(gopacket.Packet, *Info) error{
|
||||
parseIPv4,
|
||||
parseIPv6,
|
||||
parseTCP,
|
||||
parseUDP,
|
||||
// parseUDPLite, // We don't yet support udplite.
|
||||
parseICMPv4,
|
||||
parseICMPv6,
|
||||
parseIGMP,
|
||||
checkError,
|
||||
}
|
||||
|
||||
for _, dec := range availableDecoders {
|
||||
if err := dec(packet, pktBase.Info()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
pktBase.layers = packet
|
||||
if transport := packet.TransportLayer(); transport != nil {
|
||||
pktBase.layer5Data = transport.LayerPayload()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
genIPProtocolFromLayerType()
|
||||
}
|
||||
57
service/network/ports.go
Normal file
57
service/network/ports.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/rng"
|
||||
)
|
||||
|
||||
// GetUnusedLocalPort returns a local port of the specified protocol that is
|
||||
// currently unused and is unlikely to be used within the next seconds.
|
||||
func GetUnusedLocalPort(protocol uint8) (port uint16, ok bool) {
|
||||
allConns := conns.clone()
|
||||
tries := 1000
|
||||
|
||||
// Try up to 1000 times to find an unused port.
|
||||
nextPort:
|
||||
for i := 0; i < tries; i++ {
|
||||
// Generate random port between 10000 and 65535
|
||||
rN, err := rng.Number(55535)
|
||||
if err != nil {
|
||||
log.Warningf("network: failed to generate random port: %s", err)
|
||||
return 0, false
|
||||
}
|
||||
port := uint16(rN + 10000)
|
||||
|
||||
// Shrink range when we chew through the tries.
|
||||
portRangeStart := port - 10
|
||||
|
||||
// Check if the generated port is unused.
|
||||
nextConnection:
|
||||
for _, conn := range allConns {
|
||||
switch {
|
||||
case !conn.DataIsComplete():
|
||||
// Skip connection if the data is not complete.
|
||||
continue nextConnection
|
||||
|
||||
case conn.Entity.Protocol != protocol:
|
||||
// Skip connection if the protocol does not match the protocol of interest.
|
||||
continue nextConnection
|
||||
|
||||
case conn.LocalPort <= port && conn.LocalPort >= portRangeStart:
|
||||
// Skip port if the local port is in dangerous proximity.
|
||||
// Consecutive port numbers are very common.
|
||||
continue nextPort
|
||||
}
|
||||
}
|
||||
|
||||
// Log if it took more than 10 attempts.
|
||||
if i >= 10 {
|
||||
log.Warningf("network: took %d attempts to find a suitable unused port for pre-auth", i+1)
|
||||
}
|
||||
|
||||
// The checks have passed. We have found a good unused port.
|
||||
return port, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
107
service/network/proc/findpid.go
Normal file
107
service/network/proc/findpid.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
// GetPID returns the already existing pid of the given socket info or searches for it.
|
||||
// This also acts as a getter for socket.Info.PID, as locking for that occurs here.
|
||||
func GetPID(socketInfo socket.Info) (pid int) {
|
||||
// Get currently assigned PID to the socket info.
|
||||
currentPid := socketInfo.GetPID()
|
||||
|
||||
// If the current PID already is valid (ie. not unidentified), return it immediately.
|
||||
if currentPid != socket.UndefinedProcessID {
|
||||
return currentPid
|
||||
}
|
||||
|
||||
// Find PID for the given UID and inode.
|
||||
pid = findPID(socketInfo.GetUIDandInode())
|
||||
|
||||
// Set the newly found PID on the socket info.
|
||||
socketInfo.SetPID(pid)
|
||||
|
||||
// Return found PID.
|
||||
return pid
|
||||
}
|
||||
|
||||
// findPID returns the pid of the given uid and socket inode.
|
||||
func findPID(uid, inode int) (pid int) {
|
||||
socketName := "socket:[" + strconv.Itoa(inode) + "]"
|
||||
|
||||
// Always update pid table (it has a call limiter anyway)
|
||||
updatePids()
|
||||
|
||||
// Get all pids for the given uid.
|
||||
pids, ok := getPidsByUser(uid)
|
||||
if !ok {
|
||||
return socket.UndefinedProcessID
|
||||
}
|
||||
|
||||
// Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to
|
||||
// be searched for.
|
||||
for j := len(pids) - 1; j >= 0; j-- {
|
||||
if pidHasSocket(pids[j], socketName) {
|
||||
return pids[j]
|
||||
}
|
||||
}
|
||||
|
||||
return socket.UndefinedProcessID
|
||||
}
|
||||
|
||||
func pidHasSocket(pid int, socketName string) bool {
|
||||
socketBase := "/proc/" + strconv.Itoa(pid) + "/fd"
|
||||
entries := readDirNames(socketBase)
|
||||
if len(entries) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
socketBase += "/"
|
||||
// Look through the FDs in reverse order, because higher/newer FDs will be
|
||||
// more likely to be searched for.
|
||||
for i := len(entries) - 1; i >= 0; i-- {
|
||||
link, err := os.Readlink(socketBase + entries[i])
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
log.Warningf("proc: failed to read link /proc/%d/fd/%s: %s", pid, entries[i], err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if link == socketName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// readDirNames only reads the directory names. Using os.ReadDir() would call `lstat` on every
|
||||
// resulting directory name, which we don't need. This function will be called a lot, so we should
|
||||
// refrain from unnecessary work.
|
||||
func readDirNames(dir string) (names []string) {
|
||||
file, err := os.Open(dir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
log.Warningf("proc: could not open directory %s: %s", dir, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
|
||||
names, err = file.Readdirnames(0)
|
||||
if err != nil {
|
||||
log.Warningf("proc: could not get entries from directory %s: %s", dir, err)
|
||||
return []string{}
|
||||
}
|
||||
return
|
||||
}
|
||||
81
service/network/proc/pids_by_user.go
Normal file
81
service/network/proc/pids_by_user.go
Normal file
@@ -0,0 +1,81 @@
|
||||
//go:build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
// pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs.
|
||||
pidsByUser = make(map[int][]int)
|
||||
pidsByUserLock sync.RWMutex
|
||||
fetchPidsByUser = utils.NewCallLimiter(10 * time.Millisecond)
|
||||
)
|
||||
|
||||
// getPidsByUser returns the cached PIDs for the given UID.
|
||||
func getPidsByUser(uid int) (pids []int, ok bool) {
|
||||
pidsByUserLock.RLock()
|
||||
defer pidsByUserLock.RUnlock()
|
||||
|
||||
pids, ok = pidsByUser[uid]
|
||||
return
|
||||
}
|
||||
|
||||
// updatePids fetches and creates a new pidsByUser map using a call limiter.
|
||||
func updatePids() {
|
||||
fetchPidsByUser.Do(func() {
|
||||
newPidsByUser := make(map[int][]int)
|
||||
pidCnt := 0
|
||||
|
||||
entries := readDirNames("/proc")
|
||||
if len(entries) == 0 {
|
||||
log.Warning("proc: found no PIDs in /proc")
|
||||
return
|
||||
}
|
||||
|
||||
entryLoop:
|
||||
for _, entry := range entries {
|
||||
pid, err := strconv.ParseInt(entry, 10, 32)
|
||||
if err != nil {
|
||||
continue entryLoop
|
||||
}
|
||||
|
||||
statData, err := os.Stat("/proc/" + strconv.FormatInt(pid, 10))
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
log.Warningf("proc: could not stat /proc/%d: %s", pid, err)
|
||||
}
|
||||
continue entryLoop
|
||||
}
|
||||
sys, ok := statData.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
log.Warningf("proc: unable to parse /proc/%d: wrong type", pid)
|
||||
continue entryLoop
|
||||
}
|
||||
|
||||
pids, ok := newPidsByUser[int(sys.Uid)]
|
||||
if ok {
|
||||
newPidsByUser[int(sys.Uid)] = append(pids, int(pid))
|
||||
} else {
|
||||
newPidsByUser[int(sys.Uid)] = []int{int(pid)}
|
||||
}
|
||||
pidCnt++
|
||||
}
|
||||
|
||||
// log.Tracef("proc: updated PID table with %d entries", pidCnt)
|
||||
|
||||
pidsByUserLock.Lock()
|
||||
defer pidsByUserLock.Unlock()
|
||||
pidsByUser = newPidsByUser
|
||||
})
|
||||
}
|
||||
244
service/network/proc/tables.go
Normal file
244
service/network/proc/tables.go
Normal file
@@ -0,0 +1,244 @@
|
||||
//go:build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
1. find socket inode
|
||||
- by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP?
|
||||
- /proc/net/{tcp|udp}[6]
|
||||
|
||||
2. get list of processes of uid
|
||||
|
||||
3. find socket inode in process fds
|
||||
- if not found, refresh map of uid->pids
|
||||
- if not found, check ALL pids: maybe euid != uid
|
||||
|
||||
4. gather process info
|
||||
|
||||
Cache every step!
|
||||
|
||||
*/
|
||||
|
||||
// Network Related Constants.
|
||||
const (
|
||||
TCP4 uint8 = iota
|
||||
UDP4
|
||||
TCP6
|
||||
UDP6
|
||||
ICMP4
|
||||
ICMP6
|
||||
|
||||
tcp4ProcFile = "/proc/net/tcp"
|
||||
tcp6ProcFile = "/proc/net/tcp6"
|
||||
udp4ProcFile = "/proc/net/udp"
|
||||
udp6ProcFile = "/proc/net/udp6"
|
||||
|
||||
tcpListenStateHex = "0A"
|
||||
)
|
||||
|
||||
// GetTCP4Table returns the system table for IPv4 TCP activity.
|
||||
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
return getTableFromSource(TCP4, tcp4ProcFile)
|
||||
}
|
||||
|
||||
// GetTCP6Table returns the system table for IPv6 TCP activity.
|
||||
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
return getTableFromSource(TCP6, tcp6ProcFile)
|
||||
}
|
||||
|
||||
// GetUDP4Table returns the system table for IPv4 UDP activity.
|
||||
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
|
||||
_, binds, err = getTableFromSource(UDP4, udp4ProcFile)
|
||||
return
|
||||
}
|
||||
|
||||
// GetUDP6Table returns the system table for IPv6 UDP activity.
|
||||
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
|
||||
_, binds, err = getTableFromSource(UDP6, udp6ProcFile)
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
// hint: we split fields by multiple delimiters, see procDelimiter
|
||||
fieldIndexLocalIP = 1
|
||||
fieldIndexLocalPort = 2
|
||||
fieldIndexRemoteIP = 3
|
||||
fieldIndexRemotePort = 4
|
||||
fieldIndexUID = 11
|
||||
fieldIndexInode = 13
|
||||
)
|
||||
|
||||
func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
|
||||
var ipConverter func(string) net.IP
|
||||
switch stack {
|
||||
case TCP4, UDP4:
|
||||
ipConverter = convertIPv4
|
||||
case TCP6, UDP6:
|
||||
ipConverter = convertIPv6
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported table stack: %d", stack)
|
||||
}
|
||||
|
||||
// open file
|
||||
socketData, err := os.Open(procFile)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = socketData.Close()
|
||||
}()
|
||||
|
||||
// file scanner
|
||||
scanner := bufio.NewScanner(socketData)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
// parse
|
||||
scanner.Scan() // skip first row
|
||||
for scanner.Scan() {
|
||||
fields := strings.FieldsFunc(scanner.Text(), procDelimiter)
|
||||
if len(fields) < 14 {
|
||||
// log.Tracef("proc: too short: %s", fields)
|
||||
continue
|
||||
}
|
||||
|
||||
localIP := ipConverter(fields[fieldIndexLocalIP])
|
||||
if localIP == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16)
|
||||
if err != nil {
|
||||
log.Warningf("proc: could not parse port: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32)
|
||||
// log.Tracef("uid: %s", fields[fieldIndexUID])
|
||||
if err != nil {
|
||||
log.Warningf("proc: could not parse uid %s: %s", fields[11], err)
|
||||
continue
|
||||
}
|
||||
|
||||
inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32)
|
||||
// log.Tracef("inode: %s", fields[fieldIndexInode])
|
||||
if err != nil {
|
||||
log.Warningf("proc: could not parse inode %s: %s", fields[13], err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch stack {
|
||||
case UDP4, UDP6:
|
||||
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: localIP,
|
||||
Port: uint16(localPort),
|
||||
},
|
||||
PID: socket.UndefinedProcessID,
|
||||
UID: int(uid),
|
||||
Inode: int(inode),
|
||||
})
|
||||
|
||||
case TCP4, TCP6:
|
||||
|
||||
if fields[5] == tcpListenStateHex {
|
||||
// listener
|
||||
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: localIP,
|
||||
Port: uint16(localPort),
|
||||
},
|
||||
PID: socket.UndefinedProcessID,
|
||||
UID: int(uid),
|
||||
Inode: int(inode),
|
||||
})
|
||||
} else {
|
||||
// connection
|
||||
|
||||
remoteIP := ipConverter(fields[fieldIndexRemoteIP])
|
||||
if remoteIP == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16)
|
||||
if err != nil {
|
||||
log.Warningf("proc: could not parse port: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
connections = append(connections, &socket.ConnectionInfo{
|
||||
Local: socket.Address{
|
||||
IP: localIP,
|
||||
Port: uint16(localPort),
|
||||
},
|
||||
Remote: socket.Address{
|
||||
IP: remoteIP,
|
||||
Port: uint16(remotePort),
|
||||
},
|
||||
PID: socket.UndefinedProcessID,
|
||||
UID: int(uid),
|
||||
Inode: int(inode),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return connections, binds, nil
|
||||
}
|
||||
|
||||
func procDelimiter(c rune) bool {
|
||||
return unicode.IsSpace(c) || c == ':'
|
||||
}
|
||||
|
||||
func convertIPv4(data string) net.IP {
|
||||
// Decode and bullshit check the data length.
|
||||
decoded, err := hex.DecodeString(data)
|
||||
if err != nil {
|
||||
log.Warningf("proc: could not parse IPv4 %s: %s", data, err)
|
||||
return nil
|
||||
}
|
||||
if len(decoded) != 4 {
|
||||
log.Warningf("proc: decoded IPv4 %s has wrong length", decoded)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build the IPv4 address with the reversed byte order.
|
||||
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
|
||||
return ip
|
||||
}
|
||||
|
||||
func convertIPv6(data string) net.IP {
|
||||
// Decode and bullshit check the data length.
|
||||
decoded, err := hex.DecodeString(data)
|
||||
if err != nil {
|
||||
log.Warningf("proc: could not parse IPv6 %s: %s", data, err)
|
||||
return nil
|
||||
}
|
||||
if len(decoded) != 16 {
|
||||
log.Warningf("proc: decoded IPv6 %s has wrong length", decoded)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build the IPv6 address with the translated byte order.
|
||||
for i := 0; i < 16; i += 4 {
|
||||
decoded[i], decoded[i+1], decoded[i+2], decoded[i+3] = decoded[i+3], decoded[i+2], decoded[i+1], decoded[i]
|
||||
}
|
||||
ip := net.IP(decoded)
|
||||
return ip
|
||||
}
|
||||
62
service/network/proc/tables_test.go
Normal file
62
service/network/proc/tables_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSockets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connections, listeners, err := GetTCP4Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nTCP 4 connections:")
|
||||
for _, connection := range connections {
|
||||
pid := GetPID(connection)
|
||||
fmt.Printf("%d: %+v\n", pid, connection)
|
||||
}
|
||||
fmt.Println("\nTCP 4 listeners:")
|
||||
for _, listener := range listeners {
|
||||
pid := GetPID(listener)
|
||||
fmt.Printf("%d: %+v\n", pid, listener)
|
||||
}
|
||||
|
||||
connections, listeners, err = GetTCP6Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nTCP 6 connections:")
|
||||
for _, connection := range connections {
|
||||
pid := GetPID(connection)
|
||||
fmt.Printf("%d: %+v\n", pid, connection)
|
||||
}
|
||||
fmt.Println("\nTCP 6 listeners:")
|
||||
for _, listener := range listeners {
|
||||
pid := GetPID(listener)
|
||||
fmt.Printf("%d: %+v\n", pid, listener)
|
||||
}
|
||||
|
||||
binds, err := GetUDP4Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nUDP 4 binds:")
|
||||
for _, bind := range binds {
|
||||
pid := GetPID(bind)
|
||||
fmt.Printf("%d: %+v\n", pid, bind)
|
||||
}
|
||||
|
||||
binds, err = GetUDP6Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nUDP 6 binds:")
|
||||
for _, bind := range binds {
|
||||
pid := GetPID(bind)
|
||||
fmt.Printf("%d: %+v\n", pid, bind)
|
||||
}
|
||||
}
|
||||
86
service/network/reference/ports.go
Normal file
86
service/network/reference/ports.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package reference
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
portNames = map[uint16]string{
|
||||
20: "FTP-DATA",
|
||||
21: "FTP",
|
||||
22: "SSH",
|
||||
23: "TELNET",
|
||||
25: "SMTP",
|
||||
43: "WHOIS",
|
||||
53: "DNS",
|
||||
67: "DHCP_SERVER",
|
||||
68: "DHCP_CLIENT",
|
||||
69: "TFTP",
|
||||
80: "HTTP",
|
||||
110: "POP3",
|
||||
123: "NTP",
|
||||
143: "IMAP",
|
||||
161: "SNMP",
|
||||
179: "BGP",
|
||||
194: "IRC",
|
||||
389: "LDAP",
|
||||
443: "HTTPS",
|
||||
445: "SMB",
|
||||
587: "SMTP_ALT",
|
||||
465: "SMTP_SSL",
|
||||
993: "IMAP_SSL",
|
||||
995: "POP3_SSL",
|
||||
}
|
||||
|
||||
portNumbers = map[string]uint16{
|
||||
"FTP-DATA": 20,
|
||||
"FTP": 21,
|
||||
"SSH": 22,
|
||||
"TELNET": 23,
|
||||
"SMTP": 25,
|
||||
"WHOIS": 43,
|
||||
"DNS": 53,
|
||||
"DHCP-SERVER": 67,
|
||||
"DHCP_SERVER": 67,
|
||||
"DHCP-CLIENT": 68,
|
||||
"DHCP_CLIENT": 68,
|
||||
"TFTP": 69,
|
||||
"HTTP": 80,
|
||||
"POP3": 110,
|
||||
"NTP": 123,
|
||||
"IMAP": 143,
|
||||
"SNMP": 161,
|
||||
"BGP": 179,
|
||||
"IRC": 194,
|
||||
"LDAP": 389,
|
||||
"HTTPS": 443,
|
||||
"SMB": 445,
|
||||
"SMTP-ALT": 587,
|
||||
"SMTP_ALT": 587,
|
||||
"SMTP-SSL": 465,
|
||||
"SMTP_SSL": 465,
|
||||
"IMAP-SSL": 993,
|
||||
"IMAP_SSL": 993,
|
||||
"POP3-SSL": 995,
|
||||
"POP3_SSL": 995,
|
||||
}
|
||||
)
|
||||
|
||||
// GetPortName returns the name of a port number.
|
||||
func GetPortName(port uint16) (name string) {
|
||||
name, ok := portNames[port]
|
||||
if ok {
|
||||
return name
|
||||
}
|
||||
return strconv.Itoa(int(port))
|
||||
}
|
||||
|
||||
// GetPortNumber returns the number of a port name.
|
||||
func GetPortNumber(port string) (number uint16, ok bool) {
|
||||
number, ok = portNumbers[strings.ToUpper(port)]
|
||||
if ok {
|
||||
return number, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
75
service/network/reference/protocols.go
Normal file
75
service/network/reference/protocols.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package reference
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
protocolNames = map[uint8]string{
|
||||
1: "ICMP",
|
||||
2: "IGMP",
|
||||
6: "TCP",
|
||||
17: "UDP",
|
||||
27: "RDP",
|
||||
58: "ICMP6",
|
||||
33: "DCCP",
|
||||
136: "UDP-LITE",
|
||||
}
|
||||
|
||||
protocolNumbers = map[string]uint8{
|
||||
"ICMP": 1,
|
||||
"IGMP": 2,
|
||||
"TCP": 6,
|
||||
"UDP": 17,
|
||||
"RDP": 27,
|
||||
"DCCP": 33,
|
||||
"ICMP6": 58,
|
||||
"UDP-LITE": 136,
|
||||
}
|
||||
)
|
||||
|
||||
// GetProtocolName returns the name of a IP protocol number.
|
||||
func GetProtocolName(protocol uint8) (name string) {
|
||||
name, ok := protocolNames[protocol]
|
||||
if ok {
|
||||
return name
|
||||
}
|
||||
return strconv.Itoa(int(protocol))
|
||||
}
|
||||
|
||||
// GetProtocolNumber returns the number of a IP protocol name.
|
||||
func GetProtocolNumber(protocol string) (number uint8, ok bool) {
|
||||
number, ok = protocolNumbers[strings.ToUpper(protocol)]
|
||||
if ok {
|
||||
return number, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// IsPacketProtocol returns whether the given protocol number is a known packet based protocol.
|
||||
// Note: Not fully complete. Calling IsPacketProtocol() does not equal calling !IsStreamProtocol().
|
||||
func IsPacketProtocol(protocol uint8) bool {
|
||||
switch protocol {
|
||||
case 1, // ICMP
|
||||
17, // UDP
|
||||
27, // RDP
|
||||
58, // ICMP6
|
||||
33, // DCCP
|
||||
136: // UDP-LITE
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsStreamProtocol returns whether the given protocol number is a known stream based protocol.
|
||||
// Note: Not fully complete. Calling IsPacketProtocol() does not equal calling !IsStreamProtocol().
|
||||
func IsStreamProtocol(protocol uint8) bool {
|
||||
switch protocol {
|
||||
case 6: // TCP
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
120
service/network/socket/socket.go
Normal file
120
service/network/socket/socket.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package socket
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// UndefinedProcessID signifies that the process ID is unknown.
|
||||
// It must match portmaster/process.UndefinedProcessID
|
||||
// It is duplicated here because of import loops.
|
||||
UndefinedProcessID = -1
|
||||
)
|
||||
|
||||
// ConnectionInfo holds socket information returned by the system.
|
||||
type ConnectionInfo struct {
|
||||
sync.Mutex
|
||||
|
||||
Local Address
|
||||
Remote Address
|
||||
PID int
|
||||
UID int
|
||||
Inode int
|
||||
}
|
||||
|
||||
// BindInfo holds socket information returned by the system.
|
||||
type BindInfo struct {
|
||||
sync.Mutex
|
||||
|
||||
Local Address
|
||||
PID int
|
||||
UID int
|
||||
Inode int
|
||||
|
||||
ListensAny bool
|
||||
}
|
||||
|
||||
// Address is an IP + Port pair.
|
||||
type Address struct {
|
||||
IP net.IP
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// Info is a generic interface to both ConnectionInfo and BindInfo.
|
||||
type Info interface {
|
||||
GetPID() int
|
||||
SetPID(int)
|
||||
GetUID() int
|
||||
GetUIDandInode() (int, int)
|
||||
}
|
||||
|
||||
// GetPID returns the PID.
|
||||
func (i *ConnectionInfo) GetPID() int {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.PID
|
||||
}
|
||||
|
||||
// SetPID sets the PID to the given value.
|
||||
func (i *ConnectionInfo) SetPID(pid int) {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
i.PID = pid
|
||||
}
|
||||
|
||||
// GetUID returns the UID.
|
||||
func (i *ConnectionInfo) GetUID() int {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.UID
|
||||
}
|
||||
|
||||
// GetUIDandInode returns the UID and Inode.
|
||||
func (i *ConnectionInfo) GetUIDandInode() (int, int) {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.UID, i.Inode
|
||||
}
|
||||
|
||||
// GetPID returns the PID.
|
||||
func (i *BindInfo) GetPID() int {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.PID
|
||||
}
|
||||
|
||||
// SetPID sets the PID to the given value.
|
||||
func (i *BindInfo) SetPID(pid int) {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
i.PID = pid
|
||||
}
|
||||
|
||||
// GetUID returns the UID.
|
||||
func (i *BindInfo) GetUID() int {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.UID
|
||||
}
|
||||
|
||||
// GetUIDandInode returns the UID and Inode.
|
||||
func (i *BindInfo) GetUIDandInode() (int, int) {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.UID, i.Inode
|
||||
}
|
||||
|
||||
// Compile time checks.
|
||||
var (
|
||||
_ Info = new(ConnectionInfo)
|
||||
_ Info = new(BindInfo)
|
||||
)
|
||||
102
service/network/state/exists.go
Normal file
102
service/network/state/exists.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
const (
|
||||
// UDPConnectionTTL defines the duration after which unseen UDP connections are regarded as ended.
|
||||
UDPConnectionTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// Exists checks if the given connection is present in the system state tables.
|
||||
func Exists(pktInfo *packet.Info, now time.Time) (exists bool) {
|
||||
// TODO: create lookup maps before running a flurry of Exists() checks.
|
||||
|
||||
switch {
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
|
||||
return tcp4Table.exists(pktInfo)
|
||||
|
||||
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP:
|
||||
return tcp6Table.exists(pktInfo)
|
||||
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP:
|
||||
return udp4Table.exists(pktInfo, now)
|
||||
|
||||
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP:
|
||||
return udp6Table.exists(pktInfo, now)
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (table *tcpTable) exists(pktInfo *packet.Info) (exists bool) {
|
||||
// Update tables if older than the connection that is checked.
|
||||
if table.lastUpdateAt.Load() < pktInfo.SeenAt.UnixNano() {
|
||||
table.updateTables()
|
||||
}
|
||||
|
||||
table.lock.RLock()
|
||||
defer table.lock.RUnlock()
|
||||
|
||||
localIP := pktInfo.LocalIP()
|
||||
localPort := pktInfo.LocalPort()
|
||||
remoteIP := pktInfo.RemoteIP()
|
||||
remotePort := pktInfo.RemotePort()
|
||||
|
||||
// search connections
|
||||
for _, socketInfo := range table.connections {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
remotePort == socketInfo.Remote.Port &&
|
||||
remoteIP.Equal(socketInfo.Remote.IP) &&
|
||||
localIP.Equal(socketInfo.Local.IP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (table *udpTable) exists(pktInfo *packet.Info, now time.Time) (exists bool) {
|
||||
// Update tables if older than the connection that is checked.
|
||||
if table.lastUpdateAt.Load() < pktInfo.SeenAt.UnixNano() {
|
||||
table.updateTables()
|
||||
}
|
||||
|
||||
table.lock.RLock()
|
||||
defer table.lock.RUnlock()
|
||||
|
||||
localIP := pktInfo.LocalIP()
|
||||
localPort := pktInfo.LocalPort()
|
||||
remoteIP := pktInfo.RemoteIP()
|
||||
remotePort := pktInfo.RemotePort()
|
||||
|
||||
connThreshhold := now.Add(-UDPConnectionTTL)
|
||||
|
||||
// search binds
|
||||
for _, socketInfo := range table.binds {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
|
||||
|
||||
udpConnState, ok := table.getConnState(socketInfo, socket.Address{
|
||||
IP: remoteIP,
|
||||
Port: remotePort,
|
||||
})
|
||||
switch {
|
||||
case !ok:
|
||||
return false
|
||||
case udpConnState.lastSeen.After(connThreshhold):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
38
service/network/state/info.go
Normal file
38
service/network/state/info.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
// Info holds network state information as provided by the system.
|
||||
type Info struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
TCP4Connections []*socket.ConnectionInfo
|
||||
TCP4Listeners []*socket.BindInfo
|
||||
TCP6Connections []*socket.ConnectionInfo
|
||||
TCP6Listeners []*socket.BindInfo
|
||||
UDP4Binds []*socket.BindInfo
|
||||
UDP6Binds []*socket.BindInfo
|
||||
}
|
||||
|
||||
// GetInfo returns all system state tables. The returned data must not be modified.
|
||||
func GetInfo() *Info {
|
||||
info := &Info{}
|
||||
|
||||
info.TCP4Connections, info.TCP4Listeners = tcp4Table.updateTables()
|
||||
info.UDP4Binds = udp4Table.updateTables()
|
||||
|
||||
if netenv.IPv6Enabled() {
|
||||
info.TCP6Connections, info.TCP6Listeners = tcp6Table.updateTables()
|
||||
info.UDP6Binds = udp6Table.updateTables()
|
||||
}
|
||||
|
||||
info.UpdateMeta()
|
||||
return info
|
||||
}
|
||||
264
service/network/state/lookup.go
Normal file
264
service/network/state/lookup.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
// - TCP
|
||||
// - Outbound: Match listeners (in!), then connections (out!)
|
||||
// - Inbound: Match listeners (in!), then connections (out!)
|
||||
// - Clean via connections
|
||||
// - UDP
|
||||
// - Any connection: match specific local address or zero IP
|
||||
// - In or out: save direction of first packet:
|
||||
// - map[<local udp bind ip+port>]map[<remote ip+port>]{direction, lastSeen}
|
||||
// - only clean if <local udp bind ip+port> is removed by OS
|
||||
// - limit <remote ip+port> to 256 entries?
|
||||
// - clean <remote ip+port> after 72hrs?
|
||||
// - switch direction to outbound if outbound packet is seen?
|
||||
// - IP: Unidentified Process
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
|
||||
ErrPIDNotFound = errors.New("could not find pid for socket inode")
|
||||
)
|
||||
|
||||
const (
|
||||
lookupTries = 5
|
||||
fastLookupTries = 2
|
||||
)
|
||||
|
||||
// Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound.
|
||||
func Lookup(pktInfo *packet.Info, fast bool) (pid int, inbound bool, err error) {
|
||||
// auto-detect version
|
||||
if pktInfo.Version == 0 {
|
||||
if ip := pktInfo.LocalIP().To4(); ip != nil {
|
||||
pktInfo.Version = packet.IPv4
|
||||
} else {
|
||||
pktInfo.Version = packet.IPv6
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
|
||||
return tcp4Table.lookup(pktInfo, fast)
|
||||
|
||||
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP:
|
||||
return tcp6Table.lookup(pktInfo, fast)
|
||||
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP:
|
||||
return udp4Table.lookup(pktInfo, fast)
|
||||
|
||||
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP:
|
||||
return udp6Table.lookup(pktInfo, fast)
|
||||
|
||||
default:
|
||||
return socket.UndefinedProcessID, pktInfo.Inbound, errors.New("unsupported protocol for finding process")
|
||||
}
|
||||
}
|
||||
|
||||
func (table *tcpTable) lookup(pktInfo *packet.Info, fast bool) (
|
||||
pid int,
|
||||
inbound bool,
|
||||
err error,
|
||||
) {
|
||||
// Prepare variables.
|
||||
var (
|
||||
connections []*socket.ConnectionInfo
|
||||
listeners []*socket.BindInfo
|
||||
|
||||
dualStackConnections []*socket.ConnectionInfo
|
||||
dualStackListeners []*socket.BindInfo
|
||||
)
|
||||
|
||||
// Search for the socket until found.
|
||||
for i := 1; i <= lookupTries; i++ {
|
||||
// Use existing tables for first check if packet was seen after last table update.
|
||||
if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() {
|
||||
connections, listeners = table.getCurrentTables()
|
||||
} else {
|
||||
connections, listeners = table.updateTables()
|
||||
}
|
||||
|
||||
// Check tables for socket.
|
||||
socketInfo, inbound := findTCPSocket(pktInfo, connections, listeners)
|
||||
|
||||
// If there's a match, check if we have the PID and return.
|
||||
if socketInfo != nil {
|
||||
return CheckPID(socketInfo, inbound)
|
||||
}
|
||||
|
||||
// DUAL-STACK
|
||||
|
||||
// Skip if dualStack is not enabled.
|
||||
if table.dualStack == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use existing tables for first check if packet was seen after last table update.
|
||||
if i == 1 && pktInfo.SeenAt.UnixNano() >= table.dualStack.lastUpdateAt.Load() {
|
||||
dualStackConnections, dualStackListeners = table.dualStack.getCurrentTables()
|
||||
} else {
|
||||
dualStackConnections, dualStackListeners = table.dualStack.updateTables()
|
||||
}
|
||||
|
||||
// Check tables for socket.
|
||||
socketInfo, inbound = findTCPSocket(pktInfo, dualStackConnections, dualStackListeners)
|
||||
|
||||
// If there's a match, check if we have the PID and return.
|
||||
if socketInfo != nil {
|
||||
return CheckPID(socketInfo, inbound)
|
||||
}
|
||||
|
||||
// Search less if we want to be fast.
|
||||
if fast && i >= fastLookupTries {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound
|
||||
}
|
||||
|
||||
func findTCPSocket(
|
||||
pktInfo *packet.Info,
|
||||
connections []*socket.ConnectionInfo,
|
||||
listeners []*socket.BindInfo,
|
||||
) (
|
||||
socketInfo socket.Info,
|
||||
inbound bool,
|
||||
) {
|
||||
localIP := pktInfo.LocalIP()
|
||||
localPort := pktInfo.LocalPort()
|
||||
|
||||
// always search listeners first
|
||||
for _, socketInfo := range listeners {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
(socketInfo.ListensAny || localIP.Equal(socketInfo.Local.IP)) {
|
||||
return socketInfo, true
|
||||
}
|
||||
}
|
||||
|
||||
remoteIP := pktInfo.RemoteIP()
|
||||
remotePort := pktInfo.RemotePort()
|
||||
|
||||
// search connections
|
||||
for _, socketInfo := range connections {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
remotePort == socketInfo.Remote.Port &&
|
||||
remoteIP.Equal(socketInfo.Remote.IP) &&
|
||||
localIP.Equal(socketInfo.Local.IP) {
|
||||
return socketInfo, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) (
|
||||
pid int,
|
||||
inbound bool,
|
||||
err error,
|
||||
) {
|
||||
// TODO: Currently broadcast/multicast scopes are not checked, so we might
|
||||
// attribute an incoming broadcast/multicast packet to the wrong process if
|
||||
// there are multiple processes listening on the same local port, but
|
||||
// binding to different addresses. This highly unusual for clients.
|
||||
isInboundMulticast := pktInfo.Inbound && netutils.GetIPScope(pktInfo.LocalIP()) == netutils.LocalMulticast
|
||||
|
||||
// Prepare variables.
|
||||
var (
|
||||
binds []*socket.BindInfo
|
||||
dualStackBinds []*socket.BindInfo
|
||||
)
|
||||
|
||||
// Search for the socket until found.
|
||||
for i := 1; i <= lookupTries; i++ {
|
||||
// Get or update tables.
|
||||
if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() {
|
||||
binds = table.getCurrentTables()
|
||||
} else {
|
||||
binds = table.updateTables()
|
||||
}
|
||||
|
||||
// Check tables for socket.
|
||||
socketInfo := findUDPSocket(pktInfo, binds, isInboundMulticast)
|
||||
|
||||
// If there's a match, do some last checks and return.
|
||||
if socketInfo != nil {
|
||||
// If there is no remote port, do check for the direction of the
|
||||
// connection. This will be the case for pure checking functions
|
||||
// that do not want to change direction state.
|
||||
if pktInfo.RemotePort() == 0 {
|
||||
return CheckPID(socketInfo, pktInfo.Inbound)
|
||||
}
|
||||
|
||||
// Get (and save) the direction of the connection.
|
||||
connInbound := table.getDirection(socketInfo, pktInfo)
|
||||
|
||||
// Check we have the PID and return.
|
||||
return CheckPID(socketInfo, connInbound)
|
||||
}
|
||||
|
||||
// DUAL-STACK
|
||||
|
||||
// Skip if dualStack is not enabled.
|
||||
if table.dualStack == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get or update tables.
|
||||
if i == 1 && pktInfo.SeenAt.UnixNano() >= table.lastUpdateAt.Load() {
|
||||
dualStackBinds = table.dualStack.getCurrentTables()
|
||||
} else {
|
||||
dualStackBinds = table.dualStack.updateTables()
|
||||
}
|
||||
|
||||
// Check tables for socket.
|
||||
socketInfo = findUDPSocket(pktInfo, dualStackBinds, isInboundMulticast)
|
||||
|
||||
// If there's a match, do some last checks and return.
|
||||
if socketInfo != nil {
|
||||
// If there is no remote port, do check for the direction of the
|
||||
// connection. This will be the case for pure checking functions
|
||||
// that do not want to change direction state.
|
||||
if pktInfo.RemotePort() == 0 {
|
||||
return CheckPID(socketInfo, pktInfo.Inbound)
|
||||
}
|
||||
|
||||
// Get (and save) the direction of the connection.
|
||||
connInbound := table.getDirection(socketInfo, pktInfo)
|
||||
|
||||
// Check we have the PID and return.
|
||||
return CheckPID(socketInfo, connInbound)
|
||||
}
|
||||
|
||||
// Search less if we want to be fast.
|
||||
if fast && i >= fastLookupTries {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return socket.UndefinedProcessID, pktInfo.Inbound, ErrConnectionNotFound
|
||||
}
|
||||
|
||||
func findUDPSocket(pktInfo *packet.Info, binds []*socket.BindInfo, isInboundMulticast bool) (socketInfo *socket.BindInfo) {
|
||||
localIP := pktInfo.LocalIP()
|
||||
localPort := pktInfo.LocalPort()
|
||||
|
||||
// search binds
|
||||
for _, socketInfo := range binds {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
(socketInfo.ListensAny || // zero IP (dual-stack)
|
||||
isInboundMulticast || // inbound broadcast, multicast
|
||||
localIP.Equal(socketInfo.Local.IP)) {
|
||||
return socketInfo
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
46
service/network/state/system_default.go
Normal file
46
service/network/state/system_default.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build !windows && !linux
|
||||
// +build !windows,!linux
|
||||
|
||||
package state
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// This increases performance on unsupported system.
|
||||
// It's not critical at all and does not break anything if it fails.
|
||||
go func() {
|
||||
// Wait for one minute before we set the default value, as we
|
||||
// currently cannot easily integrate into the startup procedure.
|
||||
time.Sleep(1 * time.Minute)
|
||||
|
||||
// We cannot use process.CfgOptionEnableProcessDetectionKey, because of an import loop.
|
||||
config.SetDefaultConfigOption("core/enableProcessDetection", false)
|
||||
}()
|
||||
}
|
||||
|
||||
func getTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func getTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func getUDP4Table() (binds []*socket.BindInfo, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func getUDP6Table() (binds []*socket.BindInfo, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// CheckPID checks the if socket info already has a PID and if not, tries to find it.
|
||||
// Depending on the OS, this might be a no-op.
|
||||
func CheckPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
|
||||
return socketInfo.GetPID(), connInbound, nil
|
||||
}
|
||||
40
service/network/state/system_linux.go
Normal file
40
service/network/state/system_linux.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/network/proc"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
var (
|
||||
getTCP4Table = proc.GetTCP4Table
|
||||
getTCP6Table = proc.GetTCP6Table
|
||||
getUDP4Table = proc.GetUDP4Table
|
||||
getUDP6Table = proc.GetUDP6Table
|
||||
|
||||
checkPIDTries = 5
|
||||
checkPIDBaseWaitTime = 5 * time.Millisecond
|
||||
)
|
||||
|
||||
// CheckPID checks the if socket info already has a PID and if not, tries to find it.
|
||||
// Depending on the OS, this might be a no-op.
|
||||
func CheckPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
|
||||
for i := 1; i <= checkPIDTries; i++ {
|
||||
// look for PID
|
||||
pid = proc.GetPID(socketInfo)
|
||||
if pid != socket.UndefinedProcessID {
|
||||
// if we found a PID, return
|
||||
break
|
||||
}
|
||||
|
||||
// every time, except for the last iteration
|
||||
if i < checkPIDTries {
|
||||
// we found no PID, we could have been too fast, give the kernel some time to think
|
||||
// back off timer: with 5ms baseWaitTime: 5, 10, 15, 20, 25 - 75ms in total
|
||||
time.Sleep(time.Duration(i) * checkPIDBaseWaitTime)
|
||||
}
|
||||
}
|
||||
|
||||
return pid, connInbound, nil
|
||||
}
|
||||
19
service/network/state/system_windows.go
Normal file
19
service/network/state/system_windows.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/service/network/iphelper"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
var (
|
||||
getTCP4Table = iphelper.GetTCP4Table
|
||||
getTCP6Table = iphelper.GetTCP6Table
|
||||
getUDP4Table = iphelper.GetUDP4Table
|
||||
getUDP6Table = iphelper.GetUDP6Table
|
||||
)
|
||||
|
||||
// CheckPID checks the if socket info already has a PID and if not, tries to find it.
|
||||
// Depending on the OS, this might be a no-op.
|
||||
func CheckPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
|
||||
return socketInfo.GetPID(), connInbound, nil
|
||||
}
|
||||
91
service/network/state/tcp.go
Normal file
91
service/network/state/tcp.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/utils"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
const (
|
||||
minDurationBetweenTableUpdates = 10 * time.Millisecond
|
||||
)
|
||||
|
||||
type tcpTable struct {
|
||||
version int
|
||||
|
||||
connections []*socket.ConnectionInfo
|
||||
listeners []*socket.BindInfo
|
||||
lock sync.RWMutex
|
||||
|
||||
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
|
||||
lastUpdateAt atomic.Int64
|
||||
|
||||
fetchLimiter *utils.CallLimiter
|
||||
fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error)
|
||||
|
||||
dualStack *tcpTable
|
||||
}
|
||||
|
||||
var (
|
||||
tcp6Table = &tcpTable{
|
||||
version: 6,
|
||||
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
|
||||
fetchTable: getTCP6Table,
|
||||
}
|
||||
|
||||
tcp4Table = &tcpTable{
|
||||
version: 4,
|
||||
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
|
||||
fetchTable: getTCP4Table,
|
||||
}
|
||||
)
|
||||
|
||||
// EnableTCPDualStack adds the TCP6 table to the TCP4 table as a dual-stack.
|
||||
// Must be called before any lookup operation.
|
||||
func EnableTCPDualStack() {
|
||||
tcp4Table.dualStack = tcp6Table
|
||||
}
|
||||
|
||||
func (table *tcpTable) getCurrentTables() (
|
||||
connections []*socket.ConnectionInfo,
|
||||
listeners []*socket.BindInfo,
|
||||
) {
|
||||
table.lock.RLock()
|
||||
defer table.lock.RUnlock()
|
||||
|
||||
return table.connections, table.listeners
|
||||
}
|
||||
|
||||
func (table *tcpTable) updateTables() (
|
||||
connections []*socket.ConnectionInfo,
|
||||
listeners []*socket.BindInfo,
|
||||
) {
|
||||
// Fetch tables.
|
||||
table.fetchLimiter.Do(func() {
|
||||
// Fetch new tables from system.
|
||||
connections, listeners, err := table.fetchTable()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get TCP%d socket table: %s", table.version, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-check for any listeners.
|
||||
for _, bindInfo := range listeners {
|
||||
bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero)
|
||||
}
|
||||
|
||||
// Apply new tables.
|
||||
table.lock.Lock()
|
||||
defer table.lock.Unlock()
|
||||
table.connections = connections
|
||||
table.listeners = listeners
|
||||
table.lastUpdateAt.Store(time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return table.getCurrentTables()
|
||||
}
|
||||
210
service/network/state/udp.go
Normal file
210
service/network/state/udp.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/utils"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/network/socket"
|
||||
)
|
||||
|
||||
type udpTable struct {
|
||||
version int
|
||||
|
||||
binds []*socket.BindInfo
|
||||
lock sync.RWMutex
|
||||
|
||||
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
|
||||
lastUpdateAt atomic.Int64
|
||||
|
||||
fetchLimiter *utils.CallLimiter
|
||||
fetchTable func() (binds []*socket.BindInfo, err error)
|
||||
|
||||
states map[string]map[string]*udpState
|
||||
statesLock sync.Mutex
|
||||
|
||||
dualStack *udpTable
|
||||
}
|
||||
|
||||
type udpState struct {
|
||||
inbound bool
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
// UDPConnStateTTL is the maximum time a udp connection state is held.
|
||||
UDPConnStateTTL = 72 * time.Hour
|
||||
|
||||
// UDPConnStateShortenedTTL is a shortened maximum time a udp connection state is held, if there more entries than defined by AggressiveCleaningThreshold.
|
||||
UDPConnStateShortenedTTL = 3 * time.Hour
|
||||
|
||||
// AggressiveCleaningThreshold defines the soft limit of udp connection state held per udp socket.
|
||||
AggressiveCleaningThreshold = 256
|
||||
)
|
||||
|
||||
var (
|
||||
udp6Table = &udpTable{
|
||||
version: 6,
|
||||
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
|
||||
fetchTable: getUDP6Table,
|
||||
states: make(map[string]map[string]*udpState),
|
||||
}
|
||||
|
||||
udp4Table = &udpTable{
|
||||
version: 4,
|
||||
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
|
||||
fetchTable: getUDP4Table,
|
||||
states: make(map[string]map[string]*udpState),
|
||||
}
|
||||
)
|
||||
|
||||
// EnableUDPDualStack adds the UDP6 table to the UDP4 table as a dual-stack.
|
||||
// Must be called before any lookup operation.
|
||||
func EnableUDPDualStack() {
|
||||
udp4Table.dualStack = udp6Table
|
||||
}
|
||||
|
||||
func (table *udpTable) getCurrentTables() (binds []*socket.BindInfo) {
|
||||
table.lock.RLock()
|
||||
defer table.lock.RUnlock()
|
||||
|
||||
return table.binds
|
||||
}
|
||||
|
||||
func (table *udpTable) updateTables() (binds []*socket.BindInfo) {
|
||||
// Fetch tables.
|
||||
table.fetchLimiter.Do(func() {
|
||||
// Fetch new tables from system.
|
||||
binds, err := table.fetchTable()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get UDP%d socket table: %s", table.version, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-check for any listeners.
|
||||
for _, bindInfo := range binds {
|
||||
bindInfo.ListensAny = bindInfo.Local.IP.Equal(net.IPv4zero) || bindInfo.Local.IP.Equal(net.IPv6zero)
|
||||
}
|
||||
|
||||
// Apply new tables.
|
||||
table.lock.Lock()
|
||||
defer table.lock.Unlock()
|
||||
table.binds = binds
|
||||
table.lastUpdateAt.Store(time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return table.getCurrentTables()
|
||||
}
|
||||
|
||||
// CleanUDPStates cleans the udp connection states which save connection directions.
|
||||
func CleanUDPStates(_ context.Context) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
udp4Table.updateTables()
|
||||
udp4Table.cleanStates(now)
|
||||
|
||||
if netenv.IPv6Enabled() {
|
||||
udp6Table.updateTables()
|
||||
udp6Table.cleanStates(now)
|
||||
}
|
||||
}
|
||||
|
||||
func (table *udpTable) getConnState(
|
||||
socketInfo *socket.BindInfo,
|
||||
remoteAddress socket.Address,
|
||||
) (udpConnState *udpState, ok bool) {
|
||||
table.statesLock.Lock()
|
||||
defer table.statesLock.Unlock()
|
||||
|
||||
bindMap, ok := table.states[makeUDPStateKey(socketInfo.Local)]
|
||||
if ok {
|
||||
udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)]
|
||||
return
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (table *udpTable) getDirection(
|
||||
socketInfo *socket.BindInfo,
|
||||
pktInfo *packet.Info,
|
||||
) (connDirection bool) {
|
||||
table.statesLock.Lock()
|
||||
defer table.statesLock.Unlock()
|
||||
|
||||
localKey := makeUDPStateKey(socketInfo.Local)
|
||||
|
||||
bindMap, ok := table.states[localKey]
|
||||
if !ok {
|
||||
bindMap = make(map[string]*udpState)
|
||||
table.states[localKey] = bindMap
|
||||
}
|
||||
|
||||
remoteKey := makeUDPStateKey(socket.Address{
|
||||
IP: pktInfo.RemoteIP(),
|
||||
Port: pktInfo.RemotePort(),
|
||||
})
|
||||
udpConnState, ok := bindMap[remoteKey]
|
||||
if !ok {
|
||||
bindMap[remoteKey] = &udpState{
|
||||
inbound: pktInfo.Inbound,
|
||||
lastSeen: time.Now().UTC(),
|
||||
}
|
||||
return pktInfo.Inbound
|
||||
}
|
||||
|
||||
udpConnState.lastSeen = time.Now().UTC()
|
||||
return udpConnState.inbound
|
||||
}
|
||||
|
||||
func (table *udpTable) cleanStates(now time.Time) {
|
||||
// compute thresholds
|
||||
threshold := now.Add(-UDPConnStateTTL)
|
||||
shortThreshhold := now.Add(-UDPConnStateShortenedTTL)
|
||||
|
||||
// make lookup map of all active keys
|
||||
bindKeys := make(map[string]struct{})
|
||||
table.lock.RLock()
|
||||
for _, socketInfo := range table.binds {
|
||||
bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{}
|
||||
}
|
||||
table.lock.RUnlock()
|
||||
|
||||
table.statesLock.Lock()
|
||||
defer table.statesLock.Unlock()
|
||||
|
||||
// clean the udp state storage
|
||||
for localKey, bindMap := range table.states {
|
||||
if _, active := bindKeys[localKey]; active {
|
||||
// clean old entries
|
||||
for remoteKey, udpConnState := range bindMap {
|
||||
if udpConnState.lastSeen.Before(threshold) {
|
||||
delete(bindMap, remoteKey)
|
||||
}
|
||||
}
|
||||
// if there are too many clean more aggressively
|
||||
if len(bindMap) > AggressiveCleaningThreshold {
|
||||
for remoteKey, udpConnState := range bindMap {
|
||||
if udpConnState.lastSeen.Before(shortThreshhold) {
|
||||
delete(bindMap, remoteKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// delete the whole thing
|
||||
delete(table.states, localKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeUDPStateKey(address socket.Address) string {
|
||||
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
|
||||
return string(address.IP) + strconv.Itoa(int(address.Port))
|
||||
}
|
||||
83
service/network/status.go
Normal file
83
service/network/status.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package network
|
||||
|
||||
// Verdict describes the decision made about a connection or link.
|
||||
type Verdict int8
|
||||
|
||||
// All possible verdicts that can be applied to a network
|
||||
// connection.
|
||||
const (
|
||||
// VerdictUndecided is the default status of new connections.
|
||||
VerdictUndecided Verdict = 0
|
||||
VerdictUndeterminable Verdict = 1
|
||||
VerdictAccept Verdict = 2
|
||||
VerdictBlock Verdict = 3
|
||||
VerdictDrop Verdict = 4
|
||||
VerdictRerouteToNameserver Verdict = 5
|
||||
VerdictRerouteToTunnel Verdict = 6
|
||||
VerdictFailed Verdict = 7
|
||||
)
|
||||
|
||||
func (v Verdict) String() string {
|
||||
switch v {
|
||||
case VerdictUndecided:
|
||||
return "<Undecided>"
|
||||
case VerdictUndeterminable:
|
||||
return "<Undeterminable>"
|
||||
case VerdictAccept:
|
||||
return "Accept"
|
||||
case VerdictBlock:
|
||||
return "Block"
|
||||
case VerdictDrop:
|
||||
return "Drop"
|
||||
case VerdictRerouteToNameserver:
|
||||
return "RerouteToNameserver"
|
||||
case VerdictRerouteToTunnel:
|
||||
return "RerouteToTunnel"
|
||||
case VerdictFailed:
|
||||
return "Failed"
|
||||
default:
|
||||
return "<INVALID VERDICT>"
|
||||
}
|
||||
}
|
||||
|
||||
// Verb returns the verdict as a past tense verb.
|
||||
func (v Verdict) Verb() string {
|
||||
switch v {
|
||||
case VerdictUndecided:
|
||||
return "undecided"
|
||||
case VerdictUndeterminable:
|
||||
return "undeterminable"
|
||||
case VerdictAccept:
|
||||
return "accepted"
|
||||
case VerdictBlock:
|
||||
return "blocked"
|
||||
case VerdictDrop:
|
||||
return "dropped"
|
||||
case VerdictRerouteToNameserver:
|
||||
return "redirected to nameserver"
|
||||
case VerdictRerouteToTunnel:
|
||||
return "tunneled"
|
||||
case VerdictFailed:
|
||||
return "failed"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
// Packet Directions.
|
||||
const (
|
||||
Inbound = true
|
||||
Outbound = false
|
||||
)
|
||||
|
||||
// Non-Domain Scopes.
|
||||
const (
|
||||
IncomingHost = "IH"
|
||||
IncomingLAN = "IL"
|
||||
IncomingInternet = "II"
|
||||
IncomingInvalid = "IX"
|
||||
PeerHost = "PH"
|
||||
PeerLAN = "PL"
|
||||
PeerInternet = "PI"
|
||||
PeerInvalid = "PX"
|
||||
)
|
||||
Reference in New Issue
Block a user