wip: migrate to mono-repo. SPN has already been moved to spn/

This commit is contained in:
Patrick Pacher
2024-03-15 11:55:13 +01:00
parent b30fd00ccf
commit 8579430db9
577 changed files with 35981 additions and 818 deletions

View File

@@ -0,0 +1,60 @@
package nameserver
import (
"flag"
"runtime"
"github.com/safing/portbase/config"
"github.com/safing/portmaster/service/core"
)
// CfgDefaultNameserverAddressKey is the config key for the listen address..
const CfgDefaultNameserverAddressKey = "dns/listenAddress"
var (
defaultNameserverAddress = "localhost:53"
nameserverAddress string
nameserverAddressConfig config.StringOption
networkServiceMode config.BoolOption
)
func init() {
// On Windows, packets are redirected to the same interface.
if runtime.GOOS == "windows" {
defaultNameserverAddress = "0.0.0.0:53"
}
flag.StringVar(
&nameserverAddress,
"nameserver-address",
defaultNameserverAddress,
"set default nameserver address; configuration is stronger",
)
}
func registerConfig() error {
err := config.Register(&config.Option{
Name: "Internal DNS Server Listen Address",
Key: CfgDefaultNameserverAddressKey,
Description: "Defines the IP address and port on which the internal DNS Server listens.",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelDeveloper,
ReleaseLevel: config.ReleaseLevelStable,
DefaultValue: nameserverAddress,
ValidationRegex: "^(localhost|[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}|\\[[:0-9A-Fa-f]+\\]):[0-9]{1,5}$",
RequiresRestart: true,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: 514,
config.CategoryAnnotation: "Development",
},
})
if err != nil {
return err
}
nameserverAddressConfig = config.GetAsString(CfgDefaultNameserverAddressKey, nameserverAddress)
networkServiceMode = config.Concurrent.GetAsBool(core.CfgNetworkServiceKey, false)
return nil
}

View File

@@ -0,0 +1,75 @@
package nameserver
import (
"net"
"os"
processInfo "github.com/shirou/gopsutil/process"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/network/state"
)
var commonResolverIPs = []net.IP{
net.IPv4zero,
net.IPv4(127, 0, 0, 1), // default
net.IPv4(127, 0, 0, 53), // some resolvers on Linux
net.IPv6zero,
net.IPv6loopback,
}
func findConflictingProcess(ip net.IP, port uint16) (conflictingProcess *processInfo.Process) {
// Evaluate which IPs to check.
var ipsToCheck []net.IP
if ip.Equal(net.IPv4zero) || ip.Equal(net.IPv6zero) {
ipsToCheck = commonResolverIPs
} else {
ipsToCheck = []net.IP{ip}
}
// Find the conflicting process.
var err error
for _, resolverIP := range ipsToCheck {
conflictingProcess, err = getListeningProcess(resolverIP, port)
switch {
case err != nil:
// Log the error and let the worker try again.
log.Warningf("nameserver: failed to find conflicting service: %s", err)
case conflictingProcess != nil:
// Conflicting service found.
return conflictingProcess
}
}
return nil
}
func getListeningProcess(resolverIP net.IP, resolverPort uint16) (*processInfo.Process, error) {
pid, _, err := state.Lookup(&packet.Info{
Inbound: true,
Version: 0, // auto-detect
Protocol: packet.UDP,
Src: nil, // do not record direction
SrcPort: 0, // do not record direction
Dst: resolverIP,
DstPort: resolverPort,
}, true)
if err != nil {
// there may be nothing listening on :53
return nil, nil //nolint:nilerr // Treat lookup error as "not found".
}
// Ignore if it's us for some reason.
if pid == os.Getpid() {
return nil, nil
}
proc, err := processInfo.NewProcess(int32(pid))
if err != nil {
// Process may have disappeared already.
return nil, err
}
return proc, nil
}

View File

@@ -0,0 +1,135 @@
package nameserver
import (
"sync"
"time"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/resolver"
)
type failingQuery struct {
// Until specifies until when the query should be regarded as failing.
Until time.Time
// Keep specifies until when the failing status shall be kept.
Keep time.Time
// Times specifies how often this query failed.
Times int
// Err holds the error the query failed with.
Err error
}
const (
failingDelay = 900 * time.Millisecond
failingBaseDuration = 900 * time.Millisecond
failingFactorDuration = 500 * time.Millisecond
failingMaxDuration = 30 * time.Second
failingKeepAddedDuration = 10 * time.Second
)
var (
failingQueries = make(map[string]*failingQuery)
failingQueriesLock sync.RWMutex
failingQueriesNetworkChangedFlag = netenv.GetNetworkChangedFlag()
)
func checkIfQueryIsFailing(q *resolver.Query) (failingUntil *time.Time, failingErr error) {
// If the network changed, reset the failed queries.
if failingQueriesNetworkChangedFlag.IsSet() {
failingQueriesNetworkChangedFlag.Refresh()
failingQueriesLock.Lock()
defer failingQueriesLock.Unlock()
// Compiler optimized map reset.
for key := range failingQueries {
delete(failingQueries, key)
}
return nil, nil
}
failingQueriesLock.RLock()
defer failingQueriesLock.RUnlock()
// Quickly return if map is empty.
if len(failingQueries) == 0 {
return nil, nil
}
// Check if query failed recently.
failing, ok := failingQueries[q.ID()]
if !ok {
return nil, nil
}
// Check if failing query should still be regarded as failing.
if time.Now().After(failing.Until) {
return nil, nil
}
// Return failing error and until when it's valid.
return &failing.Until, failing.Err
}
func addFailingQuery(q *resolver.Query, err error) {
// Check if we were given an error.
if err == nil {
return
}
// Exclude reverse and mDNS queries, as they fail _often_ and are usually not
// retried quickly.
// if strings.HasSuffix(q.FQDN, ".in-addr.arpa.") ||
// strings.HasSuffix(q.FQDN, ".ip6.arpa.") ||
// strings.HasSuffix(q.FQDN, ".local.") {
// return
// }
failingQueriesLock.Lock()
defer failingQueriesLock.Unlock()
failing, ok := failingQueries[q.ID()]
if !ok {
failing = &failingQuery{Err: err}
failingQueries[q.ID()] = failing
}
// Calculate fail duration.
// Initial fail duration will be at 900ms, perfect for a normal retry after 1s,
// but not any earlier.
failDuration := failingBaseDuration + time.Duration(failing.Times)*failingFactorDuration
if failDuration > failingMaxDuration {
failDuration = failingMaxDuration
}
// Update failing query.
failing.Times++
failing.Until = time.Now().Add(failDuration)
failing.Keep = failing.Until.Add(failingKeepAddedDuration)
}
func cleanFailingQueries(maxRemove, maxMiss int) {
failingQueriesLock.Lock()
defer failingQueriesLock.Unlock()
now := time.Now()
for key, failing := range failingQueries {
if now.After(failing.Keep) {
delete(failingQueries, key)
maxRemove--
if maxRemove == 0 {
return
}
} else {
maxMiss--
if maxMiss == 0 {
return
}
}
}
}

View File

@@ -0,0 +1,42 @@
package nameserver
import (
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/metrics"
)
var (
requestsHistogram *metrics.Histogram
totalHandledRequests *metrics.Counter
)
func registerMetrics() (err error) {
requestsHistogram, err = metrics.NewHistogram(
"nameserver/request/duration/seconds",
nil,
&metrics.Options{
Permission: api.PermitUser,
ExpertiseLevel: config.ExpertiseLevelExpert,
},
)
if err != nil {
return err
}
totalHandledRequests, err = metrics.NewCounter(
"nameserver/request/total",
nil,
&metrics.Options{
InternalID: "handled_dns_requests",
Permission: api.PermitUser,
ExpertiseLevel: config.ExpertiseLevelExpert,
Persist: true,
},
)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,290 @@
package nameserver
import (
"context"
"fmt"
"net"
"os"
"strconv"
"sync"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/notifications"
"github.com/safing/portmaster/service/compat"
"github.com/safing/portmaster/service/firewall"
"github.com/safing/portmaster/service/netenv"
)
var (
module *modules.Module
stopListeners bool
stopListener1 func() error
stopListener2 func() error
stopListenersLock sync.Mutex
eventIDConflictingService = "nameserver:conflicting-service"
eventIDListenerFailed = "nameserver:listener-failed"
)
func init() {
module = modules.Register("nameserver", prep, start, stop, "core", "resolver")
subsystems.Register(
"dns",
"Secure DNS",
"DNS resolver with scoping and DNS-over-TLS",
module,
"config:dns/",
nil,
)
}
func prep() error {
return registerConfig()
}
func start() error {
if err := registerMetrics(); err != nil {
return err
}
// Get listen addresses.
ip1, ip2, port, err := getListenAddresses(nameserverAddressConfig())
if err != nil {
return fmt.Errorf("failed to parse nameserver listen address: %w", err)
}
// Tell the compat module where we are listening.
compat.SetNameserverListenIP(ip1)
// Get own hostname.
hostname, err = os.Hostname()
if err != nil {
log.Warningf("nameserver: failed to get hostname: %s", err)
}
hostname += "."
// Start listener(s).
if ip2 == nil {
// Start a single listener.
startListener(ip1, port, true)
// Set nameserver matcher in firewall to fast-track dns queries.
if ip1.Equal(net.IPv4zero) || ip1.Equal(net.IPv6zero) {
// Fast track dns queries destined for any of the local IPs.
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
dstIsMe, err := netenv.IsMyIP(ip)
if err != nil {
log.Warningf("nameserver: failed to check if IP %s is local: %s", ip, err)
}
return dstIsMe
})
}
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
return ip.Equal(ip1)
})
}
// Dual listener.
startListener(ip1, port, true)
startListener(ip2, port, false)
// Fast track dns queries destined for one of the listener IPs.
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
return ip.Equal(ip1) || ip.Equal(ip2)
})
}
func startListener(ip net.IP, port uint16, first bool) {
// Start DNS server as service worker.
module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error {
// Create DNS server.
dnsServer := &dns.Server{
Addr: net.JoinHostPort(
ip.String(),
strconv.Itoa(int(port)),
),
Net: "udp",
Handler: dns.HandlerFunc(handleRequestAsWorker),
}
// Register stop function.
func() {
stopListenersLock.Lock()
defer stopListenersLock.Unlock()
// Check if we should stop
if stopListeners {
_ = dnsServer.Shutdown()
dnsServer = nil
return
}
// Register stop function.
if first {
stopListener1 = dnsServer.Shutdown
} else {
stopListener2 = dnsServer.Shutdown
}
}()
// Check if we should stop.
if dnsServer == nil {
return nil
}
// Resolve generic listener error, if primary listener.
if first {
module.Resolve(eventIDListenerFailed)
}
// Start listening.
log.Infof("nameserver: starting to listen on %s", dnsServer.Addr)
err := dnsServer.ListenAndServe()
if err != nil {
// Stop worker without error if we are shutting down.
if module.IsStopping() {
return nil
}
log.Warningf("nameserver: failed to listen on %s: %s", dnsServer.Addr, err)
handleListenError(err, ip, port, first)
}
return err
})
}
func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) {
var n *notifications.Notification
// Create suffix for secondary listener
var secondaryEventIDSuffix string
if !primaryListener {
secondaryEventIDSuffix = "-secondary"
}
// Find a conflicting service.
cfProcess := findConflictingProcess(ip, port)
if cfProcess != nil {
// Report the conflicting process.
// Build conflicting process description.
var cfDescription string
cfName, err := cfProcess.Name()
if err == nil && cfName != "" {
cfDescription = cfName
}
cfExe, err := cfProcess.Exe()
if err == nil && cfDescription != "" {
if cfDescription != "" {
cfDescription += " (" + cfExe + ")"
} else {
cfDescription = cfName
}
}
// Notify user about conflicting service.
n = notifications.Notify(&notifications.Notification{
EventID: eventIDConflictingService + secondaryEventIDSuffix,
Type: notifications.Error,
Title: "Conflicting DNS Software",
Message: fmt.Sprintf(
"Restart Portmaster after you have deactivated or properly configured the conflicting software: %s",
cfDescription,
),
ShowOnSystem: true,
AvailableActions: []*notifications.Action{
{
Text: "Open Docs",
Type: notifications.ActionTypeOpenURL,
Payload: "https://docs.safing.io/portmaster/install/status/software-compatibility",
},
},
})
} else {
// If no conflict is found, report the error directly.
n = notifications.Notify(&notifications.Notification{
EventID: eventIDListenerFailed + secondaryEventIDSuffix,
Type: notifications.Error,
Title: "Secure DNS Error",
Message: fmt.Sprintf(
"The internal DNS server failed. Restart Portmaster to try again. Error: %s",
err,
),
ShowOnSystem: true,
})
}
// Attach error to module, if primary listener.
if primaryListener {
n.AttachToModule(module)
}
}
func stop() error {
stopListenersLock.Lock()
defer stopListenersLock.Unlock()
// Stop listeners.
stopListeners = true
if stopListener1 != nil {
if err := stopListener1(); err != nil {
log.Warningf("nameserver: failed to stop listener1: %s", err)
}
}
if stopListener2 != nil {
if err := stopListener2(); err != nil {
log.Warningf("nameserver: failed to stop listener2: %s", err)
}
}
return nil
}
func getListenAddresses(listenAddress string) (ip1, ip2 net.IP, port uint16, err error) {
// Split host and port.
ipString, portString, err := net.SplitHostPort(listenAddress)
if err != nil {
return nil, nil, 0, fmt.Errorf(
"failed to parse address %s: %w",
listenAddress,
err,
)
}
// Parse the IP address. If the want to listen on localhost, we need to
// listen separately for IPv4 and IPv6.
if ipString == "localhost" {
ip1 = net.IPv4(127, 0, 0, 17)
if netenv.IPv6Enabled() {
ip2 = net.IPv6loopback
} else {
log.Warningf("nameserver: no IPv6 stack detected, disabling IPv6 nameserver listener")
}
} else {
ip1 = net.ParseIP(ipString)
if ip1 == nil {
return nil, nil, 0, fmt.Errorf(
"failed to parse IP %s from %s",
ipString,
listenAddress,
)
}
}
// Parse the port.
port64, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, nil, 0, fmt.Errorf(
"failed to parse port %s from %s: %w",
portString,
listenAddress,
err,
)
}
return ip1, ip2, uint16(port64), nil
}

View File

@@ -0,0 +1,407 @@
package nameserver
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/firewall"
"github.com/safing/portmaster/service/nameserver/nsutil"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/resolver"
)
var hostname string
func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) {
err := module.RunWorker("handle dns request", func(ctx context.Context) error {
return handleRequest(ctx, w, query)
})
if err != nil {
log.Warningf("nameserver: failed to handle dns request: %s", err)
}
}
func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:maintidx // TODO
// Record metrics.
startTime := time.Now()
defer requestsHistogram.UpdateDuration(startTime)
// Check Question, only process the first, that's how everyone does it.
var originalQuestion dns.Question
switch len(request.Question) {
case 0:
log.Warning("nameserver: received query without question")
return sendResponse(ctx, w, request, nsutil.Refused("no question provided"))
case 1:
originalQuestion = request.Question[0]
default:
log.Warningf(
"nameserver: received query with multiple questions, first is %s.%s",
request.Question[0].Name,
dns.Type(request.Question[0].Qtype),
)
return sendResponse(ctx, w, request, nsutil.Refused("multiple question provided"))
}
// Check the Query Class.
if originalQuestion.Qclass != dns.ClassINET {
// We only serve IN records.
log.Warningf("nameserver: received unsupported qclass %d question for %s", originalQuestion.Qclass, originalQuestion.Name)
return sendResponse(ctx, w, request, nsutil.Refused("unsupported qclass"))
}
// Check if we are handling a non-standard query name.
var nonStandardQuestionFormat bool
lowerCaseQuestion := strings.ToLower(originalQuestion.Name)
if lowerCaseQuestion != originalQuestion.Name {
nonStandardQuestionFormat = true
}
// Create query for the resolver.
q := &resolver.Query{
FQDN: lowerCaseQuestion,
QType: dns.Type(originalQuestion.Qtype),
}
// Get remote address of request.
remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr)
if !ok {
log.Warningf("nameserver: failed to get remote address of dns query: is type %+T", w.RemoteAddr())
return sendResponse(ctx, w, request, nsutil.Refused("unsupported transport"))
}
// Start context tracer for context-aware logging.
ctx, tracer := log.AddTracer(ctx)
defer tracer.Submit()
tracer.Tracef("nameserver: handling new request for %s from %s:%d", q.ID(), remoteAddr.IP, remoteAddr.Port)
// Count request.
totalHandledRequests.Inc()
// Setup quick reply function.
reply := func(responder nsutil.Responder, rrProviders ...nsutil.RRProvider) error {
err := sendResponse(ctx, w, request, responder, rrProviders...)
// Log error here instead of returning it in order to keep the context.
if err != nil {
tracer.Errorf("nameserver: %s", err)
}
return nil
}
// Handle request for localhost and the hostname.
if strings.HasSuffix(q.FQDN, "localhost.") || q.FQDN == hostname {
tracer.Tracef("nameserver: returning localhost records")
return reply(nsutil.Localhost())
}
// Validate domain name.
if !netutils.IsValidFqdn(q.FQDN) {
tracer.Debugf("nameserver: domain name %s is invalid, refusing", q.FQDN)
return reply(nsutil.Refused("invalid domain"))
}
// Get public suffix after validation.
q.InitPublicSuffixData()
// Check if query is failing.
// Some software retries failing queries excessively. This might not be a
// problem normally, but handling a request is pretty expensive for the
// Portmaster, as it has to find out who sent the query. If we know the query
// will fail with a very high probability, it is beneficial to just kill the
// query for some time before doing any expensive work.
defer cleanFailingQueries(10, 3)
failingUntil, failingErr := checkIfQueryIsFailing(q)
if failingErr != nil {
remainingFailingDuration := time.Until(*failingUntil)
tracer.Debugf("nameserver: returning previous error for %s: %s", q.ID(), failingErr)
// Delay the response a bit in order to mitigate request flooding.
if remainingFailingDuration < failingDelay {
// Delay for remainind fail duration.
tracer.Tracef("nameserver: delaying failing lookup until end of fail duration for %s", remainingFailingDuration.Round(time.Millisecond))
time.Sleep(remainingFailingDuration)
return reply(nsutil.ServerFailure(
"internal error: "+failingErr.Error(),
"delayed failing query to mitigate request flooding",
))
}
// Delay for default duration.
tracer.Tracef("nameserver: delaying failing lookup for %s", failingDelay.Round(time.Millisecond))
time.Sleep(failingDelay)
return reply(nsutil.ServerFailure(
"internal error: "+failingErr.Error(),
"delayed failing query to mitigate request flooding",
fmt.Sprintf("error is cached for another %s", remainingFailingDuration.Round(time.Millisecond)),
))
}
// Check if the request is local.
local, err := netenv.IsMyIP(remoteAddr.IP)
if err != nil {
tracer.Warningf("nameserver: failed to check if request for %s is local: %s", q.ID(), err)
return reply(nsutil.ServerFailure("internal error: failed to check if request is local"))
}
// Create connection ID for dns request.
connID := fmt.Sprintf(
"%s-%d-#%d-%s",
remoteAddr.IP,
remoteAddr.Port,
request.Id,
q.ID(),
)
// Get connection for this request. This identifies the process behind the request.
var conn *network.Connection
switch {
case local:
conn = network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, connID, remoteAddr.IP, uint16(remoteAddr.Port))
case networkServiceMode():
conn, err = network.NewConnectionFromExternalDNSRequest(ctx, q.FQDN, nil, connID, remoteAddr.IP)
if err != nil {
tracer.Warningf("nameserver: failed to get host/profile for request for %s%s: %s", q.FQDN, q.QType, err)
return reply(nsutil.ServerFailure("internal error: failed to get profile"))
}
default:
tracer.Warningf("nameserver: external request from %s for %s%s, ignoring", remoteAddr, q.FQDN, q.QType)
return reply(nsutil.Refused("external queries are not permitted"))
}
conn.Lock()
defer conn.Unlock()
// Create reference for the rrCache.
var rrCache *resolver.RRCache
// Once we decided on the connection we might need to save it to the database,
// so we defer that check for now.
defer func() {
// Add metadata to connection.
if rrCache != nil {
conn.DNSContext = rrCache.ToDNSRequestContext()
conn.Resolver = rrCache.Resolver
conn.Entity.IPScope = rrCache.Resolver.IPScope
} else {
// Get resolvers for this query to determine the resolving scope.
resolvers, _, _ := resolver.GetResolversInScope(ctx, q)
if len(resolvers) > 0 {
conn.Entity.IPScope = resolvers[0].Info.IPScope
}
}
switch conn.Verdict {
// We immediately save blocked, dropped or failed verdicts so
// they pop up in the UI.
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
conn.Save()
// For undecided or accepted connections we don't save them yet, because
// that will happen later anyway.
case network.VerdictUndecided, network.VerdictAccept:
// Check if we have a response.
if rrCache == nil {
conn.Failed("internal error: no reply", "")
return
}
// Mark successful queries as internal in order to hide them in the simple interface.
// These requests were most probably made for another process and only add confusion if listed.
if conn.Process().IsSystemResolver() {
conn.Internal = true
}
// Save the request as open, as we don't know if there will be a connection or not.
network.SaveOpenDNSRequest(q, rrCache, conn)
firewall.UpdateIPsAndCNAMEs(q, rrCache, conn)
case network.VerdictUndeterminable:
fallthrough
default:
tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.VerdictVerb(), conn)
}
}()
// Check request with the privacy filter before resolving.
firewall.FilterConnection(ctx, conn, nil, true, false)
// Check if there is a responder from the firewall.
// In special cases, the firewall might want to respond the query itself.
// A reason for this might be that the request is sink-holed to a forced
// IP address in which case we "accept" it, but let the firewall handle
// the resolving as it wishes.
if responder, ok := conn.Reason.Context.(nsutil.Responder); ok {
tracer.Infof("nameserver: handing over request for %s to special filter responder: %s", q.ID(), conn.Reason.Msg)
return reply(responder, conn)
}
// Check if there is a Verdict to act upon.
switch conn.Verdict { //nolint:exhaustive // Only checking for specific values.
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed:
tracer.Infof(
"nameserver: returning %s response for %s to %s",
conn.VerdictVerb(),
q.ID(),
conn.Process(),
)
return reply(conn, conn)
}
// Resolve request.
rrCache, err = resolver.Resolve(ctx, q)
// Handle error.
if err != nil {
switch {
case errors.Is(err, resolver.ErrNotFound):
// Try alternatives domain names for unofficial domain spaces.
rrCache = checkAlternativeCaches(ctx, q)
if rrCache == nil {
tracer.Tracef("nameserver: %s", err)
conn.Failed("domain does not exist", "")
return reply(nsutil.NxDomain("nxdomain: " + err.Error()))
}
case errors.Is(err, resolver.ErrBlocked):
tracer.Tracef("nameserver: %s", err)
conn.Block(err.Error(), "")
return reply(nsutil.BlockIP("blocked: " + err.Error()))
case errors.Is(err, resolver.ErrLocalhost):
tracer.Tracef("nameserver: returning localhost records")
conn.Accept("allowing query for localhost", "")
return reply(nsutil.Localhost())
case errors.Is(err, resolver.ErrOffline):
if rrCache == nil {
tracer.Debugf("nameserver: not resolving %s, device is offline", q.ID())
conn.Failed("not resolving, device is offline", "")
return reply(nsutil.ServerFailure(err.Error()))
}
// If an rrCache was returned, it's usable as a backup.
rrCache.IsBackup = true
log.Tracer(ctx).Debugf("nameserver: device is offline, using backup cache for %s", q.ID())
default:
tracer.Warningf("nameserver: failed to resolve %s: %s", q.ID(), err)
conn.Failed(fmt.Sprintf("query failed: %s", err), "")
addFailingQuery(q, err)
return reply(nsutil.ServerFailure("internal error: " + err.Error()))
}
}
// Handle special cases.
switch {
case rrCache == nil:
tracer.Warning("nameserver: received successful, but empty reply from resolver")
addFailingQuery(q, errors.New("emptry reply from resolver"))
return reply(nsutil.ServerFailure("internal error: empty reply"))
case rrCache.RCode == dns.RcodeNameError:
// Try alternatives domain names for unofficial domain spaces.
altRRCache := checkAlternativeCaches(ctx, q)
if altRRCache != nil {
rrCache = altRRCache
} else {
// Return now if NXDomain.
return reply(nsutil.NxDomain("no answer found (NXDomain)"))
}
}
// Check with firewall again after resolving.
tracer.Trace("nameserver: deciding on resolved dns")
rrCache = firewall.FilterResolvedDNS(ctx, conn, q, rrCache)
// Check again if there is a responder from the firewall.
if responder, ok := conn.Reason.Context.(nsutil.Responder); ok {
tracer.Infof("nameserver: handing over request for %s to special filter responder: %s", q.ID(), conn.Reason.Msg)
return reply(responder)
}
// Check if there is a Verdict to act upon.
switch conn.Verdict { //nolint:exhaustive // Only checking for specific values.
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed:
tracer.Infof(
"nameserver: returning %s response for %s to %s",
conn.VerdictVerb(),
q.ID(),
conn.Process(),
)
return reply(conn, conn, rrCache)
}
// Revert back to non-standard question format, if we had to convert.
if nonStandardQuestionFormat {
rrCache.ReplaceAnswerNames(originalQuestion.Name)
}
// Reply with successful response.
noAnswerIndicator := ""
if len(rrCache.Answer) == 0 {
noAnswerIndicator = "/no answer"
}
tracer.Infof(
"nameserver: returning %s response (%s%s) for %s to %s",
conn.VerdictVerb(),
dns.RcodeToString[rrCache.RCode],
noAnswerIndicator,
q.ID(),
conn.Process(),
)
return reply(rrCache, conn, rrCache)
}
func checkAlternativeCaches(ctx context.Context, q *resolver.Query) *resolver.RRCache {
// Do not try alternatives when the query is in a public suffix.
// This also includes arpa. and local.
if q.ICANNSpace {
return nil
}
// Check if the env resolver has something.
pmEnvQ := &resolver.Query{
FQDN: q.FQDN + "local." + resolver.InternalSpecialUseDomain,
QType: q.QType,
}
rrCache, err := resolver.QueryPortmasterEnv(ctx, pmEnvQ)
if err == nil && rrCache != nil && rrCache.RCode == dns.RcodeSuccess {
makeAlternativeRecord(ctx, q, rrCache, pmEnvQ.FQDN)
return rrCache
}
// Check if we have anything in cache
localFQDN := q.FQDN + "local."
rrCache, err = resolver.GetRRCache(localFQDN, q.QType)
if err == nil && rrCache != nil && rrCache.RCode == dns.RcodeSuccess {
makeAlternativeRecord(ctx, q, rrCache, localFQDN)
return rrCache
}
return nil
}
func makeAlternativeRecord(ctx context.Context, q *resolver.Query, rrCache *resolver.RRCache, altName string) {
log.Tracer(ctx).Debugf("using %s to answer query", altName)
// Duplicate answers so they match the query.
copied := make([]dns.RR, 0, len(rrCache.Answer))
for _, answer := range rrCache.Answer {
if strings.ToLower(answer.Header().Name) == altName {
copiedAnswer := dns.Copy(answer)
copiedAnswer.Header().Name = q.FQDN
copied = append(copied, copiedAnswer)
}
}
if len(copied) > 0 {
rrCache.Answer = append(rrCache.Answer, copied...)
}
// Update the question.
rrCache.Domain = q.FQDN
}

View File

@@ -0,0 +1,202 @@
package nsutil
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
)
// ErrNilRR is returned when a parsed RR is nil.
var ErrNilRR = errors.New("is nil")
// Responder defines the interface that any block/deny reason interface
// may implement to support sending custom DNS responses for a given reason.
// That is, if a reason context implements the Responder interface the
// ReplyWithDNS method will be called instead of creating the default
// zero-ip response.
type Responder interface {
// ReplyWithDNS is called when a DNS response to a DNS message is
// crafted because the request is either denied or blocked.
ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns.Msg
}
// RRProvider defines the interface that any block/deny reason interface
// may implement to support adding additional DNS resource records to
// the DNS responses extra (additional) section.
type RRProvider interface {
// GetExtraRRs is called when a DNS response to a DNS message is
// crafted because the request is either denied or blocked.
GetExtraRRs(ctx context.Context, request *dns.Msg) []dns.RR
}
// ResponderFunc is a convenience type to use a function
// directly as a Responder.
type ResponderFunc func(ctx context.Context, request *dns.Msg) *dns.Msg
// ReplyWithDNS implements the Responder interface and calls rf.
func (rf ResponderFunc) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns.Msg {
return rf(ctx, request)
}
// MarshalJSON disables JSON marshaling for ResponderFunc.
func (rf ResponderFunc) MarshalJSON() ([]byte, error) {
return json.Marshal(nil)
}
// BlockIP is a ResponderFunc than replies with either 0.0.0.17 or ::17 for
// each A or AAAA question respectively. If there is no A or AAAA question, it
// defaults to replying with NXDomain.
func BlockIP(msgs ...string) ResponderFunc {
return createResponderFunc(
"blocking",
"0.0.0.17",
"::17",
msgs...,
)
}
// ZeroIP is a ResponderFunc than replies with either 0.0.0.0 or :: for each A
// or AAAA question respectively. If there is no A or AAAA question, it
// defaults to replying with NXDomain.
func ZeroIP(msgs ...string) ResponderFunc {
return createResponderFunc(
"zero ip",
"0.0.0.0",
"::",
msgs...,
)
}
// Localhost is a ResponderFunc than replies with localhost IP addresses.
// If there is no A or AAAA question, it defaults to replying with NXDomain.
func Localhost(msgs ...string) ResponderFunc {
return createResponderFunc(
"localhost",
"127.0.0.1",
"::1",
msgs...,
)
}
func createResponderFunc(responderName, aAnswer, aaaaAnswer string, msgs ...string) ResponderFunc {
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
reply := new(dns.Msg)
hasErr := false
for _, question := range request.Question {
var rr dns.RR
var err error
switch question.Qtype {
case dns.TypeA:
rr, err = dns.NewRR(question.Name + " 1 IN A " + aAnswer)
case dns.TypeAAAA:
rr, err = dns.NewRR(question.Name + " 1 IN AAAA " + aaaaAnswer)
}
switch {
case err != nil:
log.Tracer(ctx).Errorf("nameserver: failed to create %s response for %s: %s", responderName, question.Name, err)
hasErr = true
case rr != nil:
reply.Answer = append(reply.Answer, rr)
}
}
switch {
case hasErr && len(reply.Answer) == 0:
reply.SetRcode(request, dns.RcodeServerFailure)
case len(reply.Answer) == 0:
reply.SetRcode(request, dns.RcodeNameError)
default:
reply.SetRcode(request, dns.RcodeSuccess)
}
AddMessagesToReply(ctx, reply, log.InfoLevel, msgs...)
return reply
}
}
// NxDomain returns a ResponderFunc that replies with NXDOMAIN.
func NxDomain(msgs ...string) ResponderFunc {
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
reply := new(dns.Msg).SetRcode(request, dns.RcodeNameError)
AddMessagesToReply(ctx, reply, log.InfoLevel, msgs...)
// According to RFC4074 (https://tools.ietf.org/html/rfc4074), there are
// nameservers that incorrectly respond with NXDomain instead of an empty
// SUCCESS response when there are other RRs for the queried domain name.
// This can lead to the software thinking that no RRs exist for that
// domain. In order to mitigate this a bit, we slightly delay NXDomain
// responses.
time.Sleep(500 * time.Millisecond)
return reply
}
}
// Refused returns a ResponderFunc that replies with REFUSED.
func Refused(msgs ...string) ResponderFunc {
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
reply := new(dns.Msg).SetRcode(request, dns.RcodeRefused)
AddMessagesToReply(ctx, reply, log.InfoLevel, msgs...)
return reply
}
}
// ServerFailure returns a ResponderFunc that replies with SERVFAIL.
func ServerFailure(msgs ...string) ResponderFunc {
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
reply := new(dns.Msg).SetRcode(request, dns.RcodeServerFailure)
AddMessagesToReply(ctx, reply, log.InfoLevel, msgs...)
return reply
}
}
// MakeMessageRecord creates an informational resource record that can be added
// to the extra section of a reply.
func MakeMessageRecord(level log.Severity, msg string) (dns.RR, error) { //nolint:interfacer
rr, err := dns.NewRR(fmt.Sprintf(
`%s.portmaster. 0 IN TXT "%s"`,
strings.ToLower(level.String()),
msg,
))
if err != nil {
return nil, err
}
if rr == nil {
return nil, ErrNilRR
}
return rr, nil
}
// AddMessagesToReply creates information resource records using
// MakeMessageRecord and immediately adds them to the extra section of the given
// reply. If an error occurs, the resource record will not be added, and the
// error will be logged.
func AddMessagesToReply(ctx context.Context, reply *dns.Msg, level log.Severity, msgs ...string) {
for _, msg := range msgs {
// Ignore empty messages.
if msg == "" {
continue
}
// Create resources record.
rr, err := MakeMessageRecord(level, msg)
if err != nil {
log.Tracer(ctx).Warningf("nameserver: failed to add message to reply: %s", err)
continue
}
// Add to extra section of the reply.
reply.Extra = append(reply.Extra, rr)
}
}

View File

@@ -0,0 +1,72 @@
package nameserver
import (
"context"
"fmt"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/nameserver/nsutil"
)
// sendResponse sends a response to query using w. The response message is
// created by responder. If addExtraRRs is not nil and implements the
// RRProvider interface then it will be also used to add more RRs in the
// extra section.
func sendResponse(
ctx context.Context,
w dns.ResponseWriter,
request *dns.Msg,
responder nsutil.Responder,
rrProviders ...nsutil.RRProvider,
) error {
// Have the Responder craft a DNS reply.
reply := responder.ReplyWithDNS(ctx, request)
if reply == nil {
// Dropping query.
return nil
}
// Signify that we are a recursive resolver.
// While we do not handle recursion directly, we can safely assume, that we
// always forward to a recursive resolver.
reply.RecursionAvailable = true
// Add extra RRs through a custom RRProvider.
for _, rrProvider := range rrProviders {
if rrProvider != nil {
rrs := rrProvider.GetExtraRRs(ctx, request)
reply.Extra = append(reply.Extra, rrs...)
}
}
// Write reply.
if err := writeDNSResponse(ctx, w, reply); err != nil {
return fmt.Errorf("failed to send response: %w", err)
}
return nil
}
func writeDNSResponse(ctx context.Context, w dns.ResponseWriter, m *dns.Msg) (err error) {
defer func() {
// recover from panic
if panicErr := recover(); panicErr != nil {
err = fmt.Errorf("panic: %s", panicErr)
log.Tracer(ctx).Debugf("nameserver: panic caused by this msg: %#v", m)
}
}()
err = w.WriteMsg(m)
if err != nil {
// If we receive an error we might have exceeded the message size with all
// our extra information records. Retry again without the extra section.
log.Tracer(ctx).Tracef("nameserver: retrying to write dns message without extra section, error was: %s", err)
m.Extra = nil
noExtraErr := w.WriteMsg(m)
if noExtraErr == nil {
return fmt.Errorf("failed to write dns message without extra section: %w", err)
}
}
return
}