Merge pull request #1075 from safing/feature/reconnect-resolver-on-spn-connect
Reconnect resolvers on SPN connect and other fixes
This commit is contained in:
@@ -116,7 +116,11 @@ func decideOnConnection(ctx context.Context, conn *network.Connection, pkt packe
|
|||||||
case profile.DefaultActionPermit:
|
case profile.DefaultActionPermit:
|
||||||
conn.Accept("allowed by default action", profile.CfgOptionDefaultActionKey)
|
conn.Accept("allowed by default action", profile.CfgOptionDefaultActionKey)
|
||||||
case profile.DefaultActionAsk:
|
case profile.DefaultActionAsk:
|
||||||
prompt(ctx, conn, pkt)
|
// Only prompt if there has not been a decision already.
|
||||||
|
// This prevents prompts from being created when re-evaluating connections.
|
||||||
|
if conn.Verdict.Firewall == network.VerdictUndecided {
|
||||||
|
prompt(ctx, conn)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
conn.Deny("blocked by default action", profile.CfgOptionDefaultActionKey)
|
conn.Deny("blocked by default action", profile.CfgOptionDefaultActionKey)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"github.com/safing/portbase/notifications"
|
"github.com/safing/portbase/notifications"
|
||||||
"github.com/safing/portmaster/intel"
|
"github.com/safing/portmaster/intel"
|
||||||
"github.com/safing/portmaster/network"
|
"github.com/safing/portmaster/network"
|
||||||
"github.com/safing/portmaster/network/packet"
|
|
||||||
"github.com/safing/portmaster/profile"
|
"github.com/safing/portmaster/profile"
|
||||||
"github.com/safing/portmaster/profile/endpoints"
|
"github.com/safing/portmaster/profile/endpoints"
|
||||||
)
|
)
|
||||||
@@ -47,9 +46,9 @@ type promptProfile struct {
|
|||||||
LinkedPath string
|
LinkedPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
func prompt(ctx context.Context, conn *network.Connection, pkt packet.Packet) {
|
func prompt(ctx context.Context, conn *network.Connection) {
|
||||||
// Create notification.
|
// Create notification.
|
||||||
n := createPrompt(ctx, conn, pkt)
|
n := createPrompt(ctx, conn)
|
||||||
if n == nil {
|
if n == nil {
|
||||||
// createPrompt returns nil when no further action should be taken.
|
// createPrompt returns nil when no further action should be taken.
|
||||||
return
|
return
|
||||||
@@ -81,11 +80,11 @@ func prompt(ctx context.Context, conn *network.Connection, pkt packet.Packet) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// promptIDPrefix is an identifier for privacy filter prompts. This is also use
|
// promptIDPrefix is an identifier for privacy filter prompts. This is also used
|
||||||
// in the UI, so don't change!
|
// in the UI, so don't change!
|
||||||
const promptIDPrefix = "filter:prompt"
|
const promptIDPrefix = "filter:prompt"
|
||||||
|
|
||||||
func createPrompt(ctx context.Context, conn *network.Connection, pkt packet.Packet) (n *notifications.Notification) {
|
func createPrompt(ctx context.Context, conn *network.Connection) (n *notifications.Notification) {
|
||||||
expires := time.Now().Add(time.Duration(askTimeout()) * time.Second).Unix()
|
expires := time.Now().Add(time.Duration(askTimeout()) * time.Second).Unix()
|
||||||
|
|
||||||
// Get local profile.
|
// Get local profile.
|
||||||
@@ -110,7 +109,7 @@ func createPrompt(ctx context.Context, conn *network.Connection, pkt packet.Pack
|
|||||||
promptIDPrefix,
|
promptIDPrefix,
|
||||||
localProfile.ID,
|
localProfile.ID,
|
||||||
conn.Inbound,
|
conn.Inbound,
|
||||||
pkt.Info().RemoteIP(),
|
conn.Entity.IP,
|
||||||
)
|
)
|
||||||
default: // connection to domain
|
default: // connection to domain
|
||||||
nID = fmt.Sprintf(
|
nID = fmt.Sprintf(
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
|
|||||||
// This is different from the Firewall verdict in order to guarantee proper
|
// This is different from the Firewall verdict in order to guarantee proper
|
||||||
// transition between verdicts that need the connection to be re-established.
|
// transition between verdicts that need the connection to be re-established.
|
||||||
Active Verdict
|
Active Verdict
|
||||||
// Firewall holsd the last (most recent) decision by the firewall.
|
// Firewall holds the last (most recent) decision by the firewall.
|
||||||
Firewall Verdict
|
Firewall Verdict
|
||||||
}
|
}
|
||||||
// Reason holds information justifying the verdict, as well as additional
|
// Reason holds information justifying the verdict, as well as additional
|
||||||
|
|||||||
@@ -57,6 +57,20 @@ func start() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Force resolvers to reconnect when SPN has connected.
|
||||||
|
if err := module.RegisterEventHook(
|
||||||
|
"captain",
|
||||||
|
"spn connect", // Defined by captain.SPNConnectedEvent
|
||||||
|
"force resolver reconnect",
|
||||||
|
func(ctx context.Context, _ any) error {
|
||||||
|
ForceResolverReconnect(ctx)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
); err != nil {
|
||||||
|
// This module does not depend on the SPN/Captain module, and probably should not.
|
||||||
|
log.Warningf("resolvers: failed to register event hook for captain/spn-connect: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
// reload after config change
|
// reload after config change
|
||||||
prevNameservers := strings.Join(configuredNameServers(), " ")
|
prevNameservers := strings.Join(configuredNameServers(), " ")
|
||||||
err = module.RegisterEventHook(
|
err = module.RegisterEventHook(
|
||||||
|
|||||||
@@ -150,6 +150,8 @@ func (er *envResolverConn) IsFailing() bool {
|
|||||||
|
|
||||||
func (er *envResolverConn) ResetFailure() {}
|
func (er *envResolverConn) ResetFailure() {}
|
||||||
|
|
||||||
|
func (er *envResolverConn) ForceReconnect(_ context.Context) {}
|
||||||
|
|
||||||
// QueryPortmasterEnv queries the environment resolver directly.
|
// QueryPortmasterEnv queries the environment resolver directly.
|
||||||
func QueryPortmasterEnv(ctx context.Context, q *Query) (*RRCache, error) {
|
func QueryPortmasterEnv(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
return envResolver.Conn.Query(ctx, q)
|
return envResolver.Conn.Query(ctx, q)
|
||||||
|
|||||||
@@ -8,15 +8,19 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
"github.com/safing/portbase/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPSResolver is a resolver using just a single tcp connection with pipelining.
|
// HTTPSResolver is a resolver using just a single tcp connection with pipelining.
|
||||||
type HTTPSResolver struct {
|
type HTTPSResolver struct {
|
||||||
BasicResolverConn
|
BasicResolverConn
|
||||||
Client *http.Client
|
client *http.Client
|
||||||
|
clientLock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPSQuery holds the query information for a hTTPSResolverConn.
|
// HTTPSQuery holds the query information for a hTTPSResolverConn.
|
||||||
@@ -40,23 +44,13 @@ func (tq *HTTPSQuery) MakeCacheRecord(reply *dns.Msg, resolverInfo *ResolverInfo
|
|||||||
|
|
||||||
// NewHTTPSResolver returns a new HTTPSResolver.
|
// NewHTTPSResolver returns a new HTTPSResolver.
|
||||||
func NewHTTPSResolver(resolver *Resolver) *HTTPSResolver {
|
func NewHTTPSResolver(resolver *Resolver) *HTTPSResolver {
|
||||||
tr := &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
ServerName: resolver.Info.Domain,
|
|
||||||
// TODO: use portbase rng
|
|
||||||
},
|
|
||||||
IdleConnTimeout: 3 * time.Minute,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{Transport: tr}
|
|
||||||
newResolver := &HTTPSResolver{
|
newResolver := &HTTPSResolver{
|
||||||
BasicResolverConn: BasicResolverConn{
|
BasicResolverConn: BasicResolverConn{
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
},
|
},
|
||||||
Client: client,
|
|
||||||
}
|
}
|
||||||
newResolver.BasicResolverConn.init()
|
newResolver.BasicResolverConn.init()
|
||||||
|
newResolver.refreshClient()
|
||||||
return newResolver
|
return newResolver
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,7 +80,13 @@ func (hr *HTTPSResolver) Query(ctx context.Context, q *Query) (*RRCache, error)
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := hr.Client.Do(request)
|
// Lock client for usage.
|
||||||
|
hr.clientLock.RLock()
|
||||||
|
defer hr.clientLock.RUnlock()
|
||||||
|
|
||||||
|
// TODO: Check age of client and force a refresh similar to the TCP resolver.
|
||||||
|
|
||||||
|
resp, err := hr.client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -124,3 +124,35 @@ func (hr *HTTPSResolver) Query(ctx context.Context, q *Query) (*RRCache, error)
|
|||||||
// TODO: check if reply.Answer is valid
|
// TODO: check if reply.Answer is valid
|
||||||
return newRecord, nil
|
return newRecord, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForceReconnect forces the resolver to re-establish the connection to the server.
|
||||||
|
func (hr *HTTPSResolver) ForceReconnect(ctx context.Context) {
|
||||||
|
hr.refreshClient()
|
||||||
|
log.Tracer(ctx).Tracef("resolver: created new HTTP client for %s", hr.resolver)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hr *HTTPSResolver) refreshClient() {
|
||||||
|
// Lock client for changing.
|
||||||
|
hr.clientLock.Lock()
|
||||||
|
defer hr.clientLock.Unlock()
|
||||||
|
|
||||||
|
// Attempt to close connection of previous client.
|
||||||
|
if hr.client != nil {
|
||||||
|
hr.client.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new client.
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: hr.resolver.Info.Domain,
|
||||||
|
// TODO: use portbase rng
|
||||||
|
},
|
||||||
|
IdleConnTimeout: 1 * time.Minute,
|
||||||
|
TLSHandshakeTimeout: defaultConnectTimeout,
|
||||||
|
}
|
||||||
|
hr.client = &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
Timeout: maxRequestTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ func (mrc *mDNSResolverConn) IsFailing() bool {
|
|||||||
|
|
||||||
func (mrc *mDNSResolverConn) ResetFailure() {}
|
func (mrc *mDNSResolverConn) ResetFailure() {}
|
||||||
|
|
||||||
|
func (mrc *mDNSResolverConn) ForceReconnect(_ context.Context) {}
|
||||||
|
|
||||||
type savedQuestion struct {
|
type savedQuestion struct {
|
||||||
question dns.Question
|
question dns.Question
|
||||||
expires time.Time
|
expires time.Time
|
||||||
|
|||||||
@@ -96,3 +96,7 @@ func (pr *PlainResolver) Query(ctx context.Context, q *Query) (*RRCache, error)
|
|||||||
// TODO: check if reply.Answer is valid
|
// TODO: check if reply.Answer is valid
|
||||||
return newRecord, nil
|
return newRecord, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForceReconnect forces the resolver to re-establish the connection to the server.
|
||||||
|
// Does nothing for PlainResolver, as every request uses its own connection.
|
||||||
|
func (pr *PlainResolver) ForceReconnect(_ context.Context) {}
|
||||||
|
|||||||
@@ -236,11 +236,29 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
|||||||
return tq.MakeCacheRecord(reply, tr.resolver.Info), nil
|
return tq.MakeCacheRecord(reply, tr.resolver.Info), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForceReconnect forces the resolver to re-establish the connection to the server.
|
||||||
|
func (tr *TCPResolver) ForceReconnect(ctx context.Context) {
|
||||||
|
tr.Lock()
|
||||||
|
defer tr.Unlock()
|
||||||
|
|
||||||
|
// Do nothing if no connection is available.
|
||||||
|
if tr.resolverConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the abandoned to force a new connection on next request.
|
||||||
|
// This will leave the previous connection and handler running until all requests are handled.
|
||||||
|
tr.resolverConn.abandoned.Set()
|
||||||
|
|
||||||
|
log.Tracer(ctx).Tracef("resolver: marked %s for reconnecting", tr.resolver)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown cleanly shuts down the resolver connection.
|
||||||
|
// Must only be called once.
|
||||||
func (trc *tcpResolverConn) shutdown() {
|
func (trc *tcpResolverConn) shutdown() {
|
||||||
// Set abandoned status and close connection to the DNS server.
|
// Set abandoned status and close connection to the DNS server.
|
||||||
if trc.abandoned.SetToIf(false, true) {
|
trc.abandoned.Set()
|
||||||
_ = trc.conn.Close()
|
_ = trc.conn.Close()
|
||||||
}
|
|
||||||
|
|
||||||
// Close all response channels for in-flight queries.
|
// Close all response channels for in-flight queries.
|
||||||
for _, tq := range trc.inFlightQueries {
|
for _, tq := range trc.inFlightQueries {
|
||||||
@@ -320,7 +338,7 @@ func (trc *tcpResolverConn) handler(workerCtx context.Context) error {
|
|||||||
|
|
||||||
// If we are ready to recycle and we have no in-flight queries, we can
|
// If we are ready to recycle and we have no in-flight queries, we can
|
||||||
// shutdown the connection and create a new one for the next query.
|
// shutdown the connection and create a new one for the next query.
|
||||||
if readyToRecycle {
|
if readyToRecycle || trc.abandoned.IsSet() {
|
||||||
if len(trc.inFlightQueries) == 0 {
|
if len(trc.inFlightQueries) == 0 {
|
||||||
log.Debugf("resolver: recycling connection to %s", trc.resolverInfo.DescriptiveName())
|
log.Debugf("resolver: recycling connection to %s", trc.resolverInfo.DescriptiveName())
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ type ResolverConn interface { //nolint:golint // TODO
|
|||||||
ReportFailure()
|
ReportFailure()
|
||||||
IsFailing() bool
|
IsFailing() bool
|
||||||
ResetFailure()
|
ResetFailure()
|
||||||
|
ForceReconnect(ctx context.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BasicResolverConn implements ResolverConn for standard dns clients.
|
// BasicResolverConn implements ResolverConn for standard dns clients.
|
||||||
|
|||||||
@@ -479,6 +479,8 @@ func setScopedResolvers(resolvers []*Resolver) {
|
|||||||
for _, resolver := range resolvers {
|
for _, resolver := range resolvers {
|
||||||
if resolver.Info.IPScope.IsLAN() {
|
if resolver.Info.IPScope.IsLAN() {
|
||||||
localResolvers = append(localResolvers, resolver)
|
localResolvers = append(localResolvers, resolver)
|
||||||
|
} else if _, err := netenv.GetLocalNetwork(resolver.Info.IP); err != nil {
|
||||||
|
localResolvers = append(localResolvers, resolver)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resolver.Info.Source == ServerSourceOperatingSystem {
|
if resolver.Info.Source == ServerSourceOperatingSystem {
|
||||||
@@ -570,3 +572,18 @@ func IsResolverAddress(ip net.IP, port uint16) bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForceResolverReconnect forces all resolvers to reconnect.
|
||||||
|
func ForceResolverReconnect(ctx context.Context) {
|
||||||
|
resolversLock.RLock()
|
||||||
|
defer resolversLock.RUnlock()
|
||||||
|
|
||||||
|
ctx, tracer := log.AddTracer(ctx)
|
||||||
|
defer tracer.Submit()
|
||||||
|
|
||||||
|
tracer.Trace("resolver: forcing all active resolvers to reconnect")
|
||||||
|
for _, r := range globalResolvers {
|
||||||
|
r.Conn.ForceReconnect(ctx)
|
||||||
|
}
|
||||||
|
tracer.Info("resolver: all active resolvers were forced to reconnect")
|
||||||
|
}
|
||||||
|
|||||||
@@ -195,7 +195,6 @@ func GetResolversInScope(ctx context.Context, q *Query) (selected []*Resolver, p
|
|||||||
if domainInScope(q.dotPrefixedFQDN, specialUseDomains) ||
|
if domainInScope(q.dotPrefixedFQDN, specialUseDomains) ||
|
||||||
domainInScope(q.dotPrefixedFQDN, specialServiceDomains) {
|
domainInScope(q.dotPrefixedFQDN, specialServiceDomains) {
|
||||||
selected = addResolvers(ctx, q, selected, localResolvers)
|
selected = addResolvers(ctx, q, selected, localResolvers)
|
||||||
selected = addResolvers(ctx, q, selected, systemResolvers)
|
|
||||||
return selected, "special", true
|
return selected, "special", true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user