diff --git a/firewall/module.go b/firewall/module.go index dd4dcbaa..1e9925b5 100644 --- a/firewall/module.go +++ b/firewall/module.go @@ -9,6 +9,7 @@ import ( "github.com/safing/portbase/modules/subsystems" _ "github.com/safing/portmaster/core" "github.com/safing/portmaster/network" + "github.com/safing/spn/access" ) var module *modules.Module @@ -57,7 +58,7 @@ func prep() error { }, ) if err != nil { - log.Errorf("interception: failed registering event hook: %s", err) + log.Errorf("filter: failed to register event hook: %s", err) } // Reset connections every time profile changes @@ -71,7 +72,7 @@ func prep() error { }, ) if err != nil { - log.Errorf("failed registering event hook: %s", err) + log.Errorf("filter: failed to register event hook: %s", err) } // Reset connections when spn is connected @@ -86,7 +87,22 @@ func prep() error { }, ) if err != nil { - log.Errorf("failed registering event hook: %s", err) + log.Errorf("filter: failed to register event hook: %s", err) + } + + // Reset connections when account is updated. + // This will not change verdicts, but will update the feature flags on connections. + err = module.RegisterEventHook( + "access", + access.AccountUpdateEvent, + "update connection feature flags", + func(ctx context.Context, _ interface{}) error { + resetAllConnectionVerdicts() + return nil + }, + ) + if err != nil { + log.Errorf("filter: failed to register event hook: %s", err) } if err := registerConfig(); err != nil { diff --git a/firewall/packet_handler.go b/firewall/packet_handler.go index 7c364451..64a22d86 100644 --- a/firewall/packet_handler.go +++ b/firewall/packet_handler.go @@ -23,6 +23,7 @@ import ( "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/reference" + "github.com/safing/spn/access" ) var ( @@ -61,6 +62,11 @@ func resetAllConnectionVerdicts() { conn.Lock() defer conn.Unlock() + // Update feature flags. + if err := conn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) { + tracer.Warningf("network: failed to update connection feature flags: %s", err) + } + // Skip internal connections: // - Pre-authenticated connections from Portmaster // - Redirected DNS requests diff --git a/network/connection.go b/network/connection.go index dd4d7b22..f2b2c3dc 100644 --- a/network/connection.go +++ b/network/connection.go @@ -340,7 +340,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri if localProfile := proc.Profile().LocalProfile(); localProfile != nil { dnsConn.Internal = localProfile.Internal - if err := dnsConn.updateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) { + if err := dnsConn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) { log.Tracer(ctx).Warningf("network: failed to check for enabled features: %s", err) } } @@ -383,7 +383,7 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname if localProfile := remoteHost.Profile().LocalProfile(); localProfile != nil { dnsConn.Internal = localProfile.Internal - if err := dnsConn.updateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) { + if err := dnsConn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) { log.Tracer(ctx).Warningf("network: failed to check for enabled features: %s", err) } } @@ -507,7 +507,7 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil { conn.Internal = localProfile.Internal - if err := conn.updateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) { + if err := conn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) { log.Tracer(pkt.Ctx()).Warningf("network: connection %s failed to check for enabled features: %s", conn, err) } } @@ -578,12 +578,13 @@ func (conn *Connection) SetLocalIP(ip net.IP) { conn.LocalIPScope = netutils.GetIPScope(ip) } -// updateFeatures checks which connection related features may be used and sets +// UpdateFeatures checks which connection related features may be used and sets // the flags accordingly. -func (conn *Connection) updateFeatures() error { +// The caller must hold a lock on the connection. +func (conn *Connection) UpdateFeatures() error { // Get user. user, err := access.GetUser() - if err != nil { + if err != nil && !errors.Is(err, access.ErrNotLoggedIn) { return err }