Revamp connection handling flow to fix race condition and support info-only packets

This commit is contained in:
Daniel
2023-06-21 15:31:45 +02:00
parent 83b084959e
commit 8a09ba6045
22 changed files with 527 additions and 349 deletions

View File

@@ -14,43 +14,19 @@ import (
"github.com/safing/portmaster/profile"
)
// GetProcessByConnection returns the process that owns the described connection.
func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process *Process, connInbound bool, err error) {
// GetProcessWithProfile returns the process, including the profile.
// Always returns valid data.
// Errors are logged and returned for information or special handling purposes.
func GetProcessWithProfile(ctx context.Context, pid int) (process *Process, err error) {
if !enableProcessDetection() {
log.Tracer(ctx).Tracef("process: process detection disabled")
return GetUnidentifiedProcess(ctx), pktInfo.Inbound, nil
}
// Use fast search for inbound packets, as the listening socket should
// already be there for a while now.
fastSearch := pktInfo.Inbound
var pid int
if pktInfo.PID == 0 {
log.Tracer(ctx).Tracef("process: getting pid from system network state")
pid, connInbound, err = state.Lookup(pktInfo, fastSearch)
if err != nil {
log.Tracer(ctx).Tracef("process: failed to find PID of connection: %s", err)
return nil, pktInfo.Inbound, err
}
} else {
log.Tracer(ctx).Tracef("process: pid already set in packet (by ebpf or kext)")
pid = int(pktInfo.PID)
}
// Fallback to special profiles if PID could not be found.
if pid == UndefinedProcessID {
if connInbound {
pid = UnsolicitedProcessID
} else {
pid = UnidentifiedProcessID
}
return GetUnidentifiedProcess(ctx), nil
}
process, err = GetOrFindProcess(ctx, pid)
if err != nil {
log.Tracer(ctx).Debugf("process: failed to find (primary) process with PID: %s", err)
return nil, connInbound, err
log.Tracer(ctx).Warningf("process: failed to find process with PID: %s", err)
return GetUnidentifiedProcess(ctx), err
}
changed, err := process.GetProfile(ctx)
@@ -62,7 +38,46 @@ func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process
process.Save()
}
return process, connInbound, nil
return process, nil
}
// GetPidOfConnection returns the PID of the process that owns the described connection.
// Always returns valid data.
// Errors are logged and returned for information or special handling purposes.
func GetPidOfConnection(ctx context.Context, pktInfo *packet.Info) (pid int, connInbound bool, err error) {
if !enableProcessDetection() {
return UnidentifiedProcessID, pktInfo.Inbound, nil
}
// Use fast search for inbound packets, as the listening socket should
// already be there for a while now.
fastSearch := pktInfo.Inbound
connInbound = pktInfo.Inbound
// FIXME: Only match for UndefinedProcessID when integrations have been updated.
if pktInfo.PID <= 0 {
log.Tracer(ctx).Tracef("process: getting pid from system network state")
pid, connInbound, err = state.Lookup(pktInfo, fastSearch)
if err != nil {
err = fmt.Errorf("failed to find PID of connection: %w", err)
log.Tracer(ctx).Tracef("process: %s", err)
pid = UndefinedProcessID
}
} else {
log.Tracer(ctx).Tracef("process: pid already set in packet (by ebpf or kext)")
pid = pktInfo.PID
}
// Fallback to special profiles if PID could not be found.
if pid == UndefinedProcessID {
if connInbound && !netutils.ClassifyIP(pktInfo.Dst).IsLocalhost() {
pid = UnsolicitedProcessID
} else {
pid = UnidentifiedProcessID
}
}
return pid, connInbound, err
}
// GetNetworkHost returns a *Process that represents a host on the network.
@@ -113,7 +128,12 @@ func GetProcessByRequestOrigin(ar *api.Request) (*Process, error) {
SrcPort: remotePort, // source as in the process we are looking for
}
proc, _, err := GetProcessByConnection(ar.Context(), pkt)
pid, _, err := GetPidOfConnection(ar.Context(), pkt)
if err != nil {
return nil, err
}
proc, err := GetProcessWithProfile(ar.Context(), pid)
if err != nil {
return nil, err
}

View File

@@ -217,7 +217,7 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*
}
// UID
// net yet implemented for windows
// TODO: implemented for windows
if onLinux {
var uids []int32
uids, err = pInfo.UidsWithContext(ctx)