diff --git a/firewall/api.go b/firewall/api.go index 5c2f0eed..03fe77d4 100644 --- a/firewall/api.go +++ b/firewall/api.go @@ -141,7 +141,11 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo authenticatedPath += string(filepath.Separator) // Get process of request. - proc, _, err := process.GetProcessByConnection(ctx, pktInfo) + pid, _, _ := process.GetPidOfConnection(ctx, pktInfo) + if pid < 0 { + return false, fmt.Errorf(deniedMsgUnidentified, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user + } + proc, err := process.GetOrFindProcess(ctx, pid) if err != nil { log.Tracer(ctx).Debugf("filter: failed to get process of api request: %s", err) originalPid = process.UnidentifiedProcessID diff --git a/firewall/interception.go b/firewall/interception.go index 4999c5b3..775b30b3 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -11,7 +11,6 @@ import ( "github.com/google/gopacket/layers" "github.com/tevino/abool" - "golang.org/x/sync/singleflight" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" @@ -54,7 +53,7 @@ func init() { // TODO: Move interception module to own package (dir). interceptionModule = modules.Register("interception", interceptionPrep, interceptionStart, interceptionStop, "base", "updates", "network", "notifications", "profiles") - network.SetDefaultFirewallHandler(defaultHandler) + network.SetDefaultFirewallHandler(verdictHandler) } func interceptionPrep() error { @@ -120,6 +119,11 @@ func resetAllConnectionVerdicts() { // Re-evaluate all connections. var changedVerdicts int for _, conn := range network.GetAllConnections() { + // Skip incomplete connections. + if !conn.DataIsComplete() { + continue + } + func() { conn.Lock() defer conn.Unlock() @@ -167,15 +171,10 @@ func resetAllConnectionVerdicts() { func interceptionStart() error { getConfig() - - if err := registerMetrics(); err != nil { - return err - } - startAPIAuth() - interceptionModule.StartWorker("stat logger", statLogger) - interceptionModule.StartWorker("packet handler", packetHandler) + interceptionModule.StartServiceWorker("stat logger", 0, statLogger) + interceptionModule.StartServiceWorker("packet handler", 0, packetHandler) return interception.Start() } @@ -196,92 +195,38 @@ func SetNameserverIPMatcher(fn func(ip net.IP) bool) error { return nil } -func handlePacket(ctx context.Context, pkt packet.Packet) { - // Record metrics. - startTime := time.Now() - defer packetHandlingHistogram.UpdateDuration(startTime) - - if fastTrackedPermit(pkt) { +func handlePacket(pkt packet.Packet) { + // First, check for an existing connection. + conn, ok := network.GetConnection(pkt.GetConnectionID()) + if ok { + // Add packet to connection handler queue or apply verdict directly. + conn.HandlePacket(pkt) return } - // Add context tracer and set context on packet. - traceCtx, tracer := log.AddTracer(ctx) - if tracer != nil { - // The trace is submitted in `network.Connection.packetHandler()`. - tracer.Tracef("filter: handling packet: %s", pkt) - } - pkt.SetCtx(traceCtx) + // Else create new incomplete connection from the packet and start the new handler. + conn = network.NewIncompleteConnection(pkt) + conn.Lock() + defer conn.Unlock() + conn.SetFirewallHandler(fastTrackHandler) - // Get connection of packet. - conn, err := getConnection(pkt) - if err != nil { - tracer.Errorf("filter: packet %s dropped: %s", pkt, err) - _ = pkt.Drop() - return - } - - // handle packet + // Let the new connection handler worker handle the packet. conn.HandlePacket(pkt) } -var getConnectionSingleInflight singleflight.Group - -func getConnection(pkt packet.Packet) (*network.Connection, error) { - created := false - - // Create or get connection in single inflight lock in order to prevent duplicates. - newConn, err, shared := getConnectionSingleInflight.Do(pkt.GetConnectionID(), func() (interface{}, error) { - // First, check for an existing connection. - conn, ok := network.GetConnection(pkt.GetConnectionID()) - if ok { - return conn, nil - } - - // Else create new one from the packet. - conn = network.NewConnectionFromFirstPacket(pkt) - conn.Lock() - defer conn.Unlock() - conn.SetFirewallHandler(initialHandler) - created = true - return conn, nil - }) - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - if newConn == nil { - return nil, errors.New("connection getter returned nil") - } - - // Transform and log result. - conn := newConn.(*network.Connection) //nolint:forcetypeassert // Can only be a *network.Connection. - sharedIndicator := "" - if shared { - sharedIndicator = " (shared)" - } - if created { - log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s%s", conn.ID, sharedIndicator) - } else { - log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s%s", conn.ID, sharedIndicator) - } - - return conn, nil -} - // fastTrackedPermit quickly permits certain network critical or internal connections. -func fastTrackedPermit(pkt packet.Packet) (handled bool) { +func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bool) { meta := pkt.Info() // Check if packed was already fast-tracked by the OS integration. if pkt.FastTrackedByIntegration() { log.Debugf("filter: fast-tracked by OS integration: %s", pkt) - return true + return network.VerdictAccept, true } // Check if connection was already blocked. if meta.Dst.Equal(blockedIPv4) || meta.Dst.Equal(blockedIPv6) { - _ = pkt.PermanentBlock() - return true + return network.VerdictBlock, true } // Some programs do a network self-check where they connects to the same @@ -290,8 +235,8 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { if meta.SrcPort == meta.DstPort && meta.Src.Equal(meta.Dst) { log.Debugf("filter: fast-track network self-check: %s", pkt) - _ = pkt.PermanentAccept() - return true + return network.VerdictAccept, true + } switch meta.Protocol { //nolint:exhaustive // Checking for specific values only. @@ -300,8 +245,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { err := pkt.LoadPacketData() if err != nil { log.Debugf("filter: failed to load ICMP packet data: %s", err) - _ = pkt.PermanentAccept() - return true + return network.VerdictAccept, true } // Submit to ICMP listener. @@ -311,8 +255,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { // permanent accept, because then we won't see any future packets of that // connection and thus cannot continue to submit them. log.Debugf("filter: fast-track tracing ICMP/v6: %s", pkt) - _ = pkt.Accept() - return true + return network.VerdictAccept, false } // Handle echo request and replies regularly. @@ -323,20 +266,19 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { switch icmpLayer.TypeCode.Type() { case layers.ICMPv4TypeEchoRequest, layers.ICMPv4TypeEchoReply: - return false + return network.VerdictUndecided, false } case *layers.ICMPv6: switch icmpLayer.TypeCode.Type() { case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply: - return false + return network.VerdictUndecided, false } } // Permit all ICMP/v6 packets that are not echo requests or replies. log.Debugf("filter: fast-track accepting ICMP/v6: %s", pkt) - _ = pkt.PermanentAccept() - return true + return network.VerdictAccept, true case packet.UDP, packet.TCP: switch meta.DstPort { @@ -346,37 +288,36 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { // DHCP and DHCPv6 must be UDP. if meta.Protocol != packet.UDP { - return false + return network.VerdictUndecided, false } // DHCP is only valid in local network scopes. switch netutils.ClassifyIP(meta.Dst) { //nolint:exhaustive // Checking for specific values only. case netutils.HostLocal, netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: default: - return false + return network.VerdictUndecided, false } // Log and permit. log.Debugf("filter: fast-track accepting DHCP: %s", pkt) - _ = pkt.PermanentAccept() - return true + return network.VerdictAccept, true case apiPort: // Always allow direct access to the Portmaster API. // Portmaster API is TCP only. if meta.Protocol != packet.TCP { - return false + return network.VerdictUndecided, false } // Check if the api port is even set. if !apiPortSet { - return false + return network.VerdictUndecided, false } // Must be destined for the API IP. if !meta.Dst.Equal(apiIP) { - return false + return network.VerdictUndecided, false } // Only fast-track local requests. @@ -384,15 +325,14 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { switch { case err != nil: log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err) - return false + return network.VerdictUndecided, false case !isMe: - return false + return network.VerdictUndecided, false } // Log and permit. log.Debugf("filter: fast-track accepting api connection: %s", pkt) - _ = pkt.PermanentAccept() - return true + return network.VerdictAccept, true case 53: // Always allow direct access to the Portmaster Nameserver. @@ -400,12 +340,12 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { // Check if a nameserver IP matcher is set. if !nameserverIPMatcherReady.IsSet() { - return false + return network.VerdictUndecided, false } // Check if packet is destined for a nameserver IP. if !nameserverIPMatcher(meta.Dst) { - return false + return network.VerdictUndecided, false } // Only fast-track local requests. @@ -413,32 +353,76 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { switch { case err != nil: log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err) - return false + return network.VerdictUndecided, false case !isMe: - return false + return network.VerdictUndecided, false } // Log and permit. log.Debugf("filter: fast-track accepting local dns: %s", pkt) - _ = pkt.PermanentAccept() - return true + return network.VerdictAccept, true } case compat.SystemIntegrationCheckProtocol: if pkt.Info().Dst.Equal(compat.SystemIntegrationCheckDstIP) { compat.SubmitSystemIntegrationCheckPacket(pkt) - _ = pkt.Drop() - return true + return network.VerdictDrop, false } } - return false + return network.VerdictUndecided, false } -func initialHandler(conn *network.Connection, pkt packet.Packet) { +func fastTrackHandler(conn *network.Connection, pkt packet.Packet) { + fastTrackedVerdict, permanent := fastTrackedPermit(pkt) + if fastTrackedVerdict != network.VerdictUndecided { + // Set verdict on connection. + conn.Verdict.Active = fastTrackedVerdict + conn.Verdict.Firewall = fastTrackedVerdict + // Apply verdict to (real) packet. + if !pkt.InfoOnly() { + issueVerdict(conn, pkt, fastTrackedVerdict, permanent) + } + // Stop handler if permanent. + if permanent { + conn.SetVerdict(fastTrackedVerdict, "fast-tracked", "", nil) + conn.Verdict.Worst = fastTrackedVerdict + // Do not finalize verdict, as we are missing necessary data. + conn.StopFirewallHandler() + } + + // Do not continue to next handler. + return + } + + // If packet is not fast-tracked, continue with gathering more information. + conn.UpdateFirewallHandler(gatherDataHandler) + gatherDataHandler(conn, pkt) +} + +func gatherDataHandler(conn *network.Connection, pkt packet.Packet) { + // Get process info + _ = conn.GatherConnectionInfo(pkt) + // Errors are informational and are logged to the context. + + // Run this handler again if data is not yet complete. + if !conn.DataIsComplete() { + return + } + + // Continue to filter handler, when connection data is complete. + conn.UpdateFirewallHandler(filterHandler) + filterHandler(conn, pkt) +} + +func filterHandler(conn *network.Connection, pkt packet.Packet) { + // Skip if data is not complete. + if !conn.DataIsComplete() { + return + } + filterConnection := true - log.Tracer(pkt.Ctx()).Trace("filter: handing over to connection-based handler") // Check for special (internal) connection cases. switch { case !conn.Inbound && localPortIsPreAuthenticated(conn.Entity.Protocol, conn.LocalPort): @@ -480,8 +464,8 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { switch { case conn.Inspecting: log.Tracer(pkt.Ctx()).Trace("filter: start inspecting") - conn.SetFirewallHandler(inspectThenVerdict) - inspectThenVerdict(conn, pkt) + conn.SetFirewallHandler(inspectAndVerdictHandler) + inspectAndVerdictHandler(conn, pkt) default: conn.StopFirewallHandler() issueVerdict(conn, pkt, 0, true) @@ -490,6 +474,11 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { // FilterConnection runs all the filtering (and tunneling) procedures. func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet, checkFilter, checkTunnel bool) { + // Skip if data is not complete. + if !conn.DataIsComplete() { + return + } + if checkFilter { if filterEnabled() { log.Tracer(ctx).Trace("filter: starting decision process") @@ -537,12 +526,11 @@ func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet. } } -func defaultHandler(conn *network.Connection, pkt packet.Packet) { - // TODO: `pkt` has an active trace log, which we currently don't submit. +func verdictHandler(conn *network.Connection, pkt packet.Packet) { issueVerdict(conn, pkt, 0, true) } -func inspectThenVerdict(conn *network.Connection, pkt packet.Packet) { +func inspectAndVerdictHandler(conn *network.Connection, pkt packet.Packet) { pktVerdict, continueInspection := inspection.RunInspectors(conn, pkt) if continueInspection { issueVerdict(conn, pkt, pktVerdict, false) @@ -689,10 +677,11 @@ func packetHandler(ctx context.Context) error { case <-ctx.Done(): return nil case pkt := <-interception.Packets: - interceptionModule.StartWorker("initial packet handler", func(workerCtx context.Context) error { - handlePacket(workerCtx, pkt) - return nil - }) + if pkt != nil { + handlePacket(pkt) + } else { + return errors.New("received nil packet from interception") + } } } } diff --git a/firewall/interception/ebpf/packet.go b/firewall/interception/ebpf/packet.go index fffd5558..19895646 100644 --- a/firewall/interception/ebpf/packet.go +++ b/firewall/interception/ebpf/packet.go @@ -13,6 +13,12 @@ type infoPacket struct { pmpacket.Base } +// InfoOnly returns whether the packet is informational only and does not +// represent an actual packet. +func (pkt *infoPacket) InfoOnly() bool { + return true +} + // LoadPacketData does nothing on Linux, as data is always fully parsed. func (pkt *infoPacket) LoadPacketData() error { return fmt.Errorf("can't load data in info only packet") diff --git a/firewall/interception/ebpf/worker.go b/firewall/interception/ebpf/worker.go index 03be5d0d..2fdf50c2 100644 --- a/firewall/interception/ebpf/worker.go +++ b/firewall/interception/ebpf/worker.go @@ -5,11 +5,13 @@ import ( "encoding/binary" "errors" "net" + "time" "unsafe" "github.com/cilium/ebpf/link" "github.com/cilium/ebpf/ringbuf" "github.com/cilium/ebpf/rlimit" + "github.com/safing/portbase/log" "github.com/safing/portmaster/network/packet" ) @@ -17,6 +19,7 @@ import ( //go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" -type Event bpf program/monitor.c var stopper chan struct{} +// StartEBPFWorker starts the ebpf worker. func StartEBPFWorker(ch chan packet.Packet) { stopper = make(chan struct{}) go func() { @@ -32,7 +35,7 @@ func StartEBPFWorker(ch chan packet.Packet) { log.Errorf("ebpf: failed to load ebpf object: %s", err) return } - defer objs.Close() + defer objs.Close() //nolint:errcheck // Create a link to the tcp_connect program. linkTCPConnect, err := link.AttachTracing(link.TracingOptions{ @@ -42,7 +45,7 @@ func StartEBPFWorker(ch chan packet.Packet) { log.Errorf("ebpf: failed to attach to tcp_v4_connect: %s ", err) return } - defer linkTCPConnect.Close() + defer linkTCPConnect.Close() //nolint:errcheck // Create a link to the udp_v4_connect program. linkUDPV4, err := link.AttachTracing(link.TracingOptions{ @@ -52,7 +55,7 @@ func StartEBPFWorker(ch chan packet.Packet) { log.Errorf("ebpf: failed to attach to udp_v4_connect: %s ", err) return } - defer linkUDPV4.Close() + defer linkUDPV4.Close() //nolint:errcheck // Create a link to the udp_v6_connect program. linkUDPV6, err := link.AttachTracing(link.TracingOptions{ @@ -62,14 +65,14 @@ func StartEBPFWorker(ch chan packet.Packet) { log.Errorf("ebpf: failed to attach to udp_v6_connect: %s ", err) return } - defer linkUDPV6.Close() + defer linkUDPV6.Close() //nolint:errcheck rd, err := ringbuf.NewReader(objs.bpfMaps.Events) if err != nil { log.Errorf("ebpf: failed to open ring buffer: %s", err) return } - defer rd.Close() + defer rd.Close() //nolint:errcheck go func() { <-stopper @@ -107,7 +110,8 @@ func StartEBPFWorker(ch chan packet.Packet) { DstPort: event.Dport, Src: arrayToIP(event.Saddr, packet.IPVersion(event.IpVersion)), Dst: arrayToIP(event.Daddr, packet.IPVersion(event.IpVersion)), - PID: event.Pid, + PID: int(event.Pid), + SeenAt: time.Now(), } if isEventValid(event) { log.Debugf("ebpf: PID: %d conn: %s:%d -> %s:%d %s %s", info.PID, info.LocalIP(), info.LocalPort(), info.RemoteIP(), info.RemotePort(), info.Version.String(), info.Protocol.String()) @@ -123,6 +127,7 @@ func StartEBPFWorker(ch chan packet.Packet) { }() } +// StopEBPFWorker stops the ebpf worker. func StopEBPFWorker() { close(stopper) } @@ -148,11 +153,12 @@ func isEventValid(event bpfEvent) bool { return true } -// arrayToIP converts IP number array to net.IP +// arrayToIP converts IP number array to net.IP. func arrayToIP(ipNum [4]uint32, ipVersion packet.IPVersion) net.IP { if ipVersion == packet.IPv4 { + // FIXME: maybe convertIPv4 from windowskext package return unsafe.Slice((*byte)(unsafe.Pointer(&ipNum)), 4) - } else { - return unsafe.Slice((*byte)(unsafe.Pointer(&ipNum)), 16) } + // FIXME: maybe use convertIPv6 from windowskext package + return unsafe.Slice((*byte)(unsafe.Pointer(&ipNum)), 16) } diff --git a/firewall/interception/nfq/nfq.go b/firewall/interception/nfq/nfq.go index 9b382944..585ba96e 100644 --- a/firewall/interception/nfq/nfq.go +++ b/firewall/interception/nfq/nfq.go @@ -16,6 +16,7 @@ import ( "github.com/safing/portbase/log" pmpacket "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/process" ) // Queue wraps a nfqueue. @@ -175,10 +176,11 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int { pkt := &packet{ pktID: *attrs.PacketID, queue: q, - received: time.Now(), verdictSet: make(chan struct{}), verdictPending: abool.New(), } + pkt.Info().PID = process.UndefinedProcessID + pkt.Info().SeenAt = time.Now() if attrs.Payload == nil { // There is not payload. @@ -194,11 +196,11 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int { select { case q.packets <- pkt: - log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received)) + log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt)) case <-ctx.Done(): return 0 case <-time.After(time.Second): - log.Warningf("nfqueue: failed to queue packet (%s since it was handed over by the kernel)", time.Since(pkt.received)) + log.Warningf("nfqueue: failed to queue packet (%s since it was handed over by the kernel)", time.Since(pkt.Info().SeenAt)) } go func() { @@ -206,7 +208,7 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int { case <-pkt.verdictSet: case <-time.After(20 * time.Second): - log.Warningf("nfqueue: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received)) + log.Warningf("nfqueue: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt)) if err := pkt.Drop(); err != nil { log.Warningf("nfqueue: failed to apply default-drop to unveridcted packet %s (%s -> %s)", pkt.ID(), pkt.Info().Src, pkt.Info().Dst) } diff --git a/firewall/interception/nfq/packet.go b/firewall/interception/nfq/packet.go index f3ede898..6dd42186 100644 --- a/firewall/interception/nfq/packet.go +++ b/firewall/interception/nfq/packet.go @@ -55,7 +55,6 @@ func markToString(mark int) string { type packet struct { pmpacket.Base pktID uint32 - received time.Time queue *Queue verdictSet chan struct{} verdictPending *abool.AtomicBool @@ -118,7 +117,7 @@ func (pkt *packet) setMark(mark int) error { } break } - log.Tracer(pkt.Ctx()).Tracef("nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.received)) + log.Tracer(pkt.Ctx()).Tracef("nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.Info().SeenAt)) return nil } diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 7515f471..f3fa33a4 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -8,8 +8,11 @@ import ( "errors" "fmt" "net" + "time" "unsafe" + "github.com/safing/portmaster/process" + "github.com/tevino/abool" "github.com/safing/portbase/log" @@ -103,21 +106,28 @@ func Handler(packets chan packet.Packet) { verdictRequest: packetInfo, verdictSet: abool.NewBool(false), } - info := new.Info() info.Inbound = packetInfo.direction > 0 info.InTunnel = false info.Protocol = packet.IPProtocol(packetInfo.protocol) - info.PID = packetInfo.pid + info.PID = int(packetInfo.pid) + info.SeenAt = time.Now() - // IP version + // Check PID + if info.PID == 0 { + // Windows does not have zero PIDs. + // Set to UndefinedProcessID. + info.PID = process.UndefinedProcessID + } + + // Set IP version if packetInfo.ipV6 == 1 { info.Version = packet.IPv6 } else { info.Version = packet.IPv4 } - // IPs + // Set IPs if info.Version == packet.IPv4 { // IPv4 if info.Inbound { @@ -142,7 +152,7 @@ func Handler(packets chan packet.Packet) { } } - // Ports + // Set Ports if info.Inbound { // Inbound info.SrcPort = packetInfo.remotePort diff --git a/firewall/interception/windowskext/packet.go b/firewall/interception/windowskext/packet.go index daa315f2..282a85a7 100644 --- a/firewall/interception/windowskext/packet.go +++ b/firewall/interception/windowskext/packet.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package windowskext @@ -23,6 +24,12 @@ type Packet struct { lock sync.Mutex } +// InfoOnly returns whether the packet is informational only and does not +// represent an actual packet. +func (pkt *Packet) InfoOnly() bool { + return pkt.verdictRequest.flags&VerdictRequestFlagSocketAuth > 0 +} + // FastTrackedByIntegration returns whether the packet has been fast-track // accepted by the OS integration. func (pkt *Packet) FastTrackedByIntegration() bool { diff --git a/firewall/master.go b/firewall/master.go index ae86f2b8..627dfe91 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -25,19 +25,6 @@ import ( "github.com/safing/portmaster/profile/endpoints" ) -// Call order: -// -// DNS Query: -// 1. DecideOnConnection -// is called when a DNS query is made, may set verdict to Undeterminable to permit a DNS reply. -// is called with a nil packet. -// 2. DecideOnResolvedDNS -// is called to (possibly) filter out A/AAAA records that the filter would deny later. -// -// Network Connection: -// 3. DecideOnConnection -// is called with the first packet of a network connection. - const noReasonOptionKey = "" type deciderFn func(context.Context, *network.Connection, *profile.LayeredProfile, packet.Packet) bool diff --git a/firewall/metrics.go b/firewall/metrics.go deleted file mode 100644 index 30d56b4b..00000000 --- a/firewall/metrics.go +++ /dev/null @@ -1,21 +0,0 @@ -package firewall - -import ( - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/metrics" -) - -var packetHandlingHistogram *metrics.Histogram - -func registerMetrics() (err error) { - packetHandlingHistogram, err = metrics.NewHistogram( - "firewall/handling/duration/seconds", - nil, - &metrics.Options{ - Permission: api.PermitUser, - ExpertiseLevel: config.ExpertiseLevelExpert, - }) - - return err -} diff --git a/netquery/manager.go b/netquery/manager.go index 2a7a70a6..6599d619 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -87,6 +87,9 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect if !ok { return } + if !conn.DataIsComplete() { + continue + } model, err := convertConnection(conn) if err != nil { diff --git a/network/clean.go b/network/clean.go index 6e4d9f83..e4042d7c 100644 --- a/network/clean.go +++ b/network/clean.go @@ -51,6 +51,11 @@ func cleanConnections() (activePIDs map[int]struct{}) { // delete inactive connections switch { + case !conn.DataIsComplete(): + // Step 0: delete old incomplete connections + if conn.Started < deleteOlderThan { + conn.delete() + } case conn.Ended == 0: // Step 1: check if still active exists := state.Exists(&packet.Info{ diff --git a/network/connection.go b/network/connection.go index f9b2379d..fe3e92fe 100644 --- a/network/connection.go +++ b/network/connection.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "github.com/tevino/abool" + "github.com/safing/portbase/database/record" "github.com/safing/portbase/log" "github.com/safing/portmaster/intel" @@ -102,6 +104,8 @@ type Connection struct { //nolint:maligned // TODO: fix alignment // set for connections created from DNS requests. LocalPort is // considered immutable once a connection object has been created. LocalPort uint16 + // PID holds the PID of the owning process. + PID int // Entity describes the remote entity that the connection has been // established to. The entity might be changed or information might // be added to it during the livetime of a connection. Access to @@ -168,6 +172,19 @@ type Connection struct { //nolint:maligned // TODO: fix alignment StopTunnel() error } + // pkgQueue is used to serialize packet handling for a single + // connection and is served by the connections packetHandler. + pktQueue chan packet.Packet + // pktQueueActive signifies whether the packet queue is active and may be written to. + pktQueueActive bool + // pktQueueLock locks access to pktQueueActive and writing to pktQueue. + pktQueueLock sync.Mutex + + // dataComplete signifies that all information about the connection is + // available and an actual packet has been seen. + // As long as this flag is not set, the connection may not be evaluated for + // a verdict and may not be sent to the UI. + dataComplete *abool.AtomicBool // Internal is set to true if the connection is attributed as an // Portmaster internal connection. Internal may be set at different // points and access to it must be guarded by the connection lock. @@ -175,9 +192,6 @@ type Connection struct { //nolint:maligned // TODO: fix alignment // process holds a reference to the actor process. That is, the // process instance that initiated the connection. process *process.Process - // pkgQueue is used to serialize packet handling for a single - // connection and is served by the connections packetHandler. - pktQueue chan packet.Packet // firewallHandler is the firewall handler that is called for // each packet sent to pktQueue. firewallHandler FirewallHandler @@ -250,8 +264,11 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri ipVersion = packet.IPv4 } - // get Process - proc, _, err := process.GetProcessByConnection( + // Get Process. + // FIXME: Find direct or redirected connection and grab the PID from there. + + // Find process by remote IP/Port. + pid, _, _ := process.GetPidOfConnection( ctx, &packet.Info{ Inbound: false, // outbound as we are looking for the process of the source address @@ -261,18 +278,17 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri SrcPort: localPort, // source as in the process we are looking for Dst: nil, // do not record direction DstPort: 0, // do not record direction + PID: process.UndefinedProcessID, }, ) - if err != nil { - log.Tracer(ctx).Debugf("network: failed to find process of dns request for %s: %s", fqdn, err) - proc = process.GetUnidentifiedProcess(ctx) - } + proc, _ := process.GetProcessWithProfile(ctx, pid) timestamp := time.Now().Unix() dnsConn := &Connection{ ID: connID, Type: DNSRequest, Scope: fqdn, + PID: proc.Pid, Entity: &intel.Entity{ Domain: fqdn, CNAME: cnames, @@ -281,6 +297,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri ProcessContext: getProcessContext(ctx, proc), Started: timestamp, Ended: timestamp, + dataComplete: abool.NewBool(true), } // Inherit internal status of profile. @@ -292,6 +309,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri // query. Blocked requests are saved immediately, accepted ones are only // saved if they are not "used" by a connection. + dnsConn.UpdateMeta() return dnsConn } @@ -308,6 +326,7 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname Type: DNSRequest, External: true, Scope: fqdn, + PID: process.NetworkHostProcessID, Entity: &intel.Entity{ Domain: fqdn, CNAME: cnames, @@ -316,6 +335,7 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname ProcessContext: getProcessContext(ctx, remoteHost), Started: timestamp, Ended: timestamp, + dataComplete: abool.NewBool(true), } // Inherit internal status of profile. @@ -327,131 +347,152 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname // query. Blocked requests are saved immediately, accepted ones are only // saved if they are not "used" by a connection. + dnsConn.UpdateMeta() return dnsConn, nil } -// NewConnectionFromFirstPacket returns a new connection based on the given packet. -func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { - // get Process - proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info()) - if err != nil { - log.Tracer(pkt.Ctx()).Debugf("network: failed to find process of packet %s: %s", pkt, err) - if inbound && !netutils.ClassifyIP(pkt.Info().Dst).IsLocalhost() { - proc = process.GetUnsolicitedProcess(pkt.Ctx()) - } else { - proc = process.GetUnidentifiedProcess(pkt.Ctx()) - } +// NewIncompleteConnection creates a new incomplete connection with only minimal information. +func NewIncompleteConnection(pkt packet.Packet) *Connection { + info := pkt.Info() + + // Create new connection object. + // We do not yet know the direction of the connection for sure, so we can only set minimal information. + conn := &Connection{ + ID: pkt.GetConnectionID(), + Type: IPConnection, + IPVersion: info.Version, + IPProtocol: info.Protocol, + Started: info.SeenAt.Unix(), + PID: info.PID, + dataComplete: abool.NewBool(false), } - // Create the (remote) entity. - entity := &intel.Entity{ - Protocol: uint8(pkt.Info().Protocol), - Port: pkt.Info().RemotePort(), + // Save connection to internal state in order to mitigate creation of + // duplicates. Do not propagate yet, as data is not yet complete. + conn.UpdateMeta() + conns.add(conn) + + return conn +} + +// GatherConnectionInfo gathers information on the process and remote entity. +func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { + // Get PID if not yet available. + // FIXME: Only match for UndefinedProcessID when integrations have been updated. + if conn.PID <= 0 { + // Get process by looking at the system state tables. + // Apply direction as reported from the state tables. + conn.PID, conn.Inbound, _ = process.GetPidOfConnection(pkt.Ctx(), pkt.Info()) + // Errors are informational and are logged to the context. } - entity.SetIP(pkt.Info().RemoteIP()) - entity.SetDstPort(pkt.Info().DstPort) - var scope string - var resolverInfo *resolver.ResolverInfo - var dnsContext *resolver.DNSRequestContext - - if inbound { - switch entity.IPScope { - case netutils.HostLocal: - scope = IncomingHost - case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - scope = IncomingLAN - case netutils.Global, netutils.GlobalMulticast: - scope = IncomingInternet - - case netutils.Undefined, netutils.Invalid: - fallthrough - default: - scope = IncomingInvalid - } - } else { - - // check if we can find a domain for that IP - ipinfo, err := resolver.GetIPInfo(proc.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String()) + // Get Process and Profile. + if conn.process == nil { + // We got connection from the system. + conn.process, err = process.GetProcessWithProfile(pkt.Ctx(), conn.PID) if err != nil { - // Try again with the global scope, in case DNS went through the system resolver. - ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String()) - } - if err == nil { - lastResolvedDomain := ipinfo.MostRecentDomain() - if lastResolvedDomain != nil { - scope = lastResolvedDomain.Domain - entity.Domain = lastResolvedDomain.Domain - entity.CNAME = lastResolvedDomain.CNAMEs - dnsContext = lastResolvedDomain.DNSRequestContext - resolverInfo = lastResolvedDomain.Resolver - removeOpenDNSRequest(proc.Pid, lastResolvedDomain.Domain) - } + conn.process = nil + err = fmt.Errorf("failed to get process and profile of PID %d: %w", conn.PID, err) + log.Tracer(pkt.Ctx()).Debugf("network: %s", err) + return err } - // check if destination IP is the captive portal's IP - portal := netenv.GetCaptivePortal() - if pkt.Info().RemoteIP().Equal(portal.IP) { - scope = portal.Domain - entity.Domain = portal.Domain - } + // Add process/profile metadata for connection. + conn.ProcessContext = getProcessContext(pkt.Ctx(), conn.process) + conn.ProfileRevisionCounter = conn.process.Profile().RevisionCnt() - if scope == "" { - // outbound direct (possibly P2P) connection - switch entity.IPScope { + // Inherit internal status of profile. + if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil { + conn.Internal = localProfile.Internal + } + } + + // Create remote entity. + if conn.Entity == nil { + // Remote + conn.Entity = &intel.Entity{ + Protocol: uint8(pkt.Info().Protocol), + Port: pkt.Info().RemotePort(), + } + conn.Entity.SetIP(pkt.Info().RemoteIP()) + conn.Entity.SetDstPort(pkt.Info().DstPort) + // Local + conn.SetLocalIP(pkt.Info().LocalIP()) + conn.LocalPort = pkt.Info().LocalPort() + + if conn.Inbound { + switch conn.Entity.IPScope { case netutils.HostLocal: - scope = PeerHost + conn.Scope = IncomingHost case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - scope = PeerLAN + conn.Scope = IncomingLAN case netutils.Global, netutils.GlobalMulticast: - scope = PeerInternet + conn.Scope = IncomingInternet case netutils.Undefined, netutils.Invalid: fallthrough default: - scope = PeerInvalid + conn.Scope = IncomingInvalid + } + } else { + + // check if we can find a domain for that IP + ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String()) + if err != nil { + // Try again with the global scope, in case DNS went through the system resolver. + ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String()) + } + if err == nil { + lastResolvedDomain := ipinfo.MostRecentDomain() + if lastResolvedDomain != nil { + conn.Scope = lastResolvedDomain.Domain + conn.Entity.Domain = lastResolvedDomain.Domain + conn.Entity.CNAME = lastResolvedDomain.CNAMEs + conn.DNSContext = lastResolvedDomain.DNSRequestContext + conn.Resolver = lastResolvedDomain.Resolver + removeOpenDNSRequest(conn.process.Pid, lastResolvedDomain.Domain) + } + } + + // check if destination IP is the captive portal's IP + portal := netenv.GetCaptivePortal() + if pkt.Info().RemoteIP().Equal(portal.IP) { + conn.Scope = portal.Domain + conn.Entity.Domain = portal.Domain + } + + if conn.Scope == "" { + // outbound direct (possibly P2P) connection + switch conn.Entity.IPScope { + case netutils.HostLocal: + conn.Scope = PeerHost + case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: + conn.Scope = PeerLAN + case netutils.Global, netutils.GlobalMulticast: + conn.Scope = PeerInternet + + case netutils.Undefined, netutils.Invalid: + fallthrough + default: + conn.Scope = PeerInvalid + } } } } - // Create new connection object. - newConn := &Connection{ - ID: pkt.GetConnectionID(), - Type: IPConnection, - Scope: scope, - IPVersion: pkt.Info().Version, - Inbound: inbound, - // local endpoint - IPProtocol: pkt.Info().Protocol, - LocalPort: pkt.Info().LocalPort(), - ProcessContext: getProcessContext(pkt.Ctx(), proc), - DNSContext: dnsContext, - process: proc, - // remote endpoint - Entity: entity, - // resolver used to resolve dns request - Resolver: resolverInfo, - // meta - Started: time.Now().Unix(), - ProfileRevisionCounter: proc.Profile().RevisionCnt(), - } - newConn.SetLocalIP(pkt.Info().LocalIP()) - - // Inherit internal status of profile. - if localProfile := proc.Profile().LocalProfile(); localProfile != nil { - newConn.Internal = localProfile.Internal + // Data collection is only complete with a packet. + if pkt.InfoOnly() { + return nil } - // Save connection to internal state in order to mitigate creation of - // duplicates. Do not propagate yet, as there is no verdict yet. - conns.add(newConn) - - return newConn + // If we have all data and have seen an actual packet, the connection data is complete. + conn.dataComplete.Set() + return nil } // GetConnection fetches a Connection from the database. -func GetConnection(id string) (*Connection, bool) { - return conns.get(id) +func GetConnection(connID string) (*Connection, bool) { + return conns.get(connID) } // GetAllConnections Gets all connection. @@ -563,6 +604,14 @@ func (conn *Connection) VerdictVerb() string { ) } +// DataIsComplete returns whether all information about the connection is +// available and an actual packet has been seen. +// As long as this flag is not set, the connection may not be evaluated for +// a verdict and may not be sent to the UI. +func (conn *Connection) DataIsComplete() bool { + return conn.dataComplete.IsSet() +} + // Process returns the connection's process. func (conn *Connection) Process() *process.Process { return conn.process @@ -579,9 +628,13 @@ func (conn *Connection) SaveWhenFinished() { // Callers must make sure to lock the connection itself before calling // Save(). func (conn *Connection) Save() { - conn.addToMetrics() conn.UpdateMeta() + // Do not save/update until data is complete. + if !conn.DataIsComplete() { + return + } + if !conn.KeyIsSet() { if conn.Type == DNSRequest { conn.SetKey(makeKey(conn.process.Pid, dbScopeDNS, conn.ID)) @@ -592,6 +645,8 @@ func (conn *Connection) Save() { } } + conn.addToMetrics() + // notify database controller dbController.PushUpdate(conn) } @@ -610,29 +665,61 @@ func (conn *Connection) delete() { } conn.Meta().Delete() - dbController.PushUpdate(conn) + + // Notify database controller if data is complete and thus connection was previously exposed. + if conn.DataIsComplete() { + dbController.PushUpdate(conn) + } } // SetFirewallHandler sets the firewall handler for this link, and starts a // worker to handle the packets. // The caller needs to hold a lock on the connection. +// Cannot be called with "nil" handler. Call StopFirewallHandler() instead. func (conn *Connection) SetFirewallHandler(handler FirewallHandler) { - if conn.firewallHandler == nil { - conn.pktQueue = make(chan packet.Packet, 100) + if handler == nil { + return + } + // Start packet handler worker when first handler is set. + if conn.firewallHandler == nil { // start handling module.StartWorker("packet handler", conn.packetHandlerWorker) } + + // Set new handler. conn.firewallHandler = handler + + // Initialize packet queue, if needed. + conn.pktQueueLock.Lock() + defer conn.pktQueueLock.Unlock() + if !conn.pktQueueActive { + conn.pktQueue = make(chan packet.Packet, 100) + conn.pktQueueActive = true + } +} + +// UpdateFirewallHandler sets the firewall handler if it already set and the +// given handler is not nil. +// The caller needs to hold a lock on the connection. +func (conn *Connection) UpdateFirewallHandler(handler FirewallHandler) { + if handler != nil && conn.firewallHandler != nil { + conn.firewallHandler = handler + } } // StopFirewallHandler unsets the firewall handler and stops the handler worker. // The caller needs to hold a lock on the connection. func (conn *Connection) StopFirewallHandler() { + conn.pktQueueLock.Lock() + defer conn.pktQueueLock.Unlock() + + // Unset the firewall handler to revert to the default handler. conn.firewallHandler = nil // Signal the packet handler worker that it can stop. close(conn.pktQueue) + conn.pktQueueActive = false // Unset the packet queue so that it can be freed. conn.pktQueue = nil @@ -640,15 +727,25 @@ func (conn *Connection) StopFirewallHandler() { // HandlePacket queues packet of Link for handling. func (conn *Connection) HandlePacket(pkt packet.Packet) { - conn.Lock() - defer conn.Unlock() + conn.pktQueueLock.Lock() + defer conn.pktQueueLock.Unlock() // execute handler or verdict - if conn.firewallHandler != nil { - conn.pktQueue <- pkt - // TODO: drop if overflowing? + if conn.pktQueueActive { + select { + case conn.pktQueue <- pkt: + default: + log.Debugf( + "filter: dropping packet %s, as there is no space in the connection's handling queue", + pkt, + ) + _ = pkt.Drop() + } } else { defaultFirewallHandler(conn, pkt) + + // Record metrics. + packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt) } } @@ -656,7 +753,12 @@ func (conn *Connection) HandlePacket(pkt packet.Packet) { func (conn *Connection) packetHandlerWorker(ctx context.Context) error { // Copy packet queue, so we can remove the reference from the connection // when we stop the firewall handler. - pktQueue := conn.pktQueue + var pktQueue chan packet.Packet + func() { + conn.pktQueueLock.Lock() + defer conn.pktQueueLock.Unlock() + pktQueue = conn.pktQueue + }() for { select { @@ -664,21 +766,27 @@ func (conn *Connection) packetHandlerWorker(ctx context.Context) error { if pkt == nil { return nil } - packetHandlerHandleConn(conn, pkt) + packetHandlerHandleConn(ctx, conn, pkt) case <-ctx.Done(): - conn.Lock() - defer conn.Unlock() - conn.firewallHandler = nil return nil } } } -func packetHandlerHandleConn(conn *Connection, pkt packet.Packet) { +func packetHandlerHandleConn(ctx context.Context, conn *Connection, pkt packet.Packet) { conn.Lock() defer conn.Unlock() + // Create tracing context. + // Add context tracer and set context on packet. + traceCtx, tracer := log.AddTracer(ctx) + if tracer != nil { + // The trace is submitted in `network.Connection.packetHandler()`. + tracer.Tracef("filter: handling packet: %s", pkt) + } + pkt.SetCtx(traceCtx) + // Handle packet with appropriate handler. if conn.firewallHandler != nil { conn.firewallHandler(conn, pkt) @@ -686,13 +794,22 @@ func packetHandlerHandleConn(conn *Connection, pkt packet.Packet) { defaultFirewallHandler(conn, pkt) } - // Log verdict. - log.Tracer(pkt.Ctx()).Infof("filter: connection %s %s: %s", conn, conn.VerdictVerb(), conn.Reason.Msg) - // Submit trace logs. - log.Tracer(pkt.Ctx()).Submit() + // Record metrics. + packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt) - // Save() itself does not touch any changing data. - // Must not be locked - would deadlock with cleaner functions. + // Log result and submit trace. + switch { + case conn.DataIsComplete(): + tracer.Infof("filter: connection %s %s: %s", conn, conn.VerdictVerb(), conn.Reason.Msg) + case conn.Verdict.Firewall != VerdictUndecided: + tracer.Debugf("filter: connection %s fast-tracked", conn) + default: + tracer.Infof("filter: gathered data on connection %s", conn) + } + // Submit trace logs. + tracer.Submit() + + // Push changes, if there are any. if conn.saveWhenFinished { conn.saveWhenFinished = false conn.Save() diff --git a/network/connection_android.go b/network/connection_android.go index d7479e1f..6fd30157 100644 --- a/network/connection_android.go +++ b/network/connection_android.go @@ -11,6 +11,7 @@ import ( "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/process" "github.com/safing/spn/navigator" + "github.com/tevino/abool" ) // NewDefaultConnection creates a new connection with default values except local and remote IPs and protocols. @@ -25,6 +26,7 @@ func NewDefaultConnection(localIP net.IP, localPort uint16, remoteIP net.IP, rem LocalIP: localIP, LocalIPScope: netutils.Global, LocalPort: localPort, + PID: process.UnidentifiedProcessID, Entity: &intel.Entity{ Protocol: uint8(protocol), IP: remoteIP, @@ -35,6 +37,7 @@ func NewDefaultConnection(localIP net.IP, localPort uint16, remoteIP net.IP, rem VerdictPermanent: false, Tunneled: true, Encrypted: false, + DataComplete: abool.NewBool(true), Internal: false, addedToMetrics: true, // Metrics are not needed for now. This will mark the Connection to be ignored. process: process.GetUnidentifiedProcess(context.Background()), diff --git a/network/database.go b/network/database.go index 159ba1d3..457b2693 100644 --- a/network/database.go +++ b/network/database.go @@ -90,12 +90,12 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { switch scope { case dbScopeDNS: - if r, ok := dnsConns.get(id); ok { - return r, nil + if c, ok := dnsConns.get(id); ok && c.DataIsComplete() { + return c, nil } case dbScopeIP: - if r, ok := conns.get(id); ok { - return r, nil + if c, ok := conns.get(id); ok && c.DataIsComplete() { + return c, nil } case dbScopeNone: if proc, ok := process.GetProcessFromStorage(pid); ok { @@ -143,11 +143,16 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { if scope == dbScopeNone || scope == dbScopeDNS { // dns scopes only for _, dnsConn := range dnsConns.clone() { + if !dnsConn.DataIsComplete() { + continue + } + func() { dnsConn.Lock() defer dnsConn.Unlock() matches = q.Matches(dnsConn) }() + if matches { it.Next <- dnsConn } @@ -157,11 +162,16 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { if scope == dbScopeNone || scope == dbScopeIP { // connections for _, conn := range conns.clone() { + if !conn.DataIsComplete() { + continue + } + func() { conn.Lock() defer conn.Unlock() matches = q.Matches(conn) }() + if matches { it.Next <- conn } diff --git a/network/metrics.go b/network/metrics.go index db5e701e..28a3b046 100644 --- a/network/metrics.go +++ b/network/metrics.go @@ -8,6 +8,7 @@ import ( ) var ( + packetHandlingHistogram *metrics.Histogram blockedOutConnCounter *metrics.Counter encryptedAndTunneledOutConnCounter *metrics.Counter encryptedOutConnCounter *metrics.Counter @@ -15,8 +16,21 @@ var ( outConnCounter *metrics.Counter ) -func registerMetrics() error { - _, err := metrics.NewGauge( +func registerMetrics() (err error) { + // This needed to be moved here, because every packet is now handled by the + // connection handler worker. + packetHandlingHistogram, err = metrics.NewHistogram( + "firewall/handling/duration/seconds", + nil, + &metrics.Options{ + Permission: api.PermitUser, + ExpertiseLevel: config.ExpertiseLevelExpert, + }) + if err != nil { + return err + } + + _, err = metrics.NewGauge( "network/connections/active/total", nil, func() float64 { diff --git a/network/multicast.go b/network/multicast.go index 8a32aea7..d7c8f9a7 100644 --- a/network/multicast.go +++ b/network/multicast.go @@ -15,6 +15,8 @@ func GetMulticastRequestConn(responseConn *Connection, responseFromNet *net.IPNe // Find requesting multicast/broadcast connection. for _, conn := range conns.clone() { switch { + case !conn.DataIsComplete(): + // Ignore connection with incomplete data. case conn.Inbound: // Ignore incoming connections. case conn.Ended != 0: diff --git a/network/packet/packet.go b/network/packet/packet.go index 294a79c2..6034f46f 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -24,6 +24,12 @@ func (pkt *Base) FastTrackedByIntegration() bool { return false } +// InfoOnly returns whether the packet is informational only and does not +// represent an actual packet. +func (pkt *Base) InfoOnly() bool { + return false +} + // SetCtx sets the packet context. func (pkt *Base) SetCtx(ctx context.Context) { pkt.ctx = ctx @@ -107,6 +113,7 @@ func (pkt *Base) GetConnectionID() string { } func (pkt *Base) createConnectionID() { + // TODO: make this ID not depend on the packet direction for better support for forwarded packets. if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { if pkt.info.Inbound { pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) @@ -236,6 +243,7 @@ type Packet interface { RerouteToNameserver() error RerouteToTunnel() error FastTrackedByIntegration() bool + InfoOnly() bool // Info. SetCtx(context.Context) diff --git a/network/packet/packetinfo.go b/network/packet/packetinfo.go index 4d7c997b..37f40216 100644 --- a/network/packet/packetinfo.go +++ b/network/packet/packetinfo.go @@ -2,6 +2,7 @@ package packet import ( "net" + "time" ) // Info holds IP and TCP/UDP header information. @@ -14,7 +15,8 @@ type Info struct { SrcPort, DstPort uint16 Src, Dst net.IP - PID uint32 + PID int + SeenAt time.Time } // LocalIP returns the local IP of the packet. diff --git a/network/ports.go b/network/ports.go index a0982461..ab870ff0 100644 --- a/network/ports.go +++ b/network/ports.go @@ -28,13 +28,18 @@ nextPort: // Check if the generated port is unused. nextConnection: for _, conn := range allConns { - // Skip connection if the protocol does not match the protocol of interest. - if conn.Entity.Protocol != protocol { + switch { + case !conn.DataIsComplete(): + // Skip connection if the data is not complete. continue nextConnection - } - // Skip port if the local port is in dangerous proximity. - // Consecutive port numbers are very common. - if conn.LocalPort <= port && conn.LocalPort >= portRangeStart { + + case conn.Entity.Protocol != protocol: + // Skip connection if the protocol does not match the protocol of interest. + continue nextConnection + + case conn.LocalPort <= port && conn.LocalPort >= portRangeStart: + // Skip port if the local port is in dangerous proximity. + // Consecutive port numbers are very common. continue nextPort } } diff --git a/process/find.go b/process/find.go index beb2a57e..41d1459a 100644 --- a/process/find.go +++ b/process/find.go @@ -14,43 +14,19 @@ import ( "github.com/safing/portmaster/profile" ) -// GetProcessByConnection returns the process that owns the described connection. -func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process *Process, connInbound bool, err error) { +// GetProcessWithProfile returns the process, including the profile. +// Always returns valid data. +// Errors are logged and returned for information or special handling purposes. +func GetProcessWithProfile(ctx context.Context, pid int) (process *Process, err error) { if !enableProcessDetection() { log.Tracer(ctx).Tracef("process: process detection disabled") - return GetUnidentifiedProcess(ctx), pktInfo.Inbound, nil - } - - // Use fast search for inbound packets, as the listening socket should - // already be there for a while now. - fastSearch := pktInfo.Inbound - - var pid int - if pktInfo.PID == 0 { - log.Tracer(ctx).Tracef("process: getting pid from system network state") - pid, connInbound, err = state.Lookup(pktInfo, fastSearch) - if err != nil { - log.Tracer(ctx).Tracef("process: failed to find PID of connection: %s", err) - return nil, pktInfo.Inbound, err - } - } else { - log.Tracer(ctx).Tracef("process: pid already set in packet (by ebpf or kext)") - pid = int(pktInfo.PID) - } - - // Fallback to special profiles if PID could not be found. - if pid == UndefinedProcessID { - if connInbound { - pid = UnsolicitedProcessID - } else { - pid = UnidentifiedProcessID - } + return GetUnidentifiedProcess(ctx), nil } process, err = GetOrFindProcess(ctx, pid) if err != nil { - log.Tracer(ctx).Debugf("process: failed to find (primary) process with PID: %s", err) - return nil, connInbound, err + log.Tracer(ctx).Warningf("process: failed to find process with PID: %s", err) + return GetUnidentifiedProcess(ctx), err } changed, err := process.GetProfile(ctx) @@ -62,7 +38,46 @@ func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process process.Save() } - return process, connInbound, nil + return process, nil +} + +// GetPidOfConnection returns the PID of the process that owns the described connection. +// Always returns valid data. +// Errors are logged and returned for information or special handling purposes. +func GetPidOfConnection(ctx context.Context, pktInfo *packet.Info) (pid int, connInbound bool, err error) { + if !enableProcessDetection() { + return UnidentifiedProcessID, pktInfo.Inbound, nil + } + + // Use fast search for inbound packets, as the listening socket should + // already be there for a while now. + fastSearch := pktInfo.Inbound + connInbound = pktInfo.Inbound + + // FIXME: Only match for UndefinedProcessID when integrations have been updated. + if pktInfo.PID <= 0 { + log.Tracer(ctx).Tracef("process: getting pid from system network state") + pid, connInbound, err = state.Lookup(pktInfo, fastSearch) + if err != nil { + err = fmt.Errorf("failed to find PID of connection: %w", err) + log.Tracer(ctx).Tracef("process: %s", err) + pid = UndefinedProcessID + } + } else { + log.Tracer(ctx).Tracef("process: pid already set in packet (by ebpf or kext)") + pid = pktInfo.PID + } + + // Fallback to special profiles if PID could not be found. + if pid == UndefinedProcessID { + if connInbound && !netutils.ClassifyIP(pktInfo.Dst).IsLocalhost() { + pid = UnsolicitedProcessID + } else { + pid = UnidentifiedProcessID + } + } + + return pid, connInbound, err } // GetNetworkHost returns a *Process that represents a host on the network. @@ -113,7 +128,12 @@ func GetProcessByRequestOrigin(ar *api.Request) (*Process, error) { SrcPort: remotePort, // source as in the process we are looking for } - proc, _, err := GetProcessByConnection(ar.Context(), pkt) + pid, _, err := GetPidOfConnection(ar.Context(), pkt) + if err != nil { + return nil, err + } + + proc, err := GetProcessWithProfile(ar.Context(), pid) if err != nil { return nil, err } diff --git a/process/process.go b/process/process.go index 3d386b37..3f2779f9 100644 --- a/process/process.go +++ b/process/process.go @@ -217,7 +217,7 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (* } // UID - // net yet implemented for windows + // TODO: implemented for windows if onLinux { var uids []int32 uids, err = pInfo.UidsWithContext(ctx)