Revamp connection handling flow to fix race condition and support info-only packets

This commit is contained in:
Daniel
2023-06-21 15:31:45 +02:00
parent 83b084959e
commit 8a09ba6045
22 changed files with 527 additions and 349 deletions

View File

@@ -11,7 +11,6 @@ import (
"github.com/google/gopacket/layers"
"github.com/tevino/abool"
"golang.org/x/sync/singleflight"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
@@ -54,7 +53,7 @@ func init() {
// TODO: Move interception module to own package (dir).
interceptionModule = modules.Register("interception", interceptionPrep, interceptionStart, interceptionStop, "base", "updates", "network", "notifications", "profiles")
network.SetDefaultFirewallHandler(defaultHandler)
network.SetDefaultFirewallHandler(verdictHandler)
}
func interceptionPrep() error {
@@ -120,6 +119,11 @@ func resetAllConnectionVerdicts() {
// Re-evaluate all connections.
var changedVerdicts int
for _, conn := range network.GetAllConnections() {
// Skip incomplete connections.
if !conn.DataIsComplete() {
continue
}
func() {
conn.Lock()
defer conn.Unlock()
@@ -167,15 +171,10 @@ func resetAllConnectionVerdicts() {
func interceptionStart() error {
getConfig()
if err := registerMetrics(); err != nil {
return err
}
startAPIAuth()
interceptionModule.StartWorker("stat logger", statLogger)
interceptionModule.StartWorker("packet handler", packetHandler)
interceptionModule.StartServiceWorker("stat logger", 0, statLogger)
interceptionModule.StartServiceWorker("packet handler", 0, packetHandler)
return interception.Start()
}
@@ -196,92 +195,38 @@ func SetNameserverIPMatcher(fn func(ip net.IP) bool) error {
return nil
}
func handlePacket(ctx context.Context, pkt packet.Packet) {
// Record metrics.
startTime := time.Now()
defer packetHandlingHistogram.UpdateDuration(startTime)
if fastTrackedPermit(pkt) {
func handlePacket(pkt packet.Packet) {
// First, check for an existing connection.
conn, ok := network.GetConnection(pkt.GetConnectionID())
if ok {
// Add packet to connection handler queue or apply verdict directly.
conn.HandlePacket(pkt)
return
}
// Add context tracer and set context on packet.
traceCtx, tracer := log.AddTracer(ctx)
if tracer != nil {
// The trace is submitted in `network.Connection.packetHandler()`.
tracer.Tracef("filter: handling packet: %s", pkt)
}
pkt.SetCtx(traceCtx)
// Else create new incomplete connection from the packet and start the new handler.
conn = network.NewIncompleteConnection(pkt)
conn.Lock()
defer conn.Unlock()
conn.SetFirewallHandler(fastTrackHandler)
// Get connection of packet.
conn, err := getConnection(pkt)
if err != nil {
tracer.Errorf("filter: packet %s dropped: %s", pkt, err)
_ = pkt.Drop()
return
}
// handle packet
// Let the new connection handler worker handle the packet.
conn.HandlePacket(pkt)
}
var getConnectionSingleInflight singleflight.Group
func getConnection(pkt packet.Packet) (*network.Connection, error) {
created := false
// Create or get connection in single inflight lock in order to prevent duplicates.
newConn, err, shared := getConnectionSingleInflight.Do(pkt.GetConnectionID(), func() (interface{}, error) {
// First, check for an existing connection.
conn, ok := network.GetConnection(pkt.GetConnectionID())
if ok {
return conn, nil
}
// Else create new one from the packet.
conn = network.NewConnectionFromFirstPacket(pkt)
conn.Lock()
defer conn.Unlock()
conn.SetFirewallHandler(initialHandler)
created = true
return conn, nil
})
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
if newConn == nil {
return nil, errors.New("connection getter returned nil")
}
// Transform and log result.
conn := newConn.(*network.Connection) //nolint:forcetypeassert // Can only be a *network.Connection.
sharedIndicator := ""
if shared {
sharedIndicator = " (shared)"
}
if created {
log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s%s", conn.ID, sharedIndicator)
} else {
log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s%s", conn.ID, sharedIndicator)
}
return conn, nil
}
// fastTrackedPermit quickly permits certain network critical or internal connections.
func fastTrackedPermit(pkt packet.Packet) (handled bool) {
func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bool) {
meta := pkt.Info()
// Check if packed was already fast-tracked by the OS integration.
if pkt.FastTrackedByIntegration() {
log.Debugf("filter: fast-tracked by OS integration: %s", pkt)
return true
return network.VerdictAccept, true
}
// Check if connection was already blocked.
if meta.Dst.Equal(blockedIPv4) || meta.Dst.Equal(blockedIPv6) {
_ = pkt.PermanentBlock()
return true
return network.VerdictBlock, true
}
// Some programs do a network self-check where they connects to the same
@@ -290,8 +235,8 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
if meta.SrcPort == meta.DstPort &&
meta.Src.Equal(meta.Dst) {
log.Debugf("filter: fast-track network self-check: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
}
switch meta.Protocol { //nolint:exhaustive // Checking for specific values only.
@@ -300,8 +245,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
err := pkt.LoadPacketData()
if err != nil {
log.Debugf("filter: failed to load ICMP packet data: %s", err)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
}
// Submit to ICMP listener.
@@ -311,8 +255,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
// permanent accept, because then we won't see any future packets of that
// connection and thus cannot continue to submit them.
log.Debugf("filter: fast-track tracing ICMP/v6: %s", pkt)
_ = pkt.Accept()
return true
return network.VerdictAccept, false
}
// Handle echo request and replies regularly.
@@ -323,20 +266,19 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
switch icmpLayer.TypeCode.Type() {
case layers.ICMPv4TypeEchoRequest,
layers.ICMPv4TypeEchoReply:
return false
return network.VerdictUndecided, false
}
case *layers.ICMPv6:
switch icmpLayer.TypeCode.Type() {
case layers.ICMPv6TypeEchoRequest,
layers.ICMPv6TypeEchoReply:
return false
return network.VerdictUndecided, false
}
}
// Permit all ICMP/v6 packets that are not echo requests or replies.
log.Debugf("filter: fast-track accepting ICMP/v6: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
case packet.UDP, packet.TCP:
switch meta.DstPort {
@@ -346,37 +288,36 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
// DHCP and DHCPv6 must be UDP.
if meta.Protocol != packet.UDP {
return false
return network.VerdictUndecided, false
}
// DHCP is only valid in local network scopes.
switch netutils.ClassifyIP(meta.Dst) { //nolint:exhaustive // Checking for specific values only.
case netutils.HostLocal, netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
default:
return false
return network.VerdictUndecided, false
}
// Log and permit.
log.Debugf("filter: fast-track accepting DHCP: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
case apiPort:
// Always allow direct access to the Portmaster API.
// Portmaster API is TCP only.
if meta.Protocol != packet.TCP {
return false
return network.VerdictUndecided, false
}
// Check if the api port is even set.
if !apiPortSet {
return false
return network.VerdictUndecided, false
}
// Must be destined for the API IP.
if !meta.Dst.Equal(apiIP) {
return false
return network.VerdictUndecided, false
}
// Only fast-track local requests.
@@ -384,15 +325,14 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
switch {
case err != nil:
log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
return false
return network.VerdictUndecided, false
case !isMe:
return false
return network.VerdictUndecided, false
}
// Log and permit.
log.Debugf("filter: fast-track accepting api connection: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
case 53:
// Always allow direct access to the Portmaster Nameserver.
@@ -400,12 +340,12 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
// Check if a nameserver IP matcher is set.
if !nameserverIPMatcherReady.IsSet() {
return false
return network.VerdictUndecided, false
}
// Check if packet is destined for a nameserver IP.
if !nameserverIPMatcher(meta.Dst) {
return false
return network.VerdictUndecided, false
}
// Only fast-track local requests.
@@ -413,32 +353,76 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
switch {
case err != nil:
log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
return false
return network.VerdictUndecided, false
case !isMe:
return false
return network.VerdictUndecided, false
}
// Log and permit.
log.Debugf("filter: fast-track accepting local dns: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
}
case compat.SystemIntegrationCheckProtocol:
if pkt.Info().Dst.Equal(compat.SystemIntegrationCheckDstIP) {
compat.SubmitSystemIntegrationCheckPacket(pkt)
_ = pkt.Drop()
return true
return network.VerdictDrop, false
}
}
return false
return network.VerdictUndecided, false
}
func initialHandler(conn *network.Connection, pkt packet.Packet) {
func fastTrackHandler(conn *network.Connection, pkt packet.Packet) {
fastTrackedVerdict, permanent := fastTrackedPermit(pkt)
if fastTrackedVerdict != network.VerdictUndecided {
// Set verdict on connection.
conn.Verdict.Active = fastTrackedVerdict
conn.Verdict.Firewall = fastTrackedVerdict
// Apply verdict to (real) packet.
if !pkt.InfoOnly() {
issueVerdict(conn, pkt, fastTrackedVerdict, permanent)
}
// Stop handler if permanent.
if permanent {
conn.SetVerdict(fastTrackedVerdict, "fast-tracked", "", nil)
conn.Verdict.Worst = fastTrackedVerdict
// Do not finalize verdict, as we are missing necessary data.
conn.StopFirewallHandler()
}
// Do not continue to next handler.
return
}
// If packet is not fast-tracked, continue with gathering more information.
conn.UpdateFirewallHandler(gatherDataHandler)
gatherDataHandler(conn, pkt)
}
func gatherDataHandler(conn *network.Connection, pkt packet.Packet) {
// Get process info
_ = conn.GatherConnectionInfo(pkt)
// Errors are informational and are logged to the context.
// Run this handler again if data is not yet complete.
if !conn.DataIsComplete() {
return
}
// Continue to filter handler, when connection data is complete.
conn.UpdateFirewallHandler(filterHandler)
filterHandler(conn, pkt)
}
func filterHandler(conn *network.Connection, pkt packet.Packet) {
// Skip if data is not complete.
if !conn.DataIsComplete() {
return
}
filterConnection := true
log.Tracer(pkt.Ctx()).Trace("filter: handing over to connection-based handler")
// Check for special (internal) connection cases.
switch {
case !conn.Inbound && localPortIsPreAuthenticated(conn.Entity.Protocol, conn.LocalPort):
@@ -480,8 +464,8 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
switch {
case conn.Inspecting:
log.Tracer(pkt.Ctx()).Trace("filter: start inspecting")
conn.SetFirewallHandler(inspectThenVerdict)
inspectThenVerdict(conn, pkt)
conn.SetFirewallHandler(inspectAndVerdictHandler)
inspectAndVerdictHandler(conn, pkt)
default:
conn.StopFirewallHandler()
issueVerdict(conn, pkt, 0, true)
@@ -490,6 +474,11 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
// FilterConnection runs all the filtering (and tunneling) procedures.
func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet, checkFilter, checkTunnel bool) {
// Skip if data is not complete.
if !conn.DataIsComplete() {
return
}
if checkFilter {
if filterEnabled() {
log.Tracer(ctx).Trace("filter: starting decision process")
@@ -537,12 +526,11 @@ func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.
}
}
func defaultHandler(conn *network.Connection, pkt packet.Packet) {
// TODO: `pkt` has an active trace log, which we currently don't submit.
func verdictHandler(conn *network.Connection, pkt packet.Packet) {
issueVerdict(conn, pkt, 0, true)
}
func inspectThenVerdict(conn *network.Connection, pkt packet.Packet) {
func inspectAndVerdictHandler(conn *network.Connection, pkt packet.Packet) {
pktVerdict, continueInspection := inspection.RunInspectors(conn, pkt)
if continueInspection {
issueVerdict(conn, pkt, pktVerdict, false)
@@ -689,10 +677,11 @@ func packetHandler(ctx context.Context) error {
case <-ctx.Done():
return nil
case pkt := <-interception.Packets:
interceptionModule.StartWorker("initial packet handler", func(workerCtx context.Context) error {
handlePacket(workerCtx, pkt)
return nil
})
if pkt != nil {
handlePacket(pkt)
} else {
return errors.New("received nil packet from interception")
}
}
}
}