wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
60
service/nameserver/config.go
Normal file
60
service/nameserver/config.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"runtime"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portmaster/service/core"
|
||||
)
|
||||
|
||||
// CfgDefaultNameserverAddressKey is the config key for the listen address..
|
||||
const CfgDefaultNameserverAddressKey = "dns/listenAddress"
|
||||
|
||||
var (
|
||||
defaultNameserverAddress = "localhost:53"
|
||||
nameserverAddress string
|
||||
nameserverAddressConfig config.StringOption
|
||||
|
||||
networkServiceMode config.BoolOption
|
||||
)
|
||||
|
||||
func init() {
|
||||
// On Windows, packets are redirected to the same interface.
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultNameserverAddress = "0.0.0.0:53"
|
||||
}
|
||||
|
||||
flag.StringVar(
|
||||
&nameserverAddress,
|
||||
"nameserver-address",
|
||||
defaultNameserverAddress,
|
||||
"set default nameserver address; configuration is stronger",
|
||||
)
|
||||
}
|
||||
|
||||
func registerConfig() error {
|
||||
err := config.Register(&config.Option{
|
||||
Name: "Internal DNS Server Listen Address",
|
||||
Key: CfgDefaultNameserverAddressKey,
|
||||
Description: "Defines the IP address and port on which the internal DNS Server listens.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
DefaultValue: nameserverAddress,
|
||||
ValidationRegex: "^(localhost|[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}|\\[[:0-9A-Fa-f]+\\]):[0-9]{1,5}$",
|
||||
RequiresRestart: true,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: 514,
|
||||
config.CategoryAnnotation: "Development",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nameserverAddressConfig = config.GetAsString(CfgDefaultNameserverAddressKey, nameserverAddress)
|
||||
|
||||
networkServiceMode = config.Concurrent.GetAsBool(core.CfgNetworkServiceKey, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
75
service/nameserver/conflict.go
Normal file
75
service/nameserver/conflict.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
processInfo "github.com/shirou/gopsutil/process"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/network/state"
|
||||
)
|
||||
|
||||
var commonResolverIPs = []net.IP{
|
||||
net.IPv4zero,
|
||||
net.IPv4(127, 0, 0, 1), // default
|
||||
net.IPv4(127, 0, 0, 53), // some resolvers on Linux
|
||||
net.IPv6zero,
|
||||
net.IPv6loopback,
|
||||
}
|
||||
|
||||
func findConflictingProcess(ip net.IP, port uint16) (conflictingProcess *processInfo.Process) {
|
||||
// Evaluate which IPs to check.
|
||||
var ipsToCheck []net.IP
|
||||
if ip.Equal(net.IPv4zero) || ip.Equal(net.IPv6zero) {
|
||||
ipsToCheck = commonResolverIPs
|
||||
} else {
|
||||
ipsToCheck = []net.IP{ip}
|
||||
}
|
||||
|
||||
// Find the conflicting process.
|
||||
var err error
|
||||
for _, resolverIP := range ipsToCheck {
|
||||
conflictingProcess, err = getListeningProcess(resolverIP, port)
|
||||
switch {
|
||||
case err != nil:
|
||||
// Log the error and let the worker try again.
|
||||
log.Warningf("nameserver: failed to find conflicting service: %s", err)
|
||||
case conflictingProcess != nil:
|
||||
// Conflicting service found.
|
||||
return conflictingProcess
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getListeningProcess(resolverIP net.IP, resolverPort uint16) (*processInfo.Process, error) {
|
||||
pid, _, err := state.Lookup(&packet.Info{
|
||||
Inbound: true,
|
||||
Version: 0, // auto-detect
|
||||
Protocol: packet.UDP,
|
||||
Src: nil, // do not record direction
|
||||
SrcPort: 0, // do not record direction
|
||||
Dst: resolverIP,
|
||||
DstPort: resolverPort,
|
||||
}, true)
|
||||
if err != nil {
|
||||
// there may be nothing listening on :53
|
||||
return nil, nil //nolint:nilerr // Treat lookup error as "not found".
|
||||
}
|
||||
|
||||
// Ignore if it's us for some reason.
|
||||
if pid == os.Getpid() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
proc, err := processInfo.NewProcess(int32(pid))
|
||||
if err != nil {
|
||||
// Process may have disappeared already.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return proc, nil
|
||||
}
|
||||
135
service/nameserver/failing.go
Normal file
135
service/nameserver/failing.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
)
|
||||
|
||||
type failingQuery struct {
|
||||
// Until specifies until when the query should be regarded as failing.
|
||||
Until time.Time
|
||||
|
||||
// Keep specifies until when the failing status shall be kept.
|
||||
Keep time.Time
|
||||
|
||||
// Times specifies how often this query failed.
|
||||
Times int
|
||||
|
||||
// Err holds the error the query failed with.
|
||||
Err error
|
||||
}
|
||||
|
||||
const (
|
||||
failingDelay = 900 * time.Millisecond
|
||||
failingBaseDuration = 900 * time.Millisecond
|
||||
failingFactorDuration = 500 * time.Millisecond
|
||||
failingMaxDuration = 30 * time.Second
|
||||
failingKeepAddedDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
failingQueries = make(map[string]*failingQuery)
|
||||
failingQueriesLock sync.RWMutex
|
||||
failingQueriesNetworkChangedFlag = netenv.GetNetworkChangedFlag()
|
||||
)
|
||||
|
||||
func checkIfQueryIsFailing(q *resolver.Query) (failingUntil *time.Time, failingErr error) {
|
||||
// If the network changed, reset the failed queries.
|
||||
if failingQueriesNetworkChangedFlag.IsSet() {
|
||||
failingQueriesNetworkChangedFlag.Refresh()
|
||||
|
||||
failingQueriesLock.Lock()
|
||||
defer failingQueriesLock.Unlock()
|
||||
|
||||
// Compiler optimized map reset.
|
||||
for key := range failingQueries {
|
||||
delete(failingQueries, key)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
failingQueriesLock.RLock()
|
||||
defer failingQueriesLock.RUnlock()
|
||||
|
||||
// Quickly return if map is empty.
|
||||
if len(failingQueries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if query failed recently.
|
||||
failing, ok := failingQueries[q.ID()]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if failing query should still be regarded as failing.
|
||||
if time.Now().After(failing.Until) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return failing error and until when it's valid.
|
||||
return &failing.Until, failing.Err
|
||||
}
|
||||
|
||||
func addFailingQuery(q *resolver.Query, err error) {
|
||||
// Check if we were given an error.
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Exclude reverse and mDNS queries, as they fail _often_ and are usually not
|
||||
// retried quickly.
|
||||
// if strings.HasSuffix(q.FQDN, ".in-addr.arpa.") ||
|
||||
// strings.HasSuffix(q.FQDN, ".ip6.arpa.") ||
|
||||
// strings.HasSuffix(q.FQDN, ".local.") {
|
||||
// return
|
||||
// }
|
||||
|
||||
failingQueriesLock.Lock()
|
||||
defer failingQueriesLock.Unlock()
|
||||
|
||||
failing, ok := failingQueries[q.ID()]
|
||||
if !ok {
|
||||
failing = &failingQuery{Err: err}
|
||||
failingQueries[q.ID()] = failing
|
||||
}
|
||||
|
||||
// Calculate fail duration.
|
||||
// Initial fail duration will be at 900ms, perfect for a normal retry after 1s,
|
||||
// but not any earlier.
|
||||
failDuration := failingBaseDuration + time.Duration(failing.Times)*failingFactorDuration
|
||||
if failDuration > failingMaxDuration {
|
||||
failDuration = failingMaxDuration
|
||||
}
|
||||
|
||||
// Update failing query.
|
||||
failing.Times++
|
||||
failing.Until = time.Now().Add(failDuration)
|
||||
failing.Keep = failing.Until.Add(failingKeepAddedDuration)
|
||||
}
|
||||
|
||||
func cleanFailingQueries(maxRemove, maxMiss int) {
|
||||
failingQueriesLock.Lock()
|
||||
defer failingQueriesLock.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, failing := range failingQueries {
|
||||
if now.After(failing.Keep) {
|
||||
delete(failingQueries, key)
|
||||
|
||||
maxRemove--
|
||||
if maxRemove == 0 {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
maxMiss--
|
||||
if maxMiss == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
42
service/nameserver/metrics.go
Normal file
42
service/nameserver/metrics.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/metrics"
|
||||
)
|
||||
|
||||
var (
|
||||
requestsHistogram *metrics.Histogram
|
||||
totalHandledRequests *metrics.Counter
|
||||
)
|
||||
|
||||
func registerMetrics() (err error) {
|
||||
requestsHistogram, err = metrics.NewHistogram(
|
||||
"nameserver/request/duration/seconds",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Permission: api.PermitUser,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalHandledRequests, err = metrics.NewCounter(
|
||||
"nameserver/request/total",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
InternalID: "handled_dns_requests",
|
||||
Permission: api.PermitUser,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
Persist: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
290
service/nameserver/module.go
Normal file
290
service/nameserver/module.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portmaster/service/compat"
|
||||
"github.com/safing/portmaster/service/firewall"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
stopListeners bool
|
||||
stopListener1 func() error
|
||||
stopListener2 func() error
|
||||
stopListenersLock sync.Mutex
|
||||
|
||||
eventIDConflictingService = "nameserver:conflicting-service"
|
||||
eventIDListenerFailed = "nameserver:listener-failed"
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("nameserver", prep, start, stop, "core", "resolver")
|
||||
subsystems.Register(
|
||||
"dns",
|
||||
"Secure DNS",
|
||||
"DNS resolver with scoping and DNS-over-TLS",
|
||||
module,
|
||||
"config:dns/",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
return registerConfig()
|
||||
}
|
||||
|
||||
func start() error {
|
||||
if err := registerMetrics(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get listen addresses.
|
||||
ip1, ip2, port, err := getListenAddresses(nameserverAddressConfig())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse nameserver listen address: %w", err)
|
||||
}
|
||||
|
||||
// Tell the compat module where we are listening.
|
||||
compat.SetNameserverListenIP(ip1)
|
||||
|
||||
// Get own hostname.
|
||||
hostname, err = os.Hostname()
|
||||
if err != nil {
|
||||
log.Warningf("nameserver: failed to get hostname: %s", err)
|
||||
}
|
||||
hostname += "."
|
||||
|
||||
// Start listener(s).
|
||||
if ip2 == nil {
|
||||
// Start a single listener.
|
||||
startListener(ip1, port, true)
|
||||
|
||||
// Set nameserver matcher in firewall to fast-track dns queries.
|
||||
if ip1.Equal(net.IPv4zero) || ip1.Equal(net.IPv6zero) {
|
||||
// Fast track dns queries destined for any of the local IPs.
|
||||
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
|
||||
dstIsMe, err := netenv.IsMyIP(ip)
|
||||
if err != nil {
|
||||
log.Warningf("nameserver: failed to check if IP %s is local: %s", ip, err)
|
||||
}
|
||||
return dstIsMe
|
||||
})
|
||||
}
|
||||
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
|
||||
return ip.Equal(ip1)
|
||||
})
|
||||
}
|
||||
|
||||
// Dual listener.
|
||||
startListener(ip1, port, true)
|
||||
startListener(ip2, port, false)
|
||||
|
||||
// Fast track dns queries destined for one of the listener IPs.
|
||||
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
|
||||
return ip.Equal(ip1) || ip.Equal(ip2)
|
||||
})
|
||||
}
|
||||
|
||||
func startListener(ip net.IP, port uint16, first bool) {
|
||||
// Start DNS server as service worker.
|
||||
module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error {
|
||||
// Create DNS server.
|
||||
dnsServer := &dns.Server{
|
||||
Addr: net.JoinHostPort(
|
||||
ip.String(),
|
||||
strconv.Itoa(int(port)),
|
||||
),
|
||||
Net: "udp",
|
||||
Handler: dns.HandlerFunc(handleRequestAsWorker),
|
||||
}
|
||||
|
||||
// Register stop function.
|
||||
func() {
|
||||
stopListenersLock.Lock()
|
||||
defer stopListenersLock.Unlock()
|
||||
|
||||
// Check if we should stop
|
||||
if stopListeners {
|
||||
_ = dnsServer.Shutdown()
|
||||
dnsServer = nil
|
||||
return
|
||||
}
|
||||
|
||||
// Register stop function.
|
||||
if first {
|
||||
stopListener1 = dnsServer.Shutdown
|
||||
} else {
|
||||
stopListener2 = dnsServer.Shutdown
|
||||
}
|
||||
}()
|
||||
|
||||
// Check if we should stop.
|
||||
if dnsServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve generic listener error, if primary listener.
|
||||
if first {
|
||||
module.Resolve(eventIDListenerFailed)
|
||||
}
|
||||
|
||||
// Start listening.
|
||||
log.Infof("nameserver: starting to listen on %s", dnsServer.Addr)
|
||||
err := dnsServer.ListenAndServe()
|
||||
if err != nil {
|
||||
// Stop worker without error if we are shutting down.
|
||||
if module.IsStopping() {
|
||||
return nil
|
||||
}
|
||||
log.Warningf("nameserver: failed to listen on %s: %s", dnsServer.Addr, err)
|
||||
handleListenError(err, ip, port, first)
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) {
|
||||
var n *notifications.Notification
|
||||
|
||||
// Create suffix for secondary listener
|
||||
var secondaryEventIDSuffix string
|
||||
if !primaryListener {
|
||||
secondaryEventIDSuffix = "-secondary"
|
||||
}
|
||||
|
||||
// Find a conflicting service.
|
||||
cfProcess := findConflictingProcess(ip, port)
|
||||
if cfProcess != nil {
|
||||
// Report the conflicting process.
|
||||
|
||||
// Build conflicting process description.
|
||||
var cfDescription string
|
||||
cfName, err := cfProcess.Name()
|
||||
if err == nil && cfName != "" {
|
||||
cfDescription = cfName
|
||||
}
|
||||
cfExe, err := cfProcess.Exe()
|
||||
if err == nil && cfDescription != "" {
|
||||
if cfDescription != "" {
|
||||
cfDescription += " (" + cfExe + ")"
|
||||
} else {
|
||||
cfDescription = cfName
|
||||
}
|
||||
}
|
||||
|
||||
// Notify user about conflicting service.
|
||||
n = notifications.Notify(¬ifications.Notification{
|
||||
EventID: eventIDConflictingService + secondaryEventIDSuffix,
|
||||
Type: notifications.Error,
|
||||
Title: "Conflicting DNS Software",
|
||||
Message: fmt.Sprintf(
|
||||
"Restart Portmaster after you have deactivated or properly configured the conflicting software: %s",
|
||||
cfDescription,
|
||||
),
|
||||
ShowOnSystem: true,
|
||||
AvailableActions: []*notifications.Action{
|
||||
{
|
||||
Text: "Open Docs",
|
||||
Type: notifications.ActionTypeOpenURL,
|
||||
Payload: "https://docs.safing.io/portmaster/install/status/software-compatibility",
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// If no conflict is found, report the error directly.
|
||||
n = notifications.Notify(¬ifications.Notification{
|
||||
EventID: eventIDListenerFailed + secondaryEventIDSuffix,
|
||||
Type: notifications.Error,
|
||||
Title: "Secure DNS Error",
|
||||
Message: fmt.Sprintf(
|
||||
"The internal DNS server failed. Restart Portmaster to try again. Error: %s",
|
||||
err,
|
||||
),
|
||||
ShowOnSystem: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Attach error to module, if primary listener.
|
||||
if primaryListener {
|
||||
n.AttachToModule(module)
|
||||
}
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
stopListenersLock.Lock()
|
||||
defer stopListenersLock.Unlock()
|
||||
|
||||
// Stop listeners.
|
||||
stopListeners = true
|
||||
if stopListener1 != nil {
|
||||
if err := stopListener1(); err != nil {
|
||||
log.Warningf("nameserver: failed to stop listener1: %s", err)
|
||||
}
|
||||
}
|
||||
if stopListener2 != nil {
|
||||
if err := stopListener2(); err != nil {
|
||||
log.Warningf("nameserver: failed to stop listener2: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getListenAddresses(listenAddress string) (ip1, ip2 net.IP, port uint16, err error) {
|
||||
// Split host and port.
|
||||
ipString, portString, err := net.SplitHostPort(listenAddress)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf(
|
||||
"failed to parse address %s: %w",
|
||||
listenAddress,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
// Parse the IP address. If the want to listen on localhost, we need to
|
||||
// listen separately for IPv4 and IPv6.
|
||||
if ipString == "localhost" {
|
||||
ip1 = net.IPv4(127, 0, 0, 17)
|
||||
if netenv.IPv6Enabled() {
|
||||
ip2 = net.IPv6loopback
|
||||
} else {
|
||||
log.Warningf("nameserver: no IPv6 stack detected, disabling IPv6 nameserver listener")
|
||||
}
|
||||
} else {
|
||||
ip1 = net.ParseIP(ipString)
|
||||
if ip1 == nil {
|
||||
return nil, nil, 0, fmt.Errorf(
|
||||
"failed to parse IP %s from %s",
|
||||
ipString,
|
||||
listenAddress,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the port.
|
||||
port64, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf(
|
||||
"failed to parse port %s from %s: %w",
|
||||
portString,
|
||||
listenAddress,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return ip1, ip2, uint16(port64), nil
|
||||
}
|
||||
407
service/nameserver/nameserver.go
Normal file
407
service/nameserver/nameserver.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/firewall"
|
||||
"github.com/safing/portmaster/service/nameserver/nsutil"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
)
|
||||
|
||||
var hostname string
|
||||
|
||||
func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) {
|
||||
err := module.RunWorker("handle dns request", func(ctx context.Context) error {
|
||||
return handleRequest(ctx, w, query)
|
||||
})
|
||||
if err != nil {
|
||||
log.Warningf("nameserver: failed to handle dns request: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:maintidx // TODO
|
||||
// Record metrics.
|
||||
startTime := time.Now()
|
||||
defer requestsHistogram.UpdateDuration(startTime)
|
||||
|
||||
// Check Question, only process the first, that's how everyone does it.
|
||||
var originalQuestion dns.Question
|
||||
switch len(request.Question) {
|
||||
case 0:
|
||||
log.Warning("nameserver: received query without question")
|
||||
return sendResponse(ctx, w, request, nsutil.Refused("no question provided"))
|
||||
case 1:
|
||||
originalQuestion = request.Question[0]
|
||||
default:
|
||||
log.Warningf(
|
||||
"nameserver: received query with multiple questions, first is %s.%s",
|
||||
request.Question[0].Name,
|
||||
dns.Type(request.Question[0].Qtype),
|
||||
)
|
||||
return sendResponse(ctx, w, request, nsutil.Refused("multiple question provided"))
|
||||
}
|
||||
|
||||
// Check the Query Class.
|
||||
if originalQuestion.Qclass != dns.ClassINET {
|
||||
// We only serve IN records.
|
||||
log.Warningf("nameserver: received unsupported qclass %d question for %s", originalQuestion.Qclass, originalQuestion.Name)
|
||||
return sendResponse(ctx, w, request, nsutil.Refused("unsupported qclass"))
|
||||
}
|
||||
|
||||
// Check if we are handling a non-standard query name.
|
||||
var nonStandardQuestionFormat bool
|
||||
lowerCaseQuestion := strings.ToLower(originalQuestion.Name)
|
||||
if lowerCaseQuestion != originalQuestion.Name {
|
||||
nonStandardQuestionFormat = true
|
||||
}
|
||||
|
||||
// Create query for the resolver.
|
||||
q := &resolver.Query{
|
||||
FQDN: lowerCaseQuestion,
|
||||
QType: dns.Type(originalQuestion.Qtype),
|
||||
}
|
||||
|
||||
// Get remote address of request.
|
||||
remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
log.Warningf("nameserver: failed to get remote address of dns query: is type %+T", w.RemoteAddr())
|
||||
return sendResponse(ctx, w, request, nsutil.Refused("unsupported transport"))
|
||||
}
|
||||
|
||||
// Start context tracer for context-aware logging.
|
||||
ctx, tracer := log.AddTracer(ctx)
|
||||
defer tracer.Submit()
|
||||
tracer.Tracef("nameserver: handling new request for %s from %s:%d", q.ID(), remoteAddr.IP, remoteAddr.Port)
|
||||
|
||||
// Count request.
|
||||
totalHandledRequests.Inc()
|
||||
|
||||
// Setup quick reply function.
|
||||
reply := func(responder nsutil.Responder, rrProviders ...nsutil.RRProvider) error {
|
||||
err := sendResponse(ctx, w, request, responder, rrProviders...)
|
||||
// Log error here instead of returning it in order to keep the context.
|
||||
if err != nil {
|
||||
tracer.Errorf("nameserver: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle request for localhost and the hostname.
|
||||
if strings.HasSuffix(q.FQDN, "localhost.") || q.FQDN == hostname {
|
||||
tracer.Tracef("nameserver: returning localhost records")
|
||||
return reply(nsutil.Localhost())
|
||||
}
|
||||
|
||||
// Validate domain name.
|
||||
if !netutils.IsValidFqdn(q.FQDN) {
|
||||
tracer.Debugf("nameserver: domain name %s is invalid, refusing", q.FQDN)
|
||||
return reply(nsutil.Refused("invalid domain"))
|
||||
}
|
||||
|
||||
// Get public suffix after validation.
|
||||
q.InitPublicSuffixData()
|
||||
|
||||
// Check if query is failing.
|
||||
// Some software retries failing queries excessively. This might not be a
|
||||
// problem normally, but handling a request is pretty expensive for the
|
||||
// Portmaster, as it has to find out who sent the query. If we know the query
|
||||
// will fail with a very high probability, it is beneficial to just kill the
|
||||
// query for some time before doing any expensive work.
|
||||
defer cleanFailingQueries(10, 3)
|
||||
failingUntil, failingErr := checkIfQueryIsFailing(q)
|
||||
if failingErr != nil {
|
||||
remainingFailingDuration := time.Until(*failingUntil)
|
||||
tracer.Debugf("nameserver: returning previous error for %s: %s", q.ID(), failingErr)
|
||||
|
||||
// Delay the response a bit in order to mitigate request flooding.
|
||||
if remainingFailingDuration < failingDelay {
|
||||
// Delay for remainind fail duration.
|
||||
tracer.Tracef("nameserver: delaying failing lookup until end of fail duration for %s", remainingFailingDuration.Round(time.Millisecond))
|
||||
time.Sleep(remainingFailingDuration)
|
||||
return reply(nsutil.ServerFailure(
|
||||
"internal error: "+failingErr.Error(),
|
||||
"delayed failing query to mitigate request flooding",
|
||||
))
|
||||
}
|
||||
// Delay for default duration.
|
||||
tracer.Tracef("nameserver: delaying failing lookup for %s", failingDelay.Round(time.Millisecond))
|
||||
time.Sleep(failingDelay)
|
||||
return reply(nsutil.ServerFailure(
|
||||
"internal error: "+failingErr.Error(),
|
||||
"delayed failing query to mitigate request flooding",
|
||||
fmt.Sprintf("error is cached for another %s", remainingFailingDuration.Round(time.Millisecond)),
|
||||
))
|
||||
}
|
||||
|
||||
// Check if the request is local.
|
||||
local, err := netenv.IsMyIP(remoteAddr.IP)
|
||||
if err != nil {
|
||||
tracer.Warningf("nameserver: failed to check if request for %s is local: %s", q.ID(), err)
|
||||
return reply(nsutil.ServerFailure("internal error: failed to check if request is local"))
|
||||
}
|
||||
|
||||
// Create connection ID for dns request.
|
||||
connID := fmt.Sprintf(
|
||||
"%s-%d-#%d-%s",
|
||||
remoteAddr.IP,
|
||||
remoteAddr.Port,
|
||||
request.Id,
|
||||
q.ID(),
|
||||
)
|
||||
|
||||
// Get connection for this request. This identifies the process behind the request.
|
||||
var conn *network.Connection
|
||||
switch {
|
||||
case local:
|
||||
conn = network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, connID, remoteAddr.IP, uint16(remoteAddr.Port))
|
||||
|
||||
case networkServiceMode():
|
||||
conn, err = network.NewConnectionFromExternalDNSRequest(ctx, q.FQDN, nil, connID, remoteAddr.IP)
|
||||
if err != nil {
|
||||
tracer.Warningf("nameserver: failed to get host/profile for request for %s%s: %s", q.FQDN, q.QType, err)
|
||||
return reply(nsutil.ServerFailure("internal error: failed to get profile"))
|
||||
}
|
||||
|
||||
default:
|
||||
tracer.Warningf("nameserver: external request from %s for %s%s, ignoring", remoteAddr, q.FQDN, q.QType)
|
||||
return reply(nsutil.Refused("external queries are not permitted"))
|
||||
}
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
|
||||
// Create reference for the rrCache.
|
||||
var rrCache *resolver.RRCache
|
||||
|
||||
// Once we decided on the connection we might need to save it to the database,
|
||||
// so we defer that check for now.
|
||||
defer func() {
|
||||
// Add metadata to connection.
|
||||
if rrCache != nil {
|
||||
conn.DNSContext = rrCache.ToDNSRequestContext()
|
||||
conn.Resolver = rrCache.Resolver
|
||||
conn.Entity.IPScope = rrCache.Resolver.IPScope
|
||||
} else {
|
||||
// Get resolvers for this query to determine the resolving scope.
|
||||
resolvers, _, _ := resolver.GetResolversInScope(ctx, q)
|
||||
if len(resolvers) > 0 {
|
||||
conn.Entity.IPScope = resolvers[0].Info.IPScope
|
||||
}
|
||||
}
|
||||
|
||||
switch conn.Verdict {
|
||||
// We immediately save blocked, dropped or failed verdicts so
|
||||
// they pop up in the UI.
|
||||
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
|
||||
conn.Save()
|
||||
|
||||
// For undecided or accepted connections we don't save them yet, because
|
||||
// that will happen later anyway.
|
||||
case network.VerdictUndecided, network.VerdictAccept:
|
||||
// Check if we have a response.
|
||||
if rrCache == nil {
|
||||
conn.Failed("internal error: no reply", "")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark successful queries as internal in order to hide them in the simple interface.
|
||||
// These requests were most probably made for another process and only add confusion if listed.
|
||||
if conn.Process().IsSystemResolver() {
|
||||
conn.Internal = true
|
||||
}
|
||||
|
||||
// Save the request as open, as we don't know if there will be a connection or not.
|
||||
network.SaveOpenDNSRequest(q, rrCache, conn)
|
||||
firewall.UpdateIPsAndCNAMEs(q, rrCache, conn)
|
||||
|
||||
case network.VerdictUndeterminable:
|
||||
fallthrough
|
||||
default:
|
||||
tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.VerdictVerb(), conn)
|
||||
}
|
||||
}()
|
||||
|
||||
// Check request with the privacy filter before resolving.
|
||||
firewall.FilterConnection(ctx, conn, nil, true, false)
|
||||
|
||||
// Check if there is a responder from the firewall.
|
||||
// In special cases, the firewall might want to respond the query itself.
|
||||
// A reason for this might be that the request is sink-holed to a forced
|
||||
// IP address in which case we "accept" it, but let the firewall handle
|
||||
// the resolving as it wishes.
|
||||
if responder, ok := conn.Reason.Context.(nsutil.Responder); ok {
|
||||
tracer.Infof("nameserver: handing over request for %s to special filter responder: %s", q.ID(), conn.Reason.Msg)
|
||||
return reply(responder, conn)
|
||||
}
|
||||
|
||||
// Check if there is a Verdict to act upon.
|
||||
switch conn.Verdict { //nolint:exhaustive // Only checking for specific values.
|
||||
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed:
|
||||
tracer.Infof(
|
||||
"nameserver: returning %s response for %s to %s",
|
||||
conn.VerdictVerb(),
|
||||
q.ID(),
|
||||
conn.Process(),
|
||||
)
|
||||
return reply(conn, conn)
|
||||
}
|
||||
|
||||
// Resolve request.
|
||||
rrCache, err = resolver.Resolve(ctx, q)
|
||||
// Handle error.
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, resolver.ErrNotFound):
|
||||
// Try alternatives domain names for unofficial domain spaces.
|
||||
rrCache = checkAlternativeCaches(ctx, q)
|
||||
if rrCache == nil {
|
||||
tracer.Tracef("nameserver: %s", err)
|
||||
conn.Failed("domain does not exist", "")
|
||||
return reply(nsutil.NxDomain("nxdomain: " + err.Error()))
|
||||
}
|
||||
case errors.Is(err, resolver.ErrBlocked):
|
||||
tracer.Tracef("nameserver: %s", err)
|
||||
conn.Block(err.Error(), "")
|
||||
return reply(nsutil.BlockIP("blocked: " + err.Error()))
|
||||
|
||||
case errors.Is(err, resolver.ErrLocalhost):
|
||||
tracer.Tracef("nameserver: returning localhost records")
|
||||
conn.Accept("allowing query for localhost", "")
|
||||
return reply(nsutil.Localhost())
|
||||
|
||||
case errors.Is(err, resolver.ErrOffline):
|
||||
if rrCache == nil {
|
||||
tracer.Debugf("nameserver: not resolving %s, device is offline", q.ID())
|
||||
conn.Failed("not resolving, device is offline", "")
|
||||
return reply(nsutil.ServerFailure(err.Error()))
|
||||
}
|
||||
// If an rrCache was returned, it's usable as a backup.
|
||||
rrCache.IsBackup = true
|
||||
log.Tracer(ctx).Debugf("nameserver: device is offline, using backup cache for %s", q.ID())
|
||||
|
||||
default:
|
||||
tracer.Warningf("nameserver: failed to resolve %s: %s", q.ID(), err)
|
||||
conn.Failed(fmt.Sprintf("query failed: %s", err), "")
|
||||
addFailingQuery(q, err)
|
||||
return reply(nsutil.ServerFailure("internal error: " + err.Error()))
|
||||
}
|
||||
}
|
||||
// Handle special cases.
|
||||
switch {
|
||||
case rrCache == nil:
|
||||
tracer.Warning("nameserver: received successful, but empty reply from resolver")
|
||||
addFailingQuery(q, errors.New("emptry reply from resolver"))
|
||||
return reply(nsutil.ServerFailure("internal error: empty reply"))
|
||||
case rrCache.RCode == dns.RcodeNameError:
|
||||
// Try alternatives domain names for unofficial domain spaces.
|
||||
altRRCache := checkAlternativeCaches(ctx, q)
|
||||
if altRRCache != nil {
|
||||
rrCache = altRRCache
|
||||
} else {
|
||||
// Return now if NXDomain.
|
||||
return reply(nsutil.NxDomain("no answer found (NXDomain)"))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Check with firewall again after resolving.
|
||||
tracer.Trace("nameserver: deciding on resolved dns")
|
||||
rrCache = firewall.FilterResolvedDNS(ctx, conn, q, rrCache)
|
||||
|
||||
// Check again if there is a responder from the firewall.
|
||||
if responder, ok := conn.Reason.Context.(nsutil.Responder); ok {
|
||||
tracer.Infof("nameserver: handing over request for %s to special filter responder: %s", q.ID(), conn.Reason.Msg)
|
||||
return reply(responder)
|
||||
}
|
||||
|
||||
// Check if there is a Verdict to act upon.
|
||||
switch conn.Verdict { //nolint:exhaustive // Only checking for specific values.
|
||||
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed:
|
||||
tracer.Infof(
|
||||
"nameserver: returning %s response for %s to %s",
|
||||
conn.VerdictVerb(),
|
||||
q.ID(),
|
||||
conn.Process(),
|
||||
)
|
||||
return reply(conn, conn, rrCache)
|
||||
}
|
||||
|
||||
// Revert back to non-standard question format, if we had to convert.
|
||||
if nonStandardQuestionFormat {
|
||||
rrCache.ReplaceAnswerNames(originalQuestion.Name)
|
||||
}
|
||||
|
||||
// Reply with successful response.
|
||||
noAnswerIndicator := ""
|
||||
if len(rrCache.Answer) == 0 {
|
||||
noAnswerIndicator = "/no answer"
|
||||
}
|
||||
tracer.Infof(
|
||||
"nameserver: returning %s response (%s%s) for %s to %s",
|
||||
conn.VerdictVerb(),
|
||||
dns.RcodeToString[rrCache.RCode],
|
||||
noAnswerIndicator,
|
||||
q.ID(),
|
||||
conn.Process(),
|
||||
)
|
||||
return reply(rrCache, conn, rrCache)
|
||||
}
|
||||
|
||||
func checkAlternativeCaches(ctx context.Context, q *resolver.Query) *resolver.RRCache {
|
||||
// Do not try alternatives when the query is in a public suffix.
|
||||
// This also includes arpa. and local.
|
||||
if q.ICANNSpace {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the env resolver has something.
|
||||
pmEnvQ := &resolver.Query{
|
||||
FQDN: q.FQDN + "local." + resolver.InternalSpecialUseDomain,
|
||||
QType: q.QType,
|
||||
}
|
||||
rrCache, err := resolver.QueryPortmasterEnv(ctx, pmEnvQ)
|
||||
if err == nil && rrCache != nil && rrCache.RCode == dns.RcodeSuccess {
|
||||
makeAlternativeRecord(ctx, q, rrCache, pmEnvQ.FQDN)
|
||||
return rrCache
|
||||
}
|
||||
|
||||
// Check if we have anything in cache
|
||||
localFQDN := q.FQDN + "local."
|
||||
rrCache, err = resolver.GetRRCache(localFQDN, q.QType)
|
||||
if err == nil && rrCache != nil && rrCache.RCode == dns.RcodeSuccess {
|
||||
makeAlternativeRecord(ctx, q, rrCache, localFQDN)
|
||||
return rrCache
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeAlternativeRecord(ctx context.Context, q *resolver.Query, rrCache *resolver.RRCache, altName string) {
|
||||
log.Tracer(ctx).Debugf("using %s to answer query", altName)
|
||||
|
||||
// Duplicate answers so they match the query.
|
||||
copied := make([]dns.RR, 0, len(rrCache.Answer))
|
||||
for _, answer := range rrCache.Answer {
|
||||
if strings.ToLower(answer.Header().Name) == altName {
|
||||
copiedAnswer := dns.Copy(answer)
|
||||
copiedAnswer.Header().Name = q.FQDN
|
||||
copied = append(copied, copiedAnswer)
|
||||
}
|
||||
}
|
||||
if len(copied) > 0 {
|
||||
rrCache.Answer = append(rrCache.Answer, copied...)
|
||||
}
|
||||
|
||||
// Update the question.
|
||||
rrCache.Domain = q.FQDN
|
||||
}
|
||||
202
service/nameserver/nsutil/nsutil.go
Normal file
202
service/nameserver/nsutil/nsutil.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package nsutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
// ErrNilRR is returned when a parsed RR is nil.
|
||||
var ErrNilRR = errors.New("is nil")
|
||||
|
||||
// Responder defines the interface that any block/deny reason interface
|
||||
// may implement to support sending custom DNS responses for a given reason.
|
||||
// That is, if a reason context implements the Responder interface the
|
||||
// ReplyWithDNS method will be called instead of creating the default
|
||||
// zero-ip response.
|
||||
type Responder interface {
|
||||
// ReplyWithDNS is called when a DNS response to a DNS message is
|
||||
// crafted because the request is either denied or blocked.
|
||||
ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns.Msg
|
||||
}
|
||||
|
||||
// RRProvider defines the interface that any block/deny reason interface
|
||||
// may implement to support adding additional DNS resource records to
|
||||
// the DNS responses extra (additional) section.
|
||||
type RRProvider interface {
|
||||
// GetExtraRRs is called when a DNS response to a DNS message is
|
||||
// crafted because the request is either denied or blocked.
|
||||
GetExtraRRs(ctx context.Context, request *dns.Msg) []dns.RR
|
||||
}
|
||||
|
||||
// ResponderFunc is a convenience type to use a function
|
||||
// directly as a Responder.
|
||||
type ResponderFunc func(ctx context.Context, request *dns.Msg) *dns.Msg
|
||||
|
||||
// ReplyWithDNS implements the Responder interface and calls rf.
|
||||
func (rf ResponderFunc) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns.Msg {
|
||||
return rf(ctx, request)
|
||||
}
|
||||
|
||||
// MarshalJSON disables JSON marshaling for ResponderFunc.
|
||||
func (rf ResponderFunc) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
|
||||
// BlockIP is a ResponderFunc than replies with either 0.0.0.17 or ::17 for
|
||||
// each A or AAAA question respectively. If there is no A or AAAA question, it
|
||||
// defaults to replying with NXDomain.
|
||||
func BlockIP(msgs ...string) ResponderFunc {
|
||||
return createResponderFunc(
|
||||
"blocking",
|
||||
"0.0.0.17",
|
||||
"::17",
|
||||
msgs...,
|
||||
)
|
||||
}
|
||||
|
||||
// ZeroIP is a ResponderFunc than replies with either 0.0.0.0 or :: for each A
|
||||
// or AAAA question respectively. If there is no A or AAAA question, it
|
||||
// defaults to replying with NXDomain.
|
||||
func ZeroIP(msgs ...string) ResponderFunc {
|
||||
return createResponderFunc(
|
||||
"zero ip",
|
||||
"0.0.0.0",
|
||||
"::",
|
||||
msgs...,
|
||||
)
|
||||
}
|
||||
|
||||
// Localhost is a ResponderFunc than replies with localhost IP addresses.
|
||||
// If there is no A or AAAA question, it defaults to replying with NXDomain.
|
||||
func Localhost(msgs ...string) ResponderFunc {
|
||||
return createResponderFunc(
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
msgs...,
|
||||
)
|
||||
}
|
||||
|
||||
func createResponderFunc(responderName, aAnswer, aaaaAnswer string, msgs ...string) ResponderFunc {
|
||||
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
|
||||
reply := new(dns.Msg)
|
||||
hasErr := false
|
||||
|
||||
for _, question := range request.Question {
|
||||
var rr dns.RR
|
||||
var err error
|
||||
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
rr, err = dns.NewRR(question.Name + " 1 IN A " + aAnswer)
|
||||
case dns.TypeAAAA:
|
||||
rr, err = dns.NewRR(question.Name + " 1 IN AAAA " + aaaaAnswer)
|
||||
}
|
||||
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Tracer(ctx).Errorf("nameserver: failed to create %s response for %s: %s", responderName, question.Name, err)
|
||||
hasErr = true
|
||||
case rr != nil:
|
||||
reply.Answer = append(reply.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case hasErr && len(reply.Answer) == 0:
|
||||
reply.SetRcode(request, dns.RcodeServerFailure)
|
||||
case len(reply.Answer) == 0:
|
||||
reply.SetRcode(request, dns.RcodeNameError)
|
||||
default:
|
||||
reply.SetRcode(request, dns.RcodeSuccess)
|
||||
}
|
||||
|
||||
AddMessagesToReply(ctx, reply, log.InfoLevel, msgs...)
|
||||
|
||||
return reply
|
||||
}
|
||||
}
|
||||
|
||||
// NxDomain returns a ResponderFunc that replies with NXDOMAIN.
|
||||
func NxDomain(msgs ...string) ResponderFunc {
|
||||
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
|
||||
reply := new(dns.Msg).SetRcode(request, dns.RcodeNameError)
|
||||
AddMessagesToReply(ctx, reply, log.InfoLevel, msgs...)
|
||||
|
||||
// According to RFC4074 (https://tools.ietf.org/html/rfc4074), there are
|
||||
// nameservers that incorrectly respond with NXDomain instead of an empty
|
||||
// SUCCESS response when there are other RRs for the queried domain name.
|
||||
// This can lead to the software thinking that no RRs exist for that
|
||||
// domain. In order to mitigate this a bit, we slightly delay NXDomain
|
||||
// responses.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
return reply
|
||||
}
|
||||
}
|
||||
|
||||
// Refused returns a ResponderFunc that replies with REFUSED.
|
||||
func Refused(msgs ...string) ResponderFunc {
|
||||
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
|
||||
reply := new(dns.Msg).SetRcode(request, dns.RcodeRefused)
|
||||
AddMessagesToReply(ctx, reply, log.InfoLevel, msgs...)
|
||||
return reply
|
||||
}
|
||||
}
|
||||
|
||||
// ServerFailure returns a ResponderFunc that replies with SERVFAIL.
|
||||
func ServerFailure(msgs ...string) ResponderFunc {
|
||||
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
|
||||
reply := new(dns.Msg).SetRcode(request, dns.RcodeServerFailure)
|
||||
AddMessagesToReply(ctx, reply, log.InfoLevel, msgs...)
|
||||
return reply
|
||||
}
|
||||
}
|
||||
|
||||
// MakeMessageRecord creates an informational resource record that can be added
|
||||
// to the extra section of a reply.
|
||||
func MakeMessageRecord(level log.Severity, msg string) (dns.RR, error) { //nolint:interfacer
|
||||
rr, err := dns.NewRR(fmt.Sprintf(
|
||||
`%s.portmaster. 0 IN TXT "%s"`,
|
||||
strings.ToLower(level.String()),
|
||||
msg,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rr == nil {
|
||||
return nil, ErrNilRR
|
||||
}
|
||||
return rr, nil
|
||||
}
|
||||
|
||||
// AddMessagesToReply creates information resource records using
|
||||
// MakeMessageRecord and immediately adds them to the extra section of the given
|
||||
// reply. If an error occurs, the resource record will not be added, and the
|
||||
// error will be logged.
|
||||
func AddMessagesToReply(ctx context.Context, reply *dns.Msg, level log.Severity, msgs ...string) {
|
||||
for _, msg := range msgs {
|
||||
// Ignore empty messages.
|
||||
if msg == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create resources record.
|
||||
rr, err := MakeMessageRecord(level, msg)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Warningf("nameserver: failed to add message to reply: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to extra section of the reply.
|
||||
reply.Extra = append(reply.Extra, rr)
|
||||
}
|
||||
}
|
||||
72
service/nameserver/response.go
Normal file
72
service/nameserver/response.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/nameserver/nsutil"
|
||||
)
|
||||
|
||||
// sendResponse sends a response to query using w. The response message is
|
||||
// created by responder. If addExtraRRs is not nil and implements the
|
||||
// RRProvider interface then it will be also used to add more RRs in the
|
||||
// extra section.
|
||||
func sendResponse(
|
||||
ctx context.Context,
|
||||
w dns.ResponseWriter,
|
||||
request *dns.Msg,
|
||||
responder nsutil.Responder,
|
||||
rrProviders ...nsutil.RRProvider,
|
||||
) error {
|
||||
// Have the Responder craft a DNS reply.
|
||||
reply := responder.ReplyWithDNS(ctx, request)
|
||||
if reply == nil {
|
||||
// Dropping query.
|
||||
return nil
|
||||
}
|
||||
// Signify that we are a recursive resolver.
|
||||
// While we do not handle recursion directly, we can safely assume, that we
|
||||
// always forward to a recursive resolver.
|
||||
reply.RecursionAvailable = true
|
||||
|
||||
// Add extra RRs through a custom RRProvider.
|
||||
for _, rrProvider := range rrProviders {
|
||||
if rrProvider != nil {
|
||||
rrs := rrProvider.GetExtraRRs(ctx, request)
|
||||
reply.Extra = append(reply.Extra, rrs...)
|
||||
}
|
||||
}
|
||||
|
||||
// Write reply.
|
||||
if err := writeDNSResponse(ctx, w, reply); err != nil {
|
||||
return fmt.Errorf("failed to send response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeDNSResponse(ctx context.Context, w dns.ResponseWriter, m *dns.Msg) (err error) {
|
||||
defer func() {
|
||||
// recover from panic
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = fmt.Errorf("panic: %s", panicErr)
|
||||
log.Tracer(ctx).Debugf("nameserver: panic caused by this msg: %#v", m)
|
||||
}
|
||||
}()
|
||||
|
||||
err = w.WriteMsg(m)
|
||||
if err != nil {
|
||||
// If we receive an error we might have exceeded the message size with all
|
||||
// our extra information records. Retry again without the extra section.
|
||||
log.Tracer(ctx).Tracef("nameserver: retrying to write dns message without extra section, error was: %s", err)
|
||||
m.Extra = nil
|
||||
noExtraErr := w.WriteMsg(m)
|
||||
if noExtraErr == nil {
|
||||
return fmt.Errorf("failed to write dns message without extra section: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user