Update all connection feature flags on account update

This commit is contained in:
Daniel
2023-08-07 16:49:12 +02:00
parent 6c4a77a205
commit 37fbc4b74c
3 changed files with 32 additions and 9 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/safing/portbase/modules/subsystems"
_ "github.com/safing/portmaster/core"
"github.com/safing/portmaster/network"
"github.com/safing/spn/access"
)
var module *modules.Module
@@ -57,7 +58,7 @@ func prep() error {
},
)
if err != nil {
log.Errorf("interception: failed registering event hook: %s", err)
log.Errorf("filter: failed to register event hook: %s", err)
}
// Reset connections every time profile changes
@@ -71,7 +72,7 @@ func prep() error {
},
)
if err != nil {
log.Errorf("failed registering event hook: %s", err)
log.Errorf("filter: failed to register event hook: %s", err)
}
// Reset connections when spn is connected
@@ -86,7 +87,22 @@ func prep() error {
},
)
if err != nil {
log.Errorf("failed registering event hook: %s", err)
log.Errorf("filter: failed to register event hook: %s", err)
}
// Reset connections when account is updated.
// This will not change verdicts, but will update the feature flags on connections.
err = module.RegisterEventHook(
"access",
access.AccountUpdateEvent,
"update connection feature flags",
func(ctx context.Context, _ interface{}) error {
resetAllConnectionVerdicts()
return nil
},
)
if err != nil {
log.Errorf("filter: failed to register event hook: %s", err)
}
if err := registerConfig(); err != nil {

View File

@@ -23,6 +23,7 @@ import (
"github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/reference"
"github.com/safing/spn/access"
)
var (
@@ -61,6 +62,11 @@ func resetAllConnectionVerdicts() {
conn.Lock()
defer conn.Unlock()
// Update feature flags.
if err := conn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
tracer.Warningf("network: failed to update connection feature flags: %s", err)
}
// Skip internal connections:
// - Pre-authenticated connections from Portmaster
// - Redirected DNS requests

View File

@@ -340,7 +340,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
if localProfile := proc.Profile().LocalProfile(); localProfile != nil {
dnsConn.Internal = localProfile.Internal
if err := dnsConn.updateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
if err := dnsConn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
log.Tracer(ctx).Warningf("network: failed to check for enabled features: %s", err)
}
}
@@ -383,7 +383,7 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname
if localProfile := remoteHost.Profile().LocalProfile(); localProfile != nil {
dnsConn.Internal = localProfile.Internal
if err := dnsConn.updateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
if err := dnsConn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
log.Tracer(ctx).Warningf("network: failed to check for enabled features: %s", err)
}
}
@@ -507,7 +507,7 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil {
conn.Internal = localProfile.Internal
if err := conn.updateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
if err := conn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
log.Tracer(pkt.Ctx()).Warningf("network: connection %s failed to check for enabled features: %s", conn, err)
}
}
@@ -578,12 +578,13 @@ func (conn *Connection) SetLocalIP(ip net.IP) {
conn.LocalIPScope = netutils.GetIPScope(ip)
}
// updateFeatures checks which connection related features may be used and sets
// UpdateFeatures checks which connection related features may be used and sets
// the flags accordingly.
func (conn *Connection) updateFeatures() error {
// The caller must hold a lock on the connection.
func (conn *Connection) UpdateFeatures() error {
// Get user.
user, err := access.GetUser()
if err != nil {
if err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
return err
}