From a3e298b1bc359b300d446ce90c34d474f39977ac Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 5 Dec 2021 14:05:48 +0100 Subject: [PATCH] Block successive failing queries --- nameserver/failing.go | 135 +++++++++++++++++++++++++++++++++++++++ nameserver/nameserver.go | 85 +++++++++++++++++------- 2 files changed, 197 insertions(+), 23 deletions(-) create mode 100644 nameserver/failing.go diff --git a/nameserver/failing.go b/nameserver/failing.go new file mode 100644 index 00000000..033af308 --- /dev/null +++ b/nameserver/failing.go @@ -0,0 +1,135 @@ +package nameserver + +import ( + "sync" + "time" + + "github.com/safing/portmaster/netenv" + "github.com/safing/portmaster/resolver" +) + +type failingQuery struct { + // Until specifies until when the query should be regarded as failing. + Until time.Time + + // Keep specifies until when the failing status shall be kept. + Keep time.Time + + // Times specifies how often this query failed. + Times int + + // Err holds the error the query failed with. + Err error +} + +const ( + failingDelay = 900 * time.Millisecond + failingBaseDuration = 900 * time.Millisecond + failingFactorDuration = 500 * time.Millisecond + failingMaxDuration = 30 * time.Second + failingKeepAddedDuration = 10 * time.Second +) + +var ( + failingQueries = make(map[string]*failingQuery) + failingQueriesLock sync.RWMutex + failingQueriesNetworkChangedFlag = netenv.GetNetworkChangedFlag() +) + +func checkIfQueryIsFailing(q *resolver.Query) (failingErr error, failingUntil *time.Time) { + // If the network changed, reset the failed queries. + if failingQueriesNetworkChangedFlag.IsSet() { + failingQueriesNetworkChangedFlag.Refresh() + + failingQueriesLock.Lock() + defer failingQueriesLock.Unlock() + + // Compiler optimized map reset. + for key, _ := range failingQueries { + delete(failingQueries, key) + } + + return nil, nil + } + + failingQueriesLock.RLock() + defer failingQueriesLock.RUnlock() + + // Quickly return if map is empty. + if len(failingQueries) == 0 { + return nil, nil + } + + // Check if query failed recently. + failing, ok := failingQueries[q.ID()] + if !ok { + return nil, nil + } + + // Check if failing query should still be regarded as failing. + if time.Now().After(failing.Until) { + return nil, nil + } + + // Return failing error and until when it's valid. + return failing.Err, &failing.Until +} + +func addFailingQuery(q *resolver.Query, err error) { + // Check if we were given an error. + if err == nil { + return + } + + // Exclude reverse and mDNS queries, as they fail _often_ and are usually not + // retried quickly. + // if strings.HasSuffix(q.FQDN, ".in-addr.arpa.") || + // strings.HasSuffix(q.FQDN, ".ip6.arpa.") || + // strings.HasSuffix(q.FQDN, ".local.") { + // return + // } + + failingQueriesLock.Lock() + defer failingQueriesLock.Unlock() + + failing, ok := failingQueries[q.ID()] + if !ok { + failing = &failingQuery{Err: err} + failingQueries[q.ID()] = failing + } + + // Calculate fail duration. + // Initial fail duration will be at 900ms, perfect for a normal retry after 1s, + // but not any earlier. + failDuration := failingBaseDuration + time.Duration(failing.Times)*failingFactorDuration + if failDuration > failingMaxDuration { + failDuration = failingMaxDuration + } + + // Update failing query. + failing.Times++ + failing.Until = time.Now().Add(failDuration) + failing.Keep = failing.Until.Add(failingKeepAddedDuration) +} + +func cleanFailingQueries(maxRemove, maxMiss int) { + failingQueriesLock.Lock() + defer failingQueriesLock.Unlock() + + now := time.Now() + for key, failing := range failingQueries { + if now.After(failing.Keep) { + delete(failingQueries, key) + + maxRemove-- + if maxRemove == 0 { + return + } + } else { + maxMiss-- + if maxMiss == 0 { + return + } + } + } +} diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 6c09e601..d9c12c63 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -35,11 +35,29 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) startTime := time.Now() defer requestsHistogram.UpdateDuration(startTime) - // Only process first question, that's how everyone does it. - if len(request.Question) == 0 { - return errors.New("missing question") + // Check Question, only process the first, that's how everyone does it. + var originalQuestion dns.Question + switch len(request.Question) { + case 0: + log.Warning("nameserver: received query without question") + return sendResponse(ctx, w, request, nsutil.Refused("no question provided")) + case 1: + originalQuestion = request.Question[0] + default: + log.Warningf( + "nameserver: received query with multiple questions, first is %s.%s", + request.Question[0].Name, + dns.Type(request.Question[0].Qtype), + ) + return sendResponse(ctx, w, request, nsutil.Refused("multiple question provided")) + } + + // Check the Query Class. + if originalQuestion.Qclass != dns.ClassINET { + // We only serve IN records. + log.Warningf("nameserver: received unsupported qclass %d question for %s", originalQuestion.Qclass, originalQuestion.Name) + return sendResponse(ctx, w, request, nsutil.Refused("unsupported qclass")) } - originalQuestion := request.Question[0] // Check if we are handling a non-standard query name. var nonStandardQuestionFormat bool @@ -57,21 +75,15 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) // Get remote address of request. remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr) if !ok { - log.Warningf("nameserver: failed to get remote address of request for %s%s, ignoring", q.FQDN, q.QType) - return nil + log.Warningf("nameserver: failed to get remote address of dns query: is type %+T", w.RemoteAddr()) + return sendResponse(ctx, w, request, nsutil.Refused("unsupported transport")) } - // log.Errorf("DEBUG: nameserver: handling new request for %s from %s:%d", q.ID(), remoteAddr.IP, remoteAddr.Port) // Start context tracer for context-aware logging. ctx, tracer := log.AddTracer(ctx) defer tracer.Submit() tracer.Tracef("nameserver: handling new request for %s from %s:%d", q.ID(), remoteAddr.IP, remoteAddr.Port) - // Check if there are more than one question. - if len(request.Question) > 1 { - tracer.Warningf("nameserver: received more than one question from (%s:%d), first question is %s", remoteAddr.IP, remoteAddr.Port, q.ID()) - } - // Setup quick reply function. reply := func(responder nsutil.Responder, rrProviders ...nsutil.RRProvider) error { err := sendResponse(ctx, w, request, responder, rrProviders...) @@ -82,13 +94,6 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) return nil } - // Check the Query Class. - if originalQuestion.Qclass != dns.ClassINET { - // we only serve IN records, return nxdomain - tracer.Warningf("nameserver: only IN record requests are supported but received QClass %d, returning NXDOMAIN", originalQuestion.Qclass) - return reply(nsutil.Refused("unsupported qclass")) - } - // Handle request for localhost and the hostname. if strings.HasSuffix(q.FQDN, "localhost.") || q.FQDN == hostname { tracer.Tracef("nameserver: returning localhost records") @@ -101,11 +106,43 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) return reply(nsutil.Refused("invalid domain")) } - // Authenticate request - only requests from the local host, but with any of its IPs, are allowed. + // Check if query is failing. + // Some software retries failing queries excessively. This might not be a + // problem normally, but handling a request is pretty expensive for the + // Portmaster, as it has to find out who sent the query. If we know the query + // will fail with a very high probability, it is beneficial to just kill the + // query for some time before doing any expensive work. + defer cleanFailingQueries(10, 3) + failingErr, failingUntil := checkIfQueryIsFailing(q) + if failingErr != nil { + remainingFailingDuration := time.Until(*failingUntil) + tracer.Debugf("nameserver: returning previous error for %s: %s", q.ID(), failingErr) + + // Delay the response a bit in order to mitigate request flooding. + if remainingFailingDuration < failingDelay { + // Delay for remainind fail duration. + tracer.Tracef("nameserver: delaying failing lookup until end of fail duration for %s", remainingFailingDuration.Round(time.Millisecond)) + time.Sleep(remainingFailingDuration) + return reply(nsutil.ServerFailure( + "internal error: "+failingErr.Error(), + "delayed failing query to mitigate request flooding", + )) + } + // Delay for default duration. + tracer.Tracef("nameserver: delaying failing lookup for %s", failingDelay.Round(time.Millisecond)) + time.Sleep(failingDelay) + return reply(nsutil.ServerFailure( + "internal error: "+failingErr.Error(), + "delayed failing query to mitigate request flooding", + fmt.Sprintf("error is cached for another %s", remainingFailingDuration.Round(time.Millisecond)), + )) + } + + // Check if the request is local. local, err := netenv.IsMyIP(remoteAddr.IP) if err != nil { tracer.Warningf("nameserver: failed to check if request for %s is local: %s", q.ID(), err) - return nil // Do no reply, drop request immediately. + return reply(nsutil.ServerFailure("internal error: failed to check if request is local")) } // Create connection ID for dns request. @@ -127,12 +164,12 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) conn, err = network.NewConnectionFromExternalDNSRequest(ctx, q.FQDN, nil, connID, remoteAddr.IP) if err != nil { tracer.Warningf("nameserver: failed to get host/profile for request for %s%s: %s", q.FQDN, q.QType, err) - return nil // Do no reply, drop request immediately. + return reply(nsutil.ServerFailure("internal error: failed to get profile")) } default: tracer.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType) - return nil // Do no reply, drop request immediately. + return reply(nsutil.Refused("external queries are not permitted")) } conn.Lock() defer conn.Unlock() @@ -226,6 +263,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) log.Tracer(ctx).Debugf("nameserver: device is offline, using backup cache for %s", q.ID()) default: tracer.Warningf("nameserver: failed to resolve %s: %s", q.ID(), err) + addFailingQuery(q, err) return reply(nsutil.ServerFailure("internal error: " + err.Error())) } } @@ -233,6 +271,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) switch { case rrCache == nil: tracer.Warning("nameserver: received successful, but empty reply from resolver") + addFailingQuery(q, errors.New("emptry reply from resolver")) return reply(nsutil.ServerFailure("internal error: empty reply")) case rrCache.RCode == dns.RcodeNameError: // Return now if NXDomain.