From 9ff824967e40c555ed23de54259f36fe3693ceb8 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 7 Apr 2021 16:43:13 +0200 Subject: [PATCH] Get/Create connections in single-inflight lock --- firewall/interception.go | 58 ++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/firewall/interception.go b/firewall/interception.go index 0a4212d9..4edc790c 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -3,12 +3,14 @@ package firewall import ( "context" "errors" + "fmt" "net" "os" "sync/atomic" "time" "github.com/safing/portmaster/netenv" + "golang.org/x/sync/singleflight" "github.com/tevino/abool" @@ -102,20 +104,60 @@ func handlePacket(ctx context.Context, pkt packet.Packet) { } pkt.SetCtx(traceCtx) - // associate packet to link and handle - conn, ok := network.GetConnection(pkt.GetConnectionID()) - if ok { - tracer.Tracef("filter: assigned to connection %s", conn.ID) - } else { - conn = network.NewConnectionFromFirstPacket(pkt) - tracer.Tracef("filter: created new connection %s", conn.ID) - conn.SetFirewallHandler(initialHandler) + // Get connection of packet. + conn, err := getConnection(pkt) + if err != nil { + tracer.Errorf("filter: packet %s dropped: %s", pkt, err) + _ = pkt.Drop() + return } // handle packet conn.HandlePacket(pkt) } +var getConnectionSingleInflight singleflight.Group + +func getConnection(pkt packet.Packet) (*network.Connection, error) { + created := false + + // Create or get connection in single inflight lock in order to prevent duplicates. + newConn, err, shared := getConnectionSingleInflight.Do(pkt.GetConnectionID(), func() (interface{}, error) { + // First, check for an existing connection. + conn, ok := network.GetConnection(pkt.GetConnectionID()) + if ok { + return conn, nil + } + + // Else create new one from the packet. + conn = network.NewConnectionFromFirstPacket(pkt) + conn.SetFirewallHandler(initialHandler) + created = true + return conn, nil + }) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %s", err) + } + if newConn == nil { + return nil, errors.New("connection getter returned nil") + } + + // Transform and log result. + conn := newConn.(*network.Connection) + switch { + case created && shared: + log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s (shared)", conn.ID) + case created: + log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s", conn.ID) + case shared: + log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s (shared)", conn.ID) + default: + log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s", conn.ID) + } + + return conn, nil +} + // fastTrackedPermit quickly permits certain network criticial or internal connections. func fastTrackedPermit(pkt packet.Packet) (handled bool) { meta := pkt.Info()