diff --git a/firewall/interception.go b/firewall/interception.go index 592e6c0a..1e985d0b 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -70,7 +70,7 @@ func interceptionPrep() error { err := interceptionModule.RegisterEventHook( "config", configChangeEvent, - "firewall config change event", + "reset connection verdicts", func(ctx context.Context, _ interface{}) error { resetPersistentVerdicts() return nil @@ -84,7 +84,7 @@ func interceptionPrep() error { err = interceptionModule.RegisterEventHook( "profiles", profileConfigChangeEvent, - "firewall profile change event", + "reset connection verdicts", func(ctx context.Context, _ interface{}) error { resetPersistentVerdicts() return nil @@ -99,7 +99,7 @@ func interceptionPrep() error { err = interceptionModule.RegisterEventHook( "captain", onSPNConnectEvent, - "firewall spn connect event", + "reset connection verdicts", func(ctx context.Context, _ interface{}) error { resetPersistentVerdicts() return nil @@ -118,30 +118,30 @@ func interceptionPrep() error { func resetPersistentVerdicts() { // Resetting will force all the connection to be evaluated by the firewall again - // this will set new verdicts if configuration was update or spn has been disabled or enabled - log.Info("interception: resetting all connections") + // this will set new verdicts if configuration was update or spn has been disabled or enabled. + log.Info("interception: reevaluating all connection trough the firewall") - // reset all connection firewall handlers. This will tell the master to rerun the firewall checks + // reset all connection firewall handlers. This will tell the master to rerun the firewall checks. for _, conn := range network.GetAllConnections() { - conn.Lock() isSPNConnection := captain.IsExcepted(conn.Entity.IP) && conn.Process().Pid == ownPID - // mark all non SPN connections to be processed by the firewall + // mark all non SPN connections to be processed by the firewall. if !isSPNConnection { + conn.Lock() conn.SetFirewallHandler(initialHandler) - // Don't keep the previous tunneled value + // Don't keep the previous tunneled value. conn.Tunneled = false // Reset entity if it exists. if conn.Entity != nil { conn.Entity.ResetLists() } + conn.Unlock() } - conn.Unlock() } err := interception.ResetVerdictOfAllConnections() if err != nil { - log.Errorf("interception: failed to reset connections verdict: %s", err) + log.Errorf("interception: failed to remove persistent verdicts: %s", err) } } @@ -177,8 +177,6 @@ func SetNameserverIPMatcher(fn func(ip net.IP) bool) error { } func handlePacket(ctx context.Context, pkt packet.Packet) { - // log.Errorf("DEBUG: firewall: handling packet %s", pkt) - // Record metrics. startTime := time.Now() defer packetHandlingHistogram.UpdateDuration(startTime) @@ -222,7 +220,9 @@ func getConnection(pkt packet.Packet) (*network.Connection, error) { // Else create new one from the packet. conn = network.NewConnectionFromFirstPacket(pkt) + conn.Lock() conn.SetFirewallHandler(initialHandler) + conn.Unlock() created = true return conn, nil }) @@ -248,27 +248,6 @@ func getConnection(pkt packet.Packet) (*network.Connection, error) { return conn, nil } -func getConnectionByID(id string) (*network.Connection, error) { - // Create or get connection in single inflight lock in order to prevent duplicates. - connPtr, _, _ := getConnectionSingleInflight.Do(id, func() (interface{}, error) { - // First, check for an existing connection. - conn, ok := network.GetConnection(id) - if ok { - return conn, nil - } - - // Else return nil - return nil, nil - }) - - if connPtr == nil { - return nil, errors.New("connection does not exist") - } - - connection := connPtr.(*network.Connection) //nolint:forcetypeassert // Can only be a *network.Connection. - return connection, nil -} - // fastTrackedPermit quickly permits certain network critical or internal connections. func fastTrackedPermit(pkt packet.Packet) (handled bool) { meta := pkt.Info()