From 25b1d59663bcfcb495da102a45c207c1d3fcb841 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 25 Oct 2019 13:35:02 +0200 Subject: [PATCH] Revamp intel and nameserver packages --- intel/clients.go | 25 +- intel/config.go | 67 ++++- intel/intel.go | 13 +- intel/main.go | 50 ++-- intel/main_test.go | 29 +- intel/mdns.go | 238 ++++++++++----- intel/resolve.go | 545 ++++++++++++++-------------------- intel/resolver-scopes.go | 276 +++++++++++++++++ intel/resolver.go | 390 +++++++----------------- intel/resolvers.go | 357 ++++++++++++++++++++++ intel/resolvers_test.go | 36 +++ intel/reverse.go | 48 +-- intel/reverse_test.go | 19 +- intel/rrcache.go | 128 ++++---- intel/special.go | 51 ---- nameserver/nameserver.go | 228 ++++++++------ nameserver/only/nameserver.go | 187 +++++++----- nameserver/takeover.go | 48 ++- 18 files changed, 1675 insertions(+), 1060 deletions(-) create mode 100644 intel/resolver-scopes.go create mode 100644 intel/resolvers.go create mode 100644 intel/resolvers_test.go delete mode 100644 intel/special.go diff --git a/intel/clients.go b/intel/clients.go index 4e6302bd..38880eb5 100644 --- a/intel/clients.go +++ b/intel/clients.go @@ -13,7 +13,7 @@ var ( localAddrFactory func(network string) net.Addr ) -// SetLocalAddrFactory supplied the intel package with a function to set local addresses for connections. +// SetLocalAddrFactory supplies the intel package with a function to get permitted local addresses for connections. func SetLocalAddrFactory(laf func(network string) net.Addr) { if localAddrFactory == nil { localAddrFactory = laf @@ -36,11 +36,9 @@ type clientManager struct { ttl time.Duration // force refresh of connection to reduce traceability } -// ref: https://godoc.org/github.com/miekg/dns#Client - -func newDNSClientManager(resolver *Resolver) *clientManager { +func newDNSClientManager(_ *Resolver) *clientManager { return &clientManager{ - // ttl: 1 * time.Minute, + ttl: 0, // new client for every request, as we need to randomize the port factory: func() *dns.Client { return &dns.Client{ Timeout: 5 * time.Second, @@ -52,15 +50,16 @@ func newDNSClientManager(resolver *Resolver) *clientManager { } } -func newTCPClientManager(resolver *Resolver) *clientManager { +func newTCPClientManager(_ *Resolver) *clientManager { return &clientManager{ - // ttl: 5 * time.Minute, + ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) factory: func() *dns.Client { return &dns.Client{ Net: "tcp", Timeout: 5 * time.Second, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), + KeepAlive: 15 * time.Second, }, } }, @@ -69,19 +68,19 @@ func newTCPClientManager(resolver *Resolver) *clientManager { func newTLSClientManager(resolver *Resolver) *clientManager { return &clientManager{ - // ttl: 5 * time.Minute, + ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) factory: func() *dns.Client { return &dns.Client{ Net: "tcp-tls", TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, ServerName: resolver.VerifyDomain, - // TODO: use custom random - // Rand: io.Reader, + // TODO: use portbase rng }, Timeout: 5 * time.Second, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), + KeepAlive: 15 * time.Second, }, } }, @@ -90,18 +89,18 @@ func newTLSClientManager(resolver *Resolver) *clientManager { func newHTTPSClientManager(resolver *Resolver) *clientManager { return &clientManager{ - // ttl: 5 * time.Minute, + ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) factory: func() *dns.Client { new := &dns.Client{ Net: "https", TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, - // TODO: use custom random - // Rand: io.Reader, + // TODO: use portbase rng }, Timeout: 5 * time.Second, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), + KeepAlive: 15 * time.Second, }, } if resolver.VerifyDomain != "" { diff --git a/intel/config.go b/intel/config.go index 6cdd0468..ef697d0e 100644 --- a/intel/config.go +++ b/intel/config.go @@ -1,6 +1,9 @@ package intel import ( + "fmt" + "strings" + "github.com/safing/portbase/config" "github.com/safing/portmaster/status" ) @@ -8,13 +11,18 @@ import ( var ( configuredNameServers config.StringArrayOption defaultNameServers = []string{ + // "dot://9.9.9.9:853?verify=dns.quad9.net&", // Quad9 + // "dot|149.112.112.112:853|dns.quad9.net", // Quad9 + // "dot://[2620:fe::fe]:853?verify=dns.quad9.net&name=Quad9" // Quad9 + // "dot://[2620:fe::9]:853?verify=dns.quad9.net&name=Quad9" // Quad9 + + "dot|1.1.1.1:853|cloudflare-dns.com", // Cloudflare + "dot|1.0.0.1:853|cloudflare-dns.com", // Cloudflare + "dns|9.9.9.9:53", // Quad9 + "dns|149.112.112.112:53", // Quad9 "dns|1.1.1.1:53", // Cloudflare "dns|1.0.0.1:53", // Cloudflare - "dns|9.9.9.9:53", // Quad9 - "tls|1.1.1.1:853|cloudflare-dns.com", // Cloudflare - "tls|1.0.0.1:853|cloudflare-dns.com", // Cloudflare - "tls|9.9.9.9:853|dns.quad9.net", // Quad9 - // "https|cloudflare-dns.com/dns-query", // HTTPS still experimental + // "doh|cloudflare-dns.com/dns-query", // DoH still experimental } nameserverRetryRate config.IntOption @@ -22,15 +30,17 @@ var ( doNotUseAssignedNameservers status.SecurityLevelOption doNotUseInsecureProtocols status.SecurityLevelOption doNotResolveSpecialDomains status.SecurityLevelOption + doNotResolveTestDomains status.SecurityLevelOption ) -func prep() error { +func prepConfig() error { err := config.Register(&config.Option{ Name: "Nameservers (DNS)", Key: "intel/nameservers", Description: "Nameserver to use for resolving DNS requests.", - ExpertiseLevel: config.ExpertiseLevelExpert, OptType: config.OptTypeStringArray, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, DefaultValue: defaultNameServers, ValidationRegex: "^(dns|tcp|tls|https)|[a-z0-9\\.|-]+$", }) @@ -43,8 +53,9 @@ func prep() error { Name: "Nameserver Retry Rate", Key: "intel/nameserverRetryRate", Description: "Rate at which to retry failed nameservers, in seconds.", - ExpertiseLevel: config.ExpertiseLevelExpert, OptType: config.OptTypeInt, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, DefaultValue: 600, }) if err != nil { @@ -56,8 +67,9 @@ func prep() error { Name: "Do not use Multicast DNS", Key: "intel/doNotUseMulticastDNS", Description: "Multicast DNS queries other devices in the local network", - ExpertiseLevel: config.ExpertiseLevelExpert, OptType: config.OptTypeInt, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, ExternalOptType: "security level", DefaultValue: 6, ValidationRegex: "^(7|6|4)$", @@ -71,8 +83,9 @@ func prep() error { Name: "Do not use assigned Nameservers", Key: "intel/doNotUseAssignedNameservers", Description: "that were acquired by the network (dhcp) or system", - ExpertiseLevel: config.ExpertiseLevelExpert, OptType: config.OptTypeInt, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, ExternalOptType: "security level", DefaultValue: 4, ValidationRegex: "^(7|6|4)$", @@ -86,8 +99,9 @@ func prep() error { Name: "Do not resolve insecurely", Key: "intel/doNotUseInsecureProtocols", Description: "Do not resolve domains with insecure protocols, ie. plain DNS", - ExpertiseLevel: config.ExpertiseLevelExpert, OptType: config.OptTypeInt, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, ExternalOptType: "security level", DefaultValue: 4, ValidationRegex: "^(7|6|4)$", @@ -100,11 +114,12 @@ func prep() error { err = config.Register(&config.Option{ Name: "Do not resolve special domains", Key: "intel/doNotResolveSpecialDomains", - Description: "Do not resolve special (top level) domains: example, example.com, example.net, example.org, invalid, test, onion. (RFC6761, RFC7686)", - ExpertiseLevel: config.ExpertiseLevelExpert, + Description: fmt.Sprintf("Do not resolve the special top level domains %s", formatScopeList(specialServiceScopes)), OptType: config.OptTypeInt, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, ExternalOptType: "security level", - DefaultValue: 6, + DefaultValue: 7, ValidationRegex: "^(7|6|4)$", }) if err != nil { @@ -112,5 +127,29 @@ func prep() error { } doNotResolveSpecialDomains = status.ConfigIsActiveConcurrent("intel/doNotResolveSpecialDomains") + err = config.Register(&config.Option{ + Name: "Do not resolve test domains", + Key: "intel/doNotResolveTestDomains", + Description: fmt.Sprintf("Do not resolve the special testing top level domains %s", formatScopeList(localTestScopes)), + OptType: config.OptTypeInt, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, + ExternalOptType: "security level", + DefaultValue: 6, + ValidationRegex: "^(7|6|4)$", + }) + if err != nil { + return err + } + doNotResolveTestDomains = status.ConfigIsActiveConcurrent("intel/doNotResolveTestDomains") + return nil } + +func formatScopeList(list []string) string { + formatted := make([]string, 0, len(list)) + for _, domain := range list { + formatted = append(formatted, strings.Trim(domain, ".")) + } + return strings.Join(formatted, ", ") +} diff --git a/intel/intel.go b/intel/intel.go index 66125213..6103d328 100644 --- a/intel/intel.go +++ b/intel/intel.go @@ -1,11 +1,13 @@ package intel import ( + "context" "fmt" "sync" "github.com/safing/portbase/database" "github.com/safing/portbase/database/record" + "github.com/safing/portbase/log" ) var ( @@ -61,6 +63,13 @@ func (intel *Intel) Save() error { } // GetIntel fetches intelligence data for the given domain. -func GetIntel(domain string) (*Intel, error) { - return &Intel{Domain: domain}, nil +func GetIntel(ctx context.Context, q *Query) (*Intel, error) { + // sanity check + if q == nil || !q.check() { + return nil, ErrInvalid + } + + log.Tracer(ctx).Trace("intel: getting intel") + // TODO + return &Intel{Domain: q.FQDN}, nil } diff --git a/intel/main.go b/intel/main.go index c2afa943..2d84b1c8 100644 --- a/intel/main.go +++ b/intel/main.go @@ -2,8 +2,7 @@ package intel import ( "context" - - "github.com/miekg/dns" + "time" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" @@ -12,30 +11,41 @@ import ( _ "github.com/safing/portmaster/core" ) +var ( + module *modules.Module +) + func init() { - modules.Register("intel", prep, start, nil, "core") + module = modules.Register("intel", prep, start, nil, "core", "network") +} + +func prep() error { + return prepConfig() } func start() error { // load resolvers from config and environment - loadResolvers(false) + loadResolvers() - go listenToMDNS() + err := module.RegisterEventHook( + "network", + "network changed", + "update nameservers", + func(_ context.Context, _ interface{}) error { + loadResolvers() + log.Debug("intel: reloaded nameservers due to network change") + return nil + }, + ) + if err != nil { + return err + } + + module.StartServiceWorker( + "mdns handler", + 5*time.Second, + listenToMDNS, + ) return nil } - -// GetIntelAndRRs returns intel and DNS resource records for the given domain. -func GetIntelAndRRs(ctx context.Context, domain string, qtype dns.Type, securityLevel uint8) (intel *Intel, rrs *RRCache) { - log.Tracer(ctx).Trace("intel: getting intel") - intel, err := GetIntel(domain) - if err != nil { - log.Tracer(ctx).Warningf("intel: failed to get intel: %s", err) - log.Errorf("intel: failed to get intel: %s", err) - intel = nil - } - - log.Tracer(ctx).Tracef("intel: getting records") - rrs = Resolve(ctx, domain, qtype, securityLevel) - return -} diff --git a/intel/main_test.go b/intel/main_test.go index add9a295..c3c4b4da 100644 --- a/intel/main_test.go +++ b/intel/main_test.go @@ -4,34 +4,29 @@ import ( "os" "testing" - "github.com/safing/portbase/database/dbmodule" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/core" ) func TestMain(m *testing.M) { // setup - testDir := os.TempDir() - dbmodule.SetDatabaseLocation(testDir) - err := modules.Start() + tmpDir, err := core.InitForTesting() if err != nil { - if err == modules.ErrCleanExit { - os.Exit(0) - } else { - err = modules.Shutdown() - if err != nil { - log.Shutdown() - } - os.Exit(1) - } + panic(err) } + // setup package + err = prep() + if err != nil { + panic(err) + } + loadResolvers() + // run tests rv := m.Run() // teardown - modules.Shutdown() - os.RemoveAll(testDir) + core.StopTesting() + _ = os.RemoveAll(tmpDir) // exit with test run return value os.Exit(rv) diff --git a/intel/mdns.go b/intel/mdns.go index c9aa675c..7f1672ed 100644 --- a/intel/mdns.go +++ b/intel/mdns.go @@ -25,13 +25,33 @@ var ( unicast4Conn *net.UDPConn unicast6Conn *net.UDPConn - questions = make(map[uint16]savedQuestion) + questions = make(map[uint16]*savedQuestion) questionsLock sync.Mutex + + mDNSResolver = &Resolver{ + Server: ServerSourceMDNS, + ServerType: ServerTypeDNS, + Source: ServerSourceMDNS, + Conn: &mDNSResolverConn{}, + } ) +type mDNSResolverConn struct{} + +func (mrc *mDNSResolverConn) Query(ctx context.Context, q *Query) (*RRCache, error) { + return queryMulticastDNS(ctx, q) +} + +func (mrc *mDNSResolverConn) MarkFailed() {} + +func (mrc *mDNSResolverConn) LastFail() time.Time { + return time.Time{} +} + type savedQuestion struct { question dns.Question - expires int64 + expires time.Time + response chan *RRCache } func indexOfRR(entry *dns.RR_Header, list *[]dns.RR) int { @@ -43,16 +63,23 @@ func indexOfRR(entry *dns.RR_Header, list *[]dns.RR) int { return -1 } -func listenToMDNS() { +//nolint:gocyclo,gocognit // TODO: make simpler +func listenToMDNS(ctx context.Context) error { var err error messages := make(chan *dns.Msg) + // TODO: init and start every listener in its own service worker + // this will make the more resilient and actually able to restart + multicast4Conn, err = net.ListenMulticastUDP("udp4", nil, &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251), Port: 5353}) if err != nil { // TODO: retry after some time log.Warningf("intel(mdns): failed to create udp4 listen multicast socket: %s", err) } else { - go listenForDNSPackets(multicast4Conn, messages) + module.StartServiceWorker("mdns udp4 multicast listener", 0, func(ctx context.Context) error { + return listenForDNSPackets(multicast4Conn, messages) + }) + defer multicast4Conn.Close() } multicast6Conn, err = net.ListenMulticastUDP("udp6", nil, &net.UDPAddr{IP: net.IP([]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb}), Port: 5353}) @@ -60,7 +87,10 @@ func listenToMDNS() { // TODO: retry after some time log.Warningf("intel(mdns): failed to create udp6 listen multicast socket: %s", err) } else { - go listenForDNSPackets(multicast6Conn, messages) + module.StartServiceWorker("mdns udp6 multicast listener", 0, func(ctx context.Context) error { + return listenForDNSPackets(multicast6Conn, messages) + }) + defer multicast6Conn.Close() } unicast4Conn, err = net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) @@ -68,7 +98,10 @@ func listenToMDNS() { // TODO: retry after some time log.Warningf("intel(mdns): failed to create udp4 listen socket: %s", err) } else { - go listenForDNSPackets(unicast4Conn, messages) + module.StartServiceWorker("mdns udp4 unicast listener", 0, func(ctx context.Context) error { + return listenForDNSPackets(unicast4Conn, messages) + }) + defer unicast4Conn.Close() } unicast6Conn, err = net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) @@ -76,14 +109,32 @@ func listenToMDNS() { // TODO: retry after some time log.Warningf("intel(mdns): failed to create udp6 listen socket: %s", err) } else { - go listenForDNSPackets(unicast6Conn, messages) + module.StartServiceWorker("mdns udp6 unicast listener", 0, func(ctx context.Context) error { + return listenForDNSPackets(unicast6Conn, messages) + }) + defer unicast6Conn.Close() } + // start message handler + module.StartServiceWorker("mdns message handler", 0, func(ctx context.Context) error { + return handleMDNSMessages(ctx, messages) + }) + + // wait for shutdown + <-module.Ctx.Done() + return nil +} + +//nolint:gocyclo,gocognit // TODO +func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { for { select { + case <-ctx.Done(): + return nil case message := <-messages: // log.Tracef("intel: got net mdns message: %s", message) + var err error var question *dns.Question var saveFullRequest bool scavengedRecords := make(map[string]dns.RR) @@ -91,7 +142,7 @@ func listenToMDNS() { // save every received response // if previous save was less than 2 seconds ago, add to response, else replace - // pick out A and AAAA records and save seperately + // pick out A and AAAA records and save separately // continue if not response if !message.Response { @@ -111,24 +162,28 @@ func listenToMDNS() { continue } + // return saved question + questionsLock.Lock() + savedQ := questions[message.MsgHdr.Id] + questionsLock.Unlock() + // get question, some servers do not reply with question - if len(message.Question) == 0 { - questionsLock.Lock() - savedQ, ok := questions[message.MsgHdr.Id] - questionsLock.Unlock() - if ok { - question = &savedQ.question - } - } else { + if len(message.Question) > 0 { question = &message.Question[0] + // if questions do not match, disregard saved question + if savedQ != nil && message.Question[0].String() != savedQ.question.String() { + savedQ = nil + } + } else if savedQ != nil { + question = &savedQ.question } if question != nil { // continue if class is not INTERNET if question.Qclass != dns.ClassINET && question.Qclass != DNSClassMulticast { - // log.Tracef("intel: mdns question is not of class INET, ignoring") continue } + // mark request to be saved saveFullRequest = true } @@ -144,7 +199,7 @@ func listenToMDNS() { } for _, entry := range message.Answer { - if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScopes(entry.Header().Name, localReverseScopes) { + if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) { if saveFullRequest { k := indexOfRR(entry.Header(), &rrCache.Answer) if k == -1 { @@ -166,7 +221,7 @@ func listenToMDNS() { } } for _, entry := range message.Ns { - if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScopes(entry.Header().Name, localReverseScopes) { + if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) { if saveFullRequest { k := indexOfRR(entry.Header(), &rrCache.Ns) if k == -1 { @@ -188,7 +243,7 @@ func listenToMDNS() { } } for _, entry := range message.Extra { - if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScopes(entry.Header().Name, localReverseScopes) { + if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) { if saveFullRequest { k := indexOfRR(entry.Header(), &rrCache.Extra) if k == -1 { @@ -213,7 +268,19 @@ func listenToMDNS() { var questionID string if saveFullRequest { rrCache.Clean(60) - rrCache.Save() + err := rrCache.Save() + if err != nil { + log.Warningf("intel: failed to cache RR %s: %s", rrCache.Domain, err) + } + + // return finished response + if savedQ != nil { + select { + case savedQ.response <- rrCache: + default: + } + } + questionID = fmt.Sprintf("%s%s", question.Name, dns.Type(question.Qtype).String()) } @@ -227,107 +294,122 @@ func listenToMDNS() { Answer: []dns.RR{v}, } rrCache.Clean(60) - rrCache.Save() + err := rrCache.Save() + if err != nil { + log.Warningf("intel: failed to cache RR %s: %s", rrCache.Domain, err) + } // log.Tracef("intel: mdns scavenged %s", k) } } cleanSavedQuestions() - } } -func listenForDNSPackets(conn *net.UDPConn, messages chan *dns.Msg) { +func listenForDNSPackets(conn *net.UDPConn, messages chan *dns.Msg) error { buf := make([]byte, 65536) for { - // log.Tracef("debug: listening...") n, err := conn.Read(buf) - // n, _, err := conn.ReadFrom(buf) - // n, _, err := conn.ReadFromUDP(buf) if err != nil { - // log.Tracef("intel: failed to read packet: %s", err) - continue + if module.ShutdownInProgress() { + return nil + } + log.Debugf("intel: failed to read packet: %s", err) + return err } - // log.Tracef("debug: read something...") message := new(dns.Msg) if err = message.Unpack(buf[:n]); err != nil { - // log.Tracef("intel: failed to unpack message: %s", err) + log.Debugf("intel: failed to unpack message: %s", err) continue } - // log.Tracef("debug: parsed message...") messages <- message } } -func queryMulticastDNS(ctx context.Context, fqdn string, qtype dns.Type) (*RRCache, error) { - log.Tracer(ctx).Trace("intel: resolving with mDNS") - - q := new(dns.Msg) - q.SetQuestion(fqdn, uint16(qtype)) - // request unicast response - // q.Question[0].Qclass |= 1 << 15 - q.RecursionDesired = false - - saveQuestion(q) - - questionsLock.Lock() - defer questionsLock.Unlock() - questions[q.MsgHdr.Id] = savedQuestion{ - question: q.Question[0], - expires: time.Now().Add(10 * time.Second).Unix(), - } - - buf, err := q.Pack() - if err != nil { - return nil, fmt.Errorf("failed to pack query: %s", err) - } +func queryMulticastDNS(ctx context.Context, q *Query) (*RRCache, error) { + // check for active connections if unicast4Conn == nil && unicast6Conn == nil { return nil, errors.New("unicast mdns connections not initialized") } - if unicast4Conn != nil && uint16(qtype) != dns.TypeAAAA { - unicast4Conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + + // trace log + log.Tracer(ctx).Trace("intel: resolving with mDNS") + + // create query + dnsQuery := new(dns.Msg) + dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) + // request unicast response + // q.Question[0].Qclass |= 1 << 15 + dnsQuery.RecursionDesired = false + + // create response channel + response := make(chan *RRCache) + + // save question + questionsLock.Lock() + defer questionsLock.Unlock() + questions[dnsQuery.MsgHdr.Id] = &savedQuestion{ + question: dnsQuery.Question[0], + expires: time.Now().Add(10 * time.Second), + response: response, + } + + // pack qeury + buf, err := dnsQuery.Pack() + if err != nil { + return nil, fmt.Errorf("failed to pack query: %s", err) + } + + // send queries + if unicast4Conn != nil && uint16(q.QType) != dns.TypeAAAA { + err = unicast4Conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + if err != nil { + return nil, fmt.Errorf("failed to configure query (set timout): %s", err) + } + _, err = unicast4Conn.WriteToUDP(buf, &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251), Port: 5353}) if err != nil { return nil, fmt.Errorf("failed to send query: %s", err) } } - if unicast6Conn != nil && uint16(qtype) != dns.TypeA { - unicast6Conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + if unicast6Conn != nil && uint16(q.QType) != dns.TypeA { + err = unicast6Conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + if err != nil { + return nil, fmt.Errorf("failed to configure query (set timout): %s", err) + } + _, err = unicast6Conn.WriteToUDP(buf, &net.UDPAddr{IP: net.IP([]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb}), Port: 5353}) if err != nil { return nil, fmt.Errorf("failed to send query: %s", err) } } - time.Sleep(1 * time.Second) - - rrCache, err := GetRRCache(fqdn, qtype) - if err == nil { - return rrCache, nil + // wait for response or timeout + select { + case rrCache := <-response: + if rrCache != nil { + return rrCache, nil + } + case <-time.After(1 * time.Second): + // check cache again + rrCache, err := GetRRCache(q.FQDN, q.QType) + if err != nil { + return rrCache, nil + } } - return nil, nil -} - -func saveQuestion(q *dns.Msg) { - questionsLock.Lock() - defer questionsLock.Unlock() - // log.Tracef("intel: saving mdns question id=%d, name=%s", q.MsgHdr.Id, q.Question[0].Name) - questions[q.MsgHdr.Id] = savedQuestion{ - question: q.Question[0], - expires: time.Now().Add(10 * time.Second).Unix(), - } + return nil, ErrNotFound } func cleanSavedQuestions() { questionsLock.Lock() defer questionsLock.Unlock() - now := time.Now().Unix() - for k, v := range questions { - if v.expires < now { - delete(questions, k) + now := time.Now() + for msgID, savedQuestion := range questions { + if now.After(savedQuestion.expires) { + delete(questions, msgID) } } } diff --git a/intel/resolve.go b/intel/resolve.go index fa0ec343..b3a9d0eb 100644 --- a/intel/resolve.go +++ b/intel/resolve.go @@ -2,10 +2,8 @@ package intel import ( "context" + "errors" "fmt" - "math/rand" - "net" - "strings" "sync" "time" @@ -13,362 +11,261 @@ import ( "github.com/safing/portbase/database" "github.com/safing/portbase/log" - "github.com/safing/portmaster/status" ) -// TODO: make resolver interface for http package +var ( + mtAsyncResolve = "async resolve" -// special tlds: + // basic errors -// localhost. [RFC6761] - respond with 127.0.0.1 and ::1 to A and AAAA queries, else nxdomain + // ErrNotFound is a basic error that will match all "not found" errors + ErrNotFound = errors.New("record does not exist") + // ErrBlocked is basic error that will match all "blocked" errors + ErrBlocked = errors.New("query was blocked") + // ErrLocalhost is returned to *.localhost queries + ErrLocalhost = errors.New("query for localhost") -// local. [RFC6762] - resolve if search, else resolve with mdns -// 10.in-addr.arpa. [RFC6761] -// 16.172.in-addr.arpa. [RFC6761] -// 17.172.in-addr.arpa. [RFC6761] -// 18.172.in-addr.arpa. [RFC6761] -// 19.172.in-addr.arpa. [RFC6761] -// 20.172.in-addr.arpa. [RFC6761] -// 21.172.in-addr.arpa. [RFC6761] -// 22.172.in-addr.arpa. [RFC6761] -// 23.172.in-addr.arpa. [RFC6761] -// 24.172.in-addr.arpa. [RFC6761] -// 25.172.in-addr.arpa. [RFC6761] -// 26.172.in-addr.arpa. [RFC6761] -// 27.172.in-addr.arpa. [RFC6761] -// 28.172.in-addr.arpa. [RFC6761] -// 29.172.in-addr.arpa. [RFC6761] -// 30.172.in-addr.arpa. [RFC6761] -// 31.172.in-addr.arpa. [RFC6761] -// 168.192.in-addr.arpa. [RFC6761] -// 254.169.in-addr.arpa. [RFC6762] -// 8.e.f.ip6.arpa. [RFC6762] -// 9.e.f.ip6.arpa. [RFC6762] -// a.e.f.ip6.arpa. [RFC6762] -// b.e.f.ip6.arpa. [RFC6762] + // detailed errors -// example. [RFC6761] - resolve if search, else return nxdomain -// example.com. [RFC6761] - resolve if search, else return nxdomain -// example.net. [RFC6761] - resolve if search, else return nxdomain -// example.org. [RFC6761] - resolve if search, else return nxdomain -// invalid. [RFC6761] - resolve if search, else return nxdomain -// test. [RFC6761] - resolve if search, else return nxdomain -// onion. [RFC7686] - resolve if search, else return nxdomain + // ErrTestDomainsDisabled wraps ErrBlocked + ErrTestDomainsDisabled = fmt.Errorf("%w: test domains disabled", ErrBlocked) + // ErrSpecialDomainsDisabled wraps ErrBlocked + ErrSpecialDomainsDisabled = fmt.Errorf("%w: special domains disabled", ErrBlocked) + // ErrInvalid wraps ErrNotFound + ErrInvalid = fmt.Errorf("%w: invalid request", ErrNotFound) + // ErrNoCompliance wraps ErrBlocked and is returned when no resolvers were able to comply with the current settings + ErrNoCompliance = fmt.Errorf("%w: no compliant resolvers for this query", ErrBlocked) +) -// resolvers: -// local -// global -// mdns +type Query struct { + FQDN string + QType dns.Type + SecurityLevel uint8 + NoCaching bool + IgnoreFailing bool + LocalResolversOnly bool -// scopes: -// local-inaddr -> local, mdns -// local -> local scopes, mdns -// global -> local scopes, global -// special -> local scopes, local + // internal + dotPrefixedFQDN string +} + +// check runs sanity checks and does some initialization. Returns whether the query passed the basic checks. +func (q *Query) check() (ok bool) { + if q.FQDN == "" { + return false + } + + // init + q.FQDN = dns.Fqdn(q.FQDN) + if q.FQDN == "." { + q.dotPrefixedFQDN = q.FQDN + } else { + q.dotPrefixedFQDN = "." + q.FQDN + } + + return true +} // Resolve resolves the given query for a domain and type and returns a RRCache object or nil, if the query failed. -func Resolve(ctx context.Context, fqdn string, qtype dns.Type, securityLevel uint8) *RRCache { - fqdn = dns.Fqdn(fqdn) +func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) { + // sanity check + if q == nil || !q.check() { + return nil, ErrInvalid + } - // use this to time how long it takes resolve this domain - // timed := time.Now() - // defer log.Tracef("intel: took %s to get resolve %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String()) + // log + log.Tracer(ctx).Tracef("intel: resolving %s%s", q.FQDN, q.QType) - // check cache - rrCache, err := GetRRCache(fqdn, qtype) - if err != nil { - switch err { - case database.ErrNotFound: - default: - log.Tracer(ctx).Warningf("intel: getting RRCache %s%s from database failed: %s", fqdn, qtype.String(), err) - log.Warningf("intel: getting RRCache %s%s from database failed: %s", fqdn, qtype.String(), err) + // check query compliance + if err = q.checkCompliance(); err != nil { + return nil, err + } + + // check the cache + if !q.NoCaching { + rrCache = checkCache(ctx, q) + if rrCache != nil { + rrCache.MixAnswers() + return rrCache, nil } - return resolveAndCache(ctx, fqdn, qtype, securityLevel) - } - if rrCache.TTL <= time.Now().Unix() { - log.Tracer(ctx).Tracef("intel: serving from cache, requesting new. TTL=%d, now=%d", rrCache.TTL, time.Now().Unix()) - // log.Tracef("intel: serving cache, requesting new. TTL=%d, now=%d", rrCache.TTL, time.Now().Unix()) - rrCache.requestingNew = true - go resolveAndCache(nil, fqdn, qtype, securityLevel) - } - - // randomize records to allow dumb clients (who only look at the first record) to reliably connect - for i := range rrCache.Answer { - j := rand.Intn(i + 1) - rrCache.Answer[i], rrCache.Answer[j] = rrCache.Answer[j], rrCache.Answer[i] - } - - return rrCache -} - -func resolveAndCache(ctx context.Context, fqdn string, qtype dns.Type, securityLevel uint8) (rrCache *RRCache) { - log.Tracer(ctx).Tracef("intel: resolving %s%s", fqdn, qtype.String()) - - // dedup requests - dupKey := fmt.Sprintf("%s%s", fqdn, qtype.String()) - dupReqLock.Lock() - mutex, requestActive := dupReqMap[dupKey] - if !requestActive { - mutex = new(sync.Mutex) - mutex.Lock() - dupReqMap[dupKey] = mutex - dupReqLock.Unlock() - } else { - dupReqLock.Unlock() - log.Tracer(ctx).Tracef("intel: waiting for duplicate query for %s to complete", dupKey) - // log.Tracef("intel: waiting for duplicate query for %s to complete", dupKey) - mutex.Lock() - // wait until duplicate request is finished, then fetch current RRCache and return - mutex.Unlock() - var err error - rrCache, err = GetRRCache(dupKey, qtype) - if err == nil { - return rrCache - } - // must have been nxdomain if we cannot get RRCache - return nil - } - defer func() { - dupReqLock.Lock() - delete(dupReqMap, dupKey) - dupReqLock.Unlock() - mutex.Unlock() - }() - - // resolve - rrCache = intelligentResolve(ctx, fqdn, qtype, securityLevel) - if rrCache == nil { - return nil - } - - // persist to database - rrCache.Clean(600) - rrCache.Save() - - return rrCache -} - -func intelligentResolve(ctx context.Context, fqdn string, qtype dns.Type, securityLevel uint8) *RRCache { - - // TODO: handle being offline - // TODO: handle multiple network connections - - // TODO: handle these in a separate goroutine - // if config.Changed() { - // log.Info("intel: config changed, reloading resolvers") - // loadResolvers(false) - // } else if env.NetworkChanged() { - // log.Info("intel: network changed, reloading resolvers") - // loadResolvers(true) - // } - - resolversLock.RLock() - defer resolversLock.RUnlock() - - lastFailBoundary := time.Now().Unix() - nameserverRetryRate() - preDottedFqdn := "." + fqdn - - // resolve: - // reverse local -> local, mdns - // local -> local scopes, mdns - // special -> local scopes, local - // global -> local scopes, global - - // local reverse scope - if domainInScopes(preDottedFqdn, localReverseScopes) { - // try local resolvers - for _, resolver := range localResolvers { - rrCache, ok := tryResolver(ctx, resolver, lastFailBoundary, fqdn, qtype, securityLevel) - if ok && rrCache != nil && !rrCache.IsNXDomain() { - return rrCache + // dedupe! + markRequestFinished := deduplicateRequest(ctx, q) + if markRequestFinished == nil { + // we waited for another request, recheck the cache! + rrCache = checkCache(ctx, q) + if rrCache != nil { + rrCache.MixAnswers() + return rrCache, nil } - } - // check config - if doNotUseMulticastDNS(securityLevel) { - return nil - } - // try mdns - rrCache, err := queryMulticastDNS(ctx, fqdn, qtype) - if err != nil { - log.Tracer(ctx).Warningf("intel: failed to query mdns: %s", err) - log.Errorf("intel: failed to query mdns: %s", err) - } - return rrCache - } - - // local scopes - for _, scope := range localScopes { - if strings.HasSuffix(preDottedFqdn, scope.Domain) { - for _, resolver := range scope.Resolvers { - rrCache, ok := tryResolver(ctx, resolver, lastFailBoundary, fqdn, qtype, securityLevel) - if ok && rrCache != nil && !rrCache.IsNXDomain() { - return rrCache - } - } - } - } - - switch { - case strings.HasSuffix(preDottedFqdn, ".local."): - // check config - if doNotUseMulticastDNS(securityLevel) { - return nil - } - // try mdns - rrCache, err := queryMulticastDNS(ctx, fqdn, qtype) - if err != nil { - log.Tracer(ctx).Warningf("intel: failed to query mdns: %s", err) - log.Errorf("intel: failed to query mdns: %s", err) - } - return rrCache - case domainInScopes(preDottedFqdn, specialScopes): - // check config - if doNotResolveSpecialDomains(securityLevel) { - return nil - } - // try local resolvers - for _, resolver := range localResolvers { - rrCache, ok := tryResolver(ctx, resolver, lastFailBoundary, fqdn, qtype, securityLevel) - if ok { - return rrCache - } - } - default: - // try global resolvers - for _, resolver := range globalResolvers { - rrCache, ok := tryResolver(ctx, resolver, lastFailBoundary, fqdn, qtype, securityLevel) - if ok { - return rrCache - } - } - } - - log.Tracer(ctx).Warningf("intel: failed to resolve %s%s: all resolvers failed (or were skipped to fulfill the security level)", fqdn, qtype.String()) - log.Criticalf("intel: failed to resolve %s%s: all resolvers failed (or were skipped to fulfill the security level), resetting servers...", fqdn, qtype.String()) - go resetResolverFailStatus() - - return nil - - // TODO: check if there would be resolvers available in lower security modes and alert user -} - -func tryResolver(ctx context.Context, resolver *Resolver, lastFailBoundary int64, fqdn string, qtype dns.Type, securityLevel uint8) (*RRCache, bool) { - log.Tracer(ctx).Tracef("intel: resolving with %s", resolver) - - // skip if not security level denies insecure protocols - if doNotUseInsecureProtocols(securityLevel) && resolver.ServerType == "dns" { - log.Tracer(ctx).Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, status.ActiveSecurityLevel(), securityLevel) - return nil, false - } - - // skip if not security level denies assigned dns servers - if doNotUseAssignedNameservers(securityLevel) && resolver.Source == "dhcp" { - log.Tracer(ctx).Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d", resolver, status.ActiveSecurityLevel(), securityLevel) - return nil, false - } - // check if failed recently - if resolver.LastFail() > lastFailBoundary { - log.Tracer(ctx).Tracef("intel: skipping resolver %s, because it failed recently", resolver) - return nil, false - } - // TODO: put SkipFqdnBeforeInit back into !resolver.Initialized.IsSet() as soon as Go1.9 arrives and we can use a custom resolver - // skip resolver if initializing and fqdn is set to skip - if fqdn == resolver.SkipFqdnBeforeInit { - log.Tracer(ctx).Tracef("intel: skipping resolver %s, because %s is set to be skipped before init", resolver, fqdn) - return nil, false - } - // check if resolver is already initialized - if !resolver.Initialized() { - // first should init, others wait - resolver.InitLock.Lock() - if resolver.Initialized() { - // unlock immediately if resolver was initialized - resolver.InitLock.Unlock() + // if cache is still empty or non-compliant, go ahead and just query } else { - // initialize and unlock when finished - defer resolver.InitLock.Unlock() - } - // check if previous init failed - if resolver.LastFail() > lastFailBoundary { - return nil, false - } - } - // resolve - rrCache, err := query(ctx, resolver, fqdn, qtype) - if err != nil { - // check if failing is disabled - if resolver.LastFail() == -1 { - log.Tracer(ctx).Tracef("intel: non-failing resolver %s failed, moving to next: %s", resolver, err) - // log.Tracef("intel: non-failing resolver %s failed (%s), moving to next", resolver, err) - return nil, false - } - log.Tracer(ctx).Warningf("intel: resolver %s failed, moving to next: %s", resolver, err) - log.Warningf("intel: resolver %s failed, moving to next: %s", resolver, err) - resolver.Lock() - resolver.failReason = err.Error() - resolver.lastFail = time.Now().Unix() - resolver.initialized = false - resolver.Unlock() - return nil, false - } - resolver.Lock() - resolver.initialized = true - resolver.Unlock() + // we are the first! + defer markRequestFinished() - return rrCache, true + } + } + + return resolveAndCache(ctx, q) } -func query(ctx context.Context, resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) { +func checkCache(ctx context.Context, q *Query) *RRCache { + rrCache, err := GetRRCache(q.FQDN, q.QType) - q := new(dns.Msg) - q.SetQuestion(fqdn, uint16(qtype)) + // failed to get from cache + if err != nil { + if err != database.ErrNotFound { + log.Tracer(ctx).Warningf("intel: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) + log.Warningf("intel: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) + } + return nil + } - var reply *dns.Msg - var err error - for i := 0; i < 3; i++ { + // get resolver that rrCache was resolved with + resolver := getResolverByIDWithLocking(rrCache.Server) + if resolver == nil { + return nil + } - // log query time - // qStart := time.Now() - reply, _, err = resolver.clientManager.getDNSClient().Exchange(q, resolver.ServerAddress) - // log.Tracef("intel: query to %s took %s", resolver.Server, time.Now().Sub(qStart)) + // check compliance of resolver + err = resolver.checkCompliance(ctx, q) + if err != nil { + log.Tracer(ctx).Debugf("intel: cached entry for %s%s does not comply to query parameters: %s", q.FQDN, q.QType.String(), err) + return nil + } - // error handling - if err != nil { - log.Tracer(ctx).Tracef("intel: query to %s encountered error: %s", resolver.Server, err) + // check if expired + if rrCache.Expired() { + rrCache.Lock() + rrCache.requestingNew = true + rrCache.Unlock() - // TODO: handle special cases - // 1. connect: network is unreachable - // 2. timeout + log.Tracer(ctx).Trace("intel: serving from cache, requesting new") - // temporary error - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - log.Tracer(ctx).Tracef("intel: retrying to resolve %s%s with %s, error is temporary", fqdn, qtype, resolver.Server) + // resolve async + module.StartMediumPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error { + _, _ = resolveAndCache(ctx, q) + return nil + }) + } + + return rrCache +} + +func deduplicateRequest(ctx context.Context, q *Query) (finishRequest func()) { + // create identifier key + dupKey := fmt.Sprintf("%s%s", q.FQDN, q.QType.String()) + + dupReqLock.Lock() + defer dupReqLock.Unlock() + + // get duplicate request waitgroup + wg, requestActive := dupReqMap[dupKey] + + // someone else is already on it! + if requestActive { + // log that we are waiting + log.Tracer(ctx).Tracef("intel: waiting for duplicate query for %s to complete", dupKey) + // wait + wg.Wait() + // done! + return nil + } + + // we are currently the only one doing a request for this + + // create new waitgroup + wg = new(sync.WaitGroup) + // add worker (us!) + wg.Add(1) + // add to registry + dupReqMap[dupKey] = wg + + // return function to mark request as finished + return func() { + dupReqLock.Lock() + defer dupReqLock.Unlock() + // mark request as done + wg.Done() + // delete from registry + delete(dupReqMap, dupKey) + } +} + +func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error) { + // get resolvers + resolvers := GetResolversInScope(ctx, q) + if len(resolvers) == 0 { + return nil, ErrNoCompliance + } + + // prep + lastFailBoundary := time.Now().Add( + -time.Duration(nameserverRetryRate()) * time.Second, + ) + + // start resolving + + var i int + // once with skipping recently failed resolvers, once without +resolveLoop: + for i = 0; i < 2; i++ { + for _, resolver := range resolvers { + // check if resolver failed recently (on first run) + if i == 0 && resolver.Conn.LastFail().After(lastFailBoundary) { + log.Tracer(ctx).Tracef("intel: skipping resolver %s, because it failed recently", resolver) continue } - // permanent error - break - } + // resolve + rrCache, err = resolver.Conn.Query(ctx, q) + if err != nil { - // no error - break + // FIXME: check if we are online? + + switch { + case errors.Is(err, ErrNotFound): + // NXDomain, or similar + return nil, err + case errors.Is(err, ErrBlocked): + // some resolvers might also block + return nil, err + } + } else { + // no error + if rrCache == nil { + // defensive: assume NXDomain + return nil, ErrNotFound + } + break resolveLoop + } + } } + // tried all resolvers, possibly twice + if i > 1 { + return nil, fmt.Errorf("all %d query-compliant resolvers failed, last error: %s", len(resolvers), err) + } + + // check for error if err != nil { return nil, err } - new := &RRCache{ - Domain: fqdn, - Question: qtype, - Answer: reply.Answer, - Ns: reply.Ns, - Extra: reply.Extra, - Server: resolver.Server, - ServerScope: resolver.ServerIPScope, + // check for result + if rrCache == nil /* defensive */ { + return nil, ErrNotFound } - // TODO: check if reply.Answer is valid - return new, nil + // cache if enabled + if !q.NoCaching { + // persist to database + rrCache.Clean(600) + err = rrCache.Save() + if err != nil { + log.Warningf("intel: failed to cache RR for %s%s: %s", q.FQDN, q.QType.String(), err) + } + } + + return rrCache, nil } diff --git a/intel/resolver-scopes.go b/intel/resolver-scopes.go new file mode 100644 index 00000000..2da53b38 --- /dev/null +++ b/intel/resolver-scopes.go @@ -0,0 +1,276 @@ +package intel + +import ( + "context" + "errors" + "strings" + + "github.com/miekg/dns" + "github.com/safing/portbase/log" +) + +// special scopes: + +// localhost. [RFC6761] - respond with 127.0.0.1 and ::1 to A and AAAA queries, else nxdomain + +// local. [RFC6762] - resolve if search, else resolve with mdns +// 10.in-addr.arpa. [RFC6761] +// 16.172.in-addr.arpa. [RFC6761] +// 17.172.in-addr.arpa. [RFC6761] +// 18.172.in-addr.arpa. [RFC6761] +// 19.172.in-addr.arpa. [RFC6761] +// 20.172.in-addr.arpa. [RFC6761] +// 21.172.in-addr.arpa. [RFC6761] +// 22.172.in-addr.arpa. [RFC6761] +// 23.172.in-addr.arpa. [RFC6761] +// 24.172.in-addr.arpa. [RFC6761] +// 25.172.in-addr.arpa. [RFC6761] +// 26.172.in-addr.arpa. [RFC6761] +// 27.172.in-addr.arpa. [RFC6761] +// 28.172.in-addr.arpa. [RFC6761] +// 29.172.in-addr.arpa. [RFC6761] +// 30.172.in-addr.arpa. [RFC6761] +// 31.172.in-addr.arpa. [RFC6761] +// 168.192.in-addr.arpa. [RFC6761] +// 254.169.in-addr.arpa. [RFC6762] +// 8.e.f.ip6.arpa. [RFC6762] +// 9.e.f.ip6.arpa. [RFC6762] +// a.e.f.ip6.arpa. [RFC6762] +// b.e.f.ip6.arpa. [RFC6762] + +// example. [RFC6761] - resolve if search, else return nxdomain +// example.com. [RFC6761] - resolve if search, else return nxdomain +// example.net. [RFC6761] - resolve if search, else return nxdomain +// example.org. [RFC6761] - resolve if search, else return nxdomain +// invalid. [RFC6761] - resolve if search, else return nxdomain +// test. [RFC6761] - resolve if search, else return nxdomain +// onion. [RFC7686] - resolve if search, else return nxdomain + +// resolvers: +// local +// global +// mdns + +var ( + // RFC6761 - respond with 127.0.0.1 and ::1 to A and AAAA queries respectively, else nxdomain + localhost = ".localhost." + + // RFC6761 - always respond with nxdomain + invalid = ".invalid." + + // RFC6762 - resolve locally + local = ".local." + + // local reverse dns + localReverseScopes = []string{ + ".10.in-addr.arpa.", // RFC6761 + ".16.172.in-addr.arpa.", // RFC6761 + ".17.172.in-addr.arpa.", // RFC6761 + ".18.172.in-addr.arpa.", // RFC6761 + ".19.172.in-addr.arpa.", // RFC6761 + ".20.172.in-addr.arpa.", // RFC6761 + ".21.172.in-addr.arpa.", // RFC6761 + ".22.172.in-addr.arpa.", // RFC6761 + ".23.172.in-addr.arpa.", // RFC6761 + ".24.172.in-addr.arpa.", // RFC6761 + ".25.172.in-addr.arpa.", // RFC6761 + ".26.172.in-addr.arpa.", // RFC6761 + ".27.172.in-addr.arpa.", // RFC6761 + ".28.172.in-addr.arpa.", // RFC6761 + ".29.172.in-addr.arpa.", // RFC6761 + ".30.172.in-addr.arpa.", // RFC6761 + ".31.172.in-addr.arpa.", // RFC6761 + ".168.192.in-addr.arpa.", // RFC6761 + ".254.169.in-addr.arpa.", // RFC6762 + ".8.e.f.ip6.arpa.", // RFC6762 + ".9.e.f.ip6.arpa.", // RFC6762 + ".a.e.f.ip6.arpa.", // RFC6762 + ".b.e.f.ip6.arpa.", // RFC6762 + } + + // RFC6761 - only resolve locally + localTestScopes = []string{ + ".example.", + ".example.com.", + ".example.net.", + ".example.org.", + ".test.", + } + + // resolve globally - resolving these should be disabled by default + specialServiceScopes = []string{ + ".onion.", // Tor Hidden Services, RFC7686 + ".bit.", // Namecoin, https://www.namecoin.org/ + } +) + +func domainInScope(dotPrefixedFQDN string, scopeList []string) bool { + for _, scope := range scopeList { + if strings.HasSuffix(dotPrefixedFQDN, scope) { + return true + } + } + return false +} + +// GetResolversInScope returns all resolvers that are in scope the resolve the given query and options. +func GetResolversInScope(ctx context.Context, q *Query) (selected []*Resolver) { + resolversLock.RLock() + defer resolversLock.RUnlock() + + // resolver selection: + // local -> local scopes, mdns + // local-inaddr -> local, mdns + // global -> local scopes, global + // special -> local scopes, local + + // check local scopes + for _, scope := range localScopes { + if strings.HasSuffix(q.dotPrefixedFQDN, scope.Domain) { + // scoped resolvers + for _, resolver := range scope.Resolvers { + if err := resolver.checkCompliance(ctx, q); err == nil { + selected = append(selected, resolver) + } else { + log.Tracef("skipping non-compliant resolver: %s", resolver.Server) + } + } + } + } + // if there was a match with a local scope, stop here + if len(selected) > 0 { + // add mdns + if err := mDNSResolver.checkCompliance(ctx, q); err == nil { + selected = append(selected, mDNSResolver) + } else { + log.Tracef("skipping non-compliant resolver: %s", mDNSResolver.Server) + } + return selected + } + + // check local reverse scope + if domainInScope(q.dotPrefixedFQDN, localReverseScopes) { + // local resolvers + for _, resolver := range localResolvers { + if err := resolver.checkCompliance(ctx, q); err == nil { + selected = append(selected, resolver) + } else { + log.Tracef("skipping non-compliant resolver: %s", resolver.Server) + } + } + // mdns resolver + if err := mDNSResolver.checkCompliance(ctx, q); err == nil { + selected = append(selected, mDNSResolver) + } else { + log.Tracef("skipping non-compliant resolver: %s", mDNSResolver.Server) + } + return selected + } + + // check for .local mdns + if strings.HasSuffix(q.dotPrefixedFQDN, local) { + // add mdns + if err := mDNSResolver.checkCompliance(ctx, q); err == nil { + selected = append(selected, mDNSResolver) + } else { + log.Tracef("skipping non-compliant resolver: %s", mDNSResolver.Server) + } + return selected + } + + // check for test scopes + if domainInScope(q.dotPrefixedFQDN, localTestScopes) { + // local resolvers + for _, resolver := range localResolvers { + if err := resolver.checkCompliance(ctx, q); err == nil { + selected = append(selected, resolver) + } else { + log.Tracef("skipping non-compliant resolver: %s", resolver.Server) + } + } + return selected + } + + // finally, query globally + for _, resolver := range globalResolvers { + if err := resolver.checkCompliance(ctx, q); err == nil { + selected = append(selected, resolver) + } else { + log.Tracef("skipping non-compliant resolver: %s", resolver.Server) + } + } + return selected +} + +var ( + errInsecureProtocol = errors.New("insecure protocols disabled") + errAssignedServer = errors.New("assigned (dhcp) nameservers disabled") + errMulticastDNS = errors.New("multicast DNS disabled") + errSkip = errors.New("this fqdn cannot resolved by this resolver") +) + +func (q *Query) checkCompliance() error { + // RFC6761 - always respond with nxdomain + if strings.HasSuffix(q.dotPrefixedFQDN, invalid) { + return ErrNotFound + } + + // RFC6761 - respond with 127.0.0.1 and ::1 to A and AAAA queries respectively, else nxdomain + if strings.HasSuffix(q.dotPrefixedFQDN, localhost) { + switch uint16(q.QType) { + case dns.TypeA, dns.TypeAAAA: + return ErrLocalhost + default: + return ErrNotFound + } + } + + // special TLDs + if doNotResolveSpecialDomains(q.SecurityLevel) && + domainInScope(q.dotPrefixedFQDN, specialServiceScopes) { + return ErrSpecialDomainsDisabled + } + + // testing TLDs + if doNotResolveTestDomains(q.SecurityLevel) && + domainInScope(q.dotPrefixedFQDN, localTestScopes) { + return ErrTestDomainsDisabled + } + + return nil +} + +func (resolver *Resolver) checkCompliance(_ context.Context, q *Query) error { + if q.FQDN == resolver.SkipFQDN { + return errSkip + } + + if doNotUseInsecureProtocols(q.SecurityLevel) { + switch resolver.ServerType { + case ServerTypeDNS: + return errInsecureProtocol + case ServerTypeTCP: + return errInsecureProtocol + case ServerTypeDoT: + // compliant + case ServerTypeDoH: + // compliant + default: + return errInsecureProtocol + } + } + + if doNotUseAssignedNameservers(q.SecurityLevel) { + if resolver.Source == ServerSourceAssigned { + return errAssignedServer + } + } + + if doNotUseMulticastDNS(q.SecurityLevel) { + if resolver.Source == ServerSourceMDNS { + return errMulticastDNS + } + } + + return nil +} diff --git a/intel/resolver.go b/intel/resolver.go index e7e830cc..a0d74a77 100644 --- a/intel/resolver.go +++ b/intel/resolver.go @@ -1,331 +1,151 @@ package intel import ( - "errors" - "fmt" + "context" "net" - "sort" - "strconv" - "strings" "sync" + "time" "github.com/miekg/dns" - "github.com/safing/portbase/log" - "github.com/safing/portmaster/network/environment" - "github.com/safing/portmaster/network/netutils" +) + +const ( + ServerTypeDNS = "dns" + ServerTypeTCP = "tcp" + ServerTypeDoT = "dot" + ServerTypeDoH = "doh" + + ServerSourceConfigured = "config" + ServerSourceAssigned = "dhcp" + ServerSourceMDNS = "mdns" ) // Resolver holds information about an active resolver. type Resolver struct { - sync.Mutex + // Server config url (and ID) + Server string - // static - Server string + // Parsed config ServerType string ServerAddress string ServerIP net.IP ServerIPScope int8 ServerPort uint16 - VerifyDomain string - Source string - clientManager *clientManager - Search *[]string - SkipFqdnBeforeInit string + // Special Options + VerifyDomain string + Search []string + SkipFQDN string - InitLock sync.Mutex + Source string - // must be locked - initialized bool - lastFail int64 - failReason string - fails int - expires int64 - - // TODO: add Expiration (for server got from DHCP / ICMPv6) + // logic interface + Conn ResolverConn } -// Initialized returns the internal initialized value while locking the Resolver. -func (r *Resolver) Initialized() bool { - r.Lock() - defer r.Unlock() - return r.initialized +// String returns the URL representation of the resolver. +func (resolver *Resolver) String() string { + return resolver.Server +} + +// ResolverConn is an interface to implement different types of query backends. +type ResolverConn interface { + Query(ctx context.Context, q *Query) (*RRCache, error) + MarkFailed() + LastFail() time.Time +} + +// BasicResolverConn implements ResolverConn for standard dns clients. +type BasicResolverConn struct { + sync.Mutex // for lastFail + + resolver *Resolver + clientManager *clientManager + lastFail time.Time +} + +// MarkFailed marks the resolver as failed. +func (brc *BasicResolverConn) MarkFailed() { + if !environment.Online() { + // don't mark failed if we are offline + return + } + + brc.Lock() + defer brc.Unlock() + brc.lastFail = time.Now() } // LastFail returns the internal lastfail value while locking the Resolver. -func (r *Resolver) LastFail() int64 { - r.Lock() - defer r.Unlock() - return r.lastFail +func (brc *BasicResolverConn) LastFail() time.Time { + brc.Lock() + defer brc.Unlock() + return brc.lastFail } -// FailReason returns the internal failreason value while locking the Resolver. -func (r *Resolver) FailReason() string { - r.Lock() - defer r.Unlock() - return r.failReason -} +func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, error) { + // convenience + resolver := brc.resolver -// Fails returns the internal fails value while locking the Resolver. -func (r *Resolver) Fails() int { - r.Lock() - defer r.Unlock() - return r.fails -} + // create query + dnsQuery := new(dns.Msg) + dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) -// Expires returns the internal expires value while locking the Resolver. -func (r *Resolver) Expires() int64 { - r.Lock() - defer r.Unlock() - return r.expires -} + // start + var reply *dns.Msg + var err error + for i := 0; i < 3; i++ { -func (r *Resolver) String() string { - return r.Server -} + // log query time + // qStart := time.Now() + reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress) + // log.Tracef("intel: query to %s took %s", resolver.Server, time.Now().Sub(qStart)) -// Scope defines a domain scope and which resolvers can resolve it. -type Scope struct { - Domain string - Resolvers []*Resolver -} + // error handling + if err != nil { + log.Tracer(ctx).Tracef("intel: query to %s encountered error: %s", resolver.Server, err) -var ( - globalResolvers []*Resolver // all resolvers - localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges - localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope - resolversLock sync.RWMutex + // TODO: handle special cases + // 1. connect: network is unreachable + // 2. timeout - env = environment.NewInterface() + // hint network environment at failed connection + environment.ReportFailedConnection() - dupReqMap = make(map[string]*sync.Mutex) - dupReqLock sync.Mutex -) - -func indexOfResolver(server string, list []*Resolver) int { - for k, v := range list { - if v.Server == server { - return k - } - } - return -1 -} - -func indexOfScope(domain string, list []*Scope) int { - for k, v := range list { - if v.Domain == domain { - return k - } - } - return -1 -} - -func parseAddress(server string) (net.IP, uint16, error) { - delimiter := strings.LastIndex(server, ":") - if delimiter < 0 { - return nil, 0, errors.New("port missing") - } - ip := net.ParseIP(strings.Trim(server[:delimiter], "[]")) - if ip == nil { - return nil, 0, errors.New("invalid IP address") - } - port, err := strconv.Atoi(server[delimiter+1:]) - if err != nil || port < 1 || port > 65536 { - return nil, 0, errors.New("invalid port") - } - return ip, uint16(port), nil -} - -func urlFormatAddress(ip net.IP, port uint16) string { - var address string - if ipv4 := ip.To4(); ipv4 != nil { - address = fmt.Sprintf("%s:%d", ipv4.String(), port) - } else { - address = fmt.Sprintf("[%s]:%d", ip.String(), port) - } - return address -} - -func loadResolvers(resetResolvers bool) { - // TODO: what happens when a lot of processes want to reload at once? we do not need to run this multiple times in a short time frame. - resolversLock.Lock() - defer resolversLock.Unlock() - - var newResolvers []*Resolver - -configuredServersLoop: - for _, server := range configuredNameServers() { - key := indexOfResolver(server, newResolvers) - if key >= 0 { - continue configuredServersLoop - } - key = indexOfResolver(server, globalResolvers) - if resetResolvers || key == -1 { - - parts := strings.Split(server, "|") - if len(parts) < 2 { - log.Warningf("intel: nameserver format invalid: %s", server) - continue configuredServersLoop + // temporary error + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + log.Tracer(ctx).Tracef("intel: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server) + continue } - ip, port, err := parseAddress(parts[1]) - if err != nil && strings.ToLower(parts[0]) != "https" { - log.Warningf("intel: nameserver (%s) address invalid: %s", server, err) - continue configuredServersLoop - } - - new := &Resolver{ - Server: server, - ServerType: strings.ToLower(parts[0]), - ServerAddress: parts[1], - ServerIP: ip, - ServerIPScope: netutils.ClassifyIP(ip), - ServerPort: port, - Source: "config", - } - - switch new.ServerType { - case "dns": - new.clientManager = newDNSClientManager(new) - case "tcp": - new.clientManager = newTCPClientManager(new) - case "tls": - if len(parts) < 3 { - log.Warningf("intel: nameserver missing verification domain as third parameter: %s", server) - continue configuredServersLoop - } - new.VerifyDomain = parts[2] - new.clientManager = newTLSClientManager(new) - case "https": - new.SkipFqdnBeforeInit = dns.Fqdn(strings.Split(parts[1], ":")[0]) - if len(parts) > 2 { - new.VerifyDomain = parts[2] - } - new.clientManager = newHTTPSClientManager(new) - default: - log.Warningf("intel: nameserver (%s) type invalid: %s", server, parts[0]) - continue configuredServersLoop - } - newResolvers = append(newResolvers, new) - } else { - newResolvers = append(newResolvers, globalResolvers[key]) + // permanent error + break } + + // no error + break } - // add local resolvers - assignedNameservers := environment.Nameservers() -assignedServersLoop: - for _, nameserver := range assignedNameservers { - server := fmt.Sprintf("dns|%s", urlFormatAddress(nameserver.IP, 53)) - key := indexOfResolver(server, newResolvers) - if key >= 0 { - continue assignedServersLoop - } - key = indexOfResolver(server, globalResolvers) - if resetResolvers || key == -1 { - - new := &Resolver{ - Server: server, - ServerType: "dns", - ServerAddress: urlFormatAddress(nameserver.IP, 53), - ServerIP: nameserver.IP, - ServerIPScope: netutils.ClassifyIP(nameserver.IP), - ServerPort: 53, - Source: "dhcp", - } - new.clientManager = newDNSClientManager(new) - - if netutils.IPIsLAN(nameserver.IP) && len(nameserver.Search) > 0 { - // only allow searches for local resolvers - var newSearch []string - for _, value := range nameserver.Search { - newSearch = append(newSearch, fmt.Sprintf(".%s.", strings.Trim(value, "."))) - } - new.Search = &newSearch - } - newResolvers = append(newResolvers, new) - } else { - newResolvers = append(newResolvers, globalResolvers[key]) - } + if err != nil { + return nil, err + // FIXME: mark as failed } - // save resolvers - globalResolvers = newResolvers - if len(globalResolvers) == 0 { - log.Criticalf("intel: no (valid) dns servers found in configuration and system") - } - - // make list with local resolvers - localResolvers = make([]*Resolver, 0) - for _, resolver := range globalResolvers { - if resolver.ServerIP != nil && netutils.IPIsLAN(resolver.ServerIP) { - localResolvers = append(localResolvers, resolver) - } - } - - // add resolvers to every scope the cover - localScopes = make([]*Scope, 0) - for _, resolver := range globalResolvers { - - if resolver.Search != nil { - // add resolver to custom searches - for _, search := range *resolver.Search { - if search == "." { - continue - } - key := indexOfScope(search, localScopes) - if key == -1 { - localScopes = append(localScopes, &Scope{ - Domain: search, - Resolvers: []*Resolver{resolver}, - }) - } else { - localScopes[key].Resolvers = append(localScopes[key].Resolvers, resolver) - } - } - - } - } - - // sort scopes by length - sort.Slice(localScopes, - func(i, j int) bool { - return len(localScopes[i].Domain) > len(localScopes[j].Domain) - }, - ) - - log.Trace("intel: loaded global resolvers:") - for _, resolver := range globalResolvers { - log.Tracef("intel: %s", resolver.Server) - } - log.Trace("intel: loaded local resolvers:") - for _, resolver := range localResolvers { - log.Tracef("intel: %s", resolver.Server) - } - log.Trace("intel: loaded scopes:") - for _, scope := range localScopes { - var scopeServers []string - for _, resolver := range scope.Resolvers { - scopeServers = append(scopeServers, resolver.Server) - } - log.Tracef("intel: %s: %s", scope.Domain, strings.Join(scopeServers, ", ")) + // hint network environment at successful connection + environment.ReportSuccessfulConnection() + + new := &RRCache{ + Domain: q.FQDN, + Question: q.QType, + Answer: reply.Answer, + Ns: reply.Ns, + Extra: reply.Extra, + Server: resolver.Server, + ServerScope: resolver.ServerIPScope, } -} - -// resetResolverFailStatus resets all resolver failures. -func resetResolverFailStatus() { - resolversLock.Lock() - defer resolversLock.Unlock() - - log.Tracef("old: %+v %+v, ", globalResolvers, localResolvers) - for _, resolver := range append(globalResolvers, localResolvers...) { - resolver.Lock() - resolver.failReason = "" - resolver.lastFail = 0 - resolver.Unlock() - } - log.Tracef("new: %+v %+v, ", globalResolvers, localResolvers) + // TODO: check if reply.Answer is valid + return new, nil } diff --git a/intel/resolvers.go b/intel/resolvers.go new file mode 100644 index 00000000..9185c8e6 --- /dev/null +++ b/intel/resolvers.go @@ -0,0 +1,357 @@ +package intel + +import ( + "errors" + "fmt" + "net" + "sort" + "strconv" + "strings" + "sync" + + "golang.org/x/net/publicsuffix" + + "github.com/miekg/dns" + "github.com/safing/portbase/log" + "github.com/safing/portmaster/network/environment" + "github.com/safing/portmaster/network/netutils" +) + +// Scope defines a domain scope and which resolvers can resolve it. +type Scope struct { + Domain string + Resolvers []*Resolver +} + +var ( + globalResolvers []*Resolver // all (global) resolvers + localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges + localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope + allResolvers map[string]*Resolver // lookup map of all resolvers + resolversLock sync.RWMutex + + dupReqMap = make(map[string]*sync.WaitGroup) + dupReqLock sync.Mutex +) + +func indexOfResolver(server string, list []*Resolver) int { + for k, v := range list { + if v.Server == server { + return k + } + } + return -1 +} + +func indexOfScope(domain string, list []*Scope) int { + for k, v := range list { + if v.Domain == domain { + return k + } + } + return -1 +} + +func getResolverByIDWithLocking(server string) *Resolver { + resolversLock.Lock() + defer resolversLock.Unlock() + + resolver, ok := allResolvers[server] + if ok { + return resolver + } + return nil +} + +func parseAddress(server string) (net.IP, uint16, error) { + delimiter := strings.LastIndex(server, ":") + if delimiter < 0 { + return nil, 0, errors.New("port missing") + } + ip := net.ParseIP(strings.Trim(server[:delimiter], "[]")) + if ip == nil { + return nil, 0, errors.New("invalid IP address") + } + port, err := strconv.Atoi(server[delimiter+1:]) + if err != nil || port < 1 || port > 65536 { + return nil, 0, errors.New("invalid port") + } + return ip, uint16(port), nil +} + +func urlFormatAddress(ip net.IP, port uint16) string { + var address string + if ipv4 := ip.To4(); ipv4 != nil { + address = fmt.Sprintf("%s:%d", ipv4.String(), port) + } else { + address = fmt.Sprintf("[%s]:%d", ip.String(), port) + } + return address +} + +//nolint:gocyclo,gocognit +func loadResolvers() { + // TODO: what happens when a lot of processes want to reload at once? we do not need to run this multiple times in a short time frame. + resolversLock.Lock() + defer resolversLock.Unlock() + + var newResolvers []*Resolver + +configuredServersLoop: + for _, server := range configuredNameServers() { + key := indexOfResolver(server, newResolvers) + if key >= 0 { + continue configuredServersLoop + } + key = indexOfResolver(server, globalResolvers) + if key == -1 { + + parts := strings.Split(server, "|") + if len(parts) < 2 { + log.Warningf("intel: nameserver format invalid: %s", server) + continue configuredServersLoop + } + + var ipScope int8 + ip, port, err := parseAddress(parts[1]) + if err == nil { + ipScope = netutils.ClassifyIP(ip) + if ipScope == netutils.HostLocal { + log.Warningf(`intel: cannot use configured localhost nameserver "%s"`, server) + continue configuredServersLoop + } + } else { + if strings.ToLower(parts[0]) == "doh" { + ipScope = netutils.Global + } else { + log.Warningf("intel: nameserver (%s) address invalid: %s", server, err) + continue configuredServersLoop + } + } + + // create new structs + newConn := &BasicResolverConn{} + new := &Resolver{ + Server: server, + ServerType: strings.ToLower(parts[0]), + ServerAddress: parts[1], + ServerIP: ip, + ServerIPScope: ipScope, + ServerPort: port, + Source: "config", + Conn: newConn, + } + + // refer back + newConn.resolver = new + + switch new.ServerType { + case "dns": + newConn.clientManager = newDNSClientManager(new) + case "tcp": + newConn.clientManager = newTCPClientManager(new) + case "dot": + if len(parts) < 3 { + log.Warningf("intel: nameserver missing verification domain as third parameter: %s", server) + continue configuredServersLoop + } + new.VerifyDomain = parts[2] + newConn.clientManager = newTLSClientManager(new) + case "doh": + new.SkipFQDN = dns.Fqdn(strings.Split(parts[1], ":")[0]) + if len(parts) > 2 { + new.VerifyDomain = parts[2] + } + newConn.clientManager = newHTTPSClientManager(new) + default: + log.Warningf("intel: nameserver (%s) type invalid: %s", server, parts[0]) + continue configuredServersLoop + } + newResolvers = append(newResolvers, new) + } else { + newResolvers = append(newResolvers, globalResolvers[key]) + } + } + + // add local resolvers + assignedNameservers := environment.Nameservers() +assignedServersLoop: + for _, nameserver := range assignedNameservers { + server := fmt.Sprintf("dns|%s", urlFormatAddress(nameserver.IP, 53)) + key := indexOfResolver(server, newResolvers) + if key >= 0 { + continue assignedServersLoop + } + key = indexOfResolver(server, globalResolvers) + if key == -1 { + + ipScope := netutils.ClassifyIP(nameserver.IP) + if ipScope == netutils.HostLocal { + log.Infof(`intel: cannot use assigned localhost nameserver at %s`, nameserver.IP) + continue assignedServersLoop + } + + // create new structs + newConn := &BasicResolverConn{} + new := &Resolver{ + Server: server, + ServerType: "dns", + ServerAddress: urlFormatAddress(nameserver.IP, 53), + ServerIP: nameserver.IP, + ServerIPScope: ipScope, + ServerPort: 53, + Source: "dhcp", + Conn: newConn, + } + + // refer back + newConn.resolver = new + + // add client manager + newConn.clientManager = newDNSClientManager(new) + + if netutils.IPIsLAN(nameserver.IP) && len(nameserver.Search) > 0 { + // only allow searches for local resolvers + for _, value := range nameserver.Search { + trimmedDomain := strings.Trim(value, ".") + if checkSearchScope(trimmedDomain) { + new.Search = append(new.Search, fmt.Sprintf(".%s.", strings.Trim(value, "."))) + } + } + // cap to mitigate exploitation via malicious local resolver + if len(new.Search) > 100 { + new.Search = new.Search[:100] + } + } + newResolvers = append(newResolvers, new) + } else { + newResolvers = append(newResolvers, globalResolvers[key]) + } + } + + // save resolvers + globalResolvers = newResolvers + if len(globalResolvers) == 0 { + log.Criticalf("intel: no (valid) dns servers found in configuration and system") + } + + // make list with local resolvers + localResolvers = make([]*Resolver, 0) + for _, resolver := range globalResolvers { + if resolver.ServerIP != nil && netutils.IPIsLAN(resolver.ServerIP) { + localResolvers = append(localResolvers, resolver) + } + } + + // add resolvers to every scope the cover + localScopes = make([]*Scope, 0) + for _, resolver := range globalResolvers { + + if resolver.Search != nil { + // add resolver to custom searches + for _, search := range resolver.Search { + if search == "." { + continue + } + key := indexOfScope(search, localScopes) + if key == -1 { + localScopes = append(localScopes, &Scope{ + Domain: search, + Resolvers: []*Resolver{resolver}, + }) + } else { + localScopes[key].Resolvers = append(localScopes[key].Resolvers, resolver) + } + } + + } + } + + // sort scopes by length + sort.Slice(localScopes, + func(i, j int) bool { + return len(localScopes[i].Domain) > len(localScopes[j].Domain) + }, + ) + + // log global resolvers + if len(globalResolvers) > 0 { + log.Trace("intel: loaded global resolvers:") + for _, resolver := range globalResolvers { + log.Tracef("intel: %s", resolver.Server) + } + } else { + log.Warning("intel: no global resolvers loaded") + } + + // log local resolvers + if len(localResolvers) > 0 { + log.Trace("intel: loaded local resolvers:") + for _, resolver := range localResolvers { + log.Tracef("intel: %s", resolver.Server) + } + } else { + log.Info("intel: no local resolvers loaded") + } + + // log scopes + if len(localScopes) > 0 { + log.Trace("intel: loaded scopes:") + for _, scope := range localScopes { + var scopeServers []string + for _, resolver := range scope.Resolvers { + scopeServers = append(scopeServers, resolver.Server) + } + log.Tracef("intel: %s: %s", scope.Domain, strings.Join(scopeServers, ", ")) + } + } else { + log.Info("intel: no scopes loaded") + } + + // alert if no resolvers are loaded + if len(globalResolvers) == 0 && len(localResolvers) == 0 { + log.Critical("intel: no resolvers loaded!") + } +} + +func checkSearchScope(searchDomain string) (ok bool) { + // sanity check + if len(searchDomain) == 0 || + searchDomain[0] == '.' || + searchDomain[len(searchDomain)-1] == '.' { + return false + } + + // add more subdomains to use official publicsuffix package for our cause + searchDomain = "*.*.*.*.*." + searchDomain + + // get suffix + suffix, icann := publicsuffix.PublicSuffix(searchDomain) + // sanity check + if len(suffix) == 0 { + return false + } + // inexistent (custom) tlds are okay + // this will include special service domains! (.onion, .bit, ...) + if !icann && !strings.Contains(suffix, ".") { + return true + } + + // check if suffix is a special service domain (may be handled fully by local nameserver) + if domainInScope("."+suffix+".", specialServiceScopes) { + return true + } + + // build eTLD+1 + split := len(searchDomain) - len(suffix) - 1 + eTLDplus1 := searchDomain[1+strings.LastIndex(searchDomain[:split], "."):] + + // scope check + //nolint:gosimple // want comment + if strings.Contains(eTLDplus1, "*") { + // oops, search domain is too high up the hierarchy + return false + } + + return true +} diff --git a/intel/resolvers_test.go b/intel/resolvers_test.go new file mode 100644 index 00000000..877c03aa --- /dev/null +++ b/intel/resolvers_test.go @@ -0,0 +1,36 @@ +package intel + +import "testing" + +func TestCheckResolverSearchScope(t *testing.T) { + + test := func(t *testing.T, domain string, expectedResult bool) { + if checkSearchScope(domain) != expectedResult { + if expectedResult { + t.Errorf("domain %s failed scope test", domain) + } else { + t.Errorf("domain %s should fail scope test", domain) + } + } + } + + // should fail (invalid) + test(t, ".", false) + test(t, ".com.", false) + test(t, "com.", false) + test(t, ".com", false) + + // should succeed + test(t, "a.com", true) + test(t, "b.a.com", true) + test(t, "c.b.a.com", true) + test(t, "onion", true) + test(t, "a.onion", true) + test(t, "b.a.onion", true) + test(t, "c.b.a.onion", true) + + test(t, "bit", true) + test(t, "a.bit", true) + test(t, "b.a.bit", true) + test(t, "c.b.a.bit", true) +} diff --git a/intel/reverse.go b/intel/reverse.go index fb1251d6..85be2230 100644 --- a/intel/reverse.go +++ b/intel/reverse.go @@ -1,26 +1,32 @@ package intel import ( - "errors" + "context" + "fmt" "strings" - "github.com/safing/portbase/log" "github.com/miekg/dns" + "github.com/safing/portbase/log" ) // ResolveIPAndValidate finds (reverse DNS), validates (forward DNS) and returns the domain name assigned to the given IP. -func ResolveIPAndValidate(ip string, securityLevel uint8) (domain string, err error) { +func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) (domain string, err error) { // get reversed DNS address - rQ, err := dns.ReverseAddr(ip) + reverseIP, err := dns.ReverseAddr(ip) if err != nil { log.Tracef("intel: failed to get reverse address of %s: %s", ip, err) - return "", err + return "", ErrInvalid } // get PTR record - rrCache := Resolve(nil, rQ, dns.Type(dns.TypePTR), securityLevel) - if rrCache == nil { - return "", errors.New("querying for PTR record failed (may be NXDomain)") + q := &Query{ + FQDN: reverseIP, + QType: dns.Type(dns.TypePTR), + SecurityLevel: securityLevel, + } + rrCache, err := Resolve(ctx, q) + if err != nil || rrCache == nil { + return "", fmt.Errorf("failed to resolve %s%s: %w", q.FQDN, q.QType, err) } // get result from record @@ -35,23 +41,27 @@ func ResolveIPAndValidate(ip string, securityLevel uint8) (domain string, err er // check for nxDomain if ptrName == "" { - return "", errors.New("no PTR record for IP (nxDomain)") + return "", fmt.Errorf("%w: %s%s", ErrNotFound, q.FQDN, q.QType) } - log.Infof("ptrName: %s", ptrName) - // get forward record - if strings.Contains(ip, ":") { - rrCache = Resolve(nil, ptrName, dns.Type(dns.TypeAAAA), securityLevel) - } else { - rrCache = Resolve(nil, ptrName, dns.Type(dns.TypeA), securityLevel) + q = &Query{ + FQDN: ptrName, + SecurityLevel: securityLevel, } - if rrCache == nil { - return "", errors.New("querying for A/AAAA record failed (may be NXDomain)") + // IPv4/6 switch + if strings.Contains(ip, ":") { + q.QType = dns.Type(dns.TypeAAAA) + } else { + q.QType = dns.Type(dns.TypeA) + } + // resolve + rrCache, err = Resolve(ctx, q) + if err != nil || rrCache == nil { + return "", fmt.Errorf("failed to resolve %s%s: %w", q.FQDN, q.QType, err) } // check for matching A/AAAA record - log.Infof("rr: %s", rrCache) for _, rr := range rrCache.Answer { switch v := rr.(type) { case *dns.A: @@ -68,5 +78,5 @@ func ResolveIPAndValidate(ip string, securityLevel uint8) (domain string, err er } // no match - return "", errors.New("validation failed") + return "", ErrBlocked } diff --git a/intel/reverse_test.go b/intel/reverse_test.go index e53d1a6f..533edd5d 100644 --- a/intel/reverse_test.go +++ b/intel/reverse_test.go @@ -1,10 +1,19 @@ package intel -import "testing" +import ( + "context" + "testing" + + "github.com/safing/portbase/log" +) func testReverse(t *testing.T, ip, result, expectedErr string) { - domain, err := ResolveIPAndValidate(ip, 0) + ctx, tracer := log.AddTracer(context.Background()) + defer tracer.Submit() + + domain, err := ResolveIPAndValidate(ctx, ip, 0) if err != nil { + tracer.Warning(err.Error()) if expectedErr == "" || err.Error() != expectedErr { t.Errorf("reverse-validating %s: unexpected error: %s", ip, err) } @@ -18,11 +27,11 @@ func testReverse(t *testing.T, ip, result, expectedErr string) { func TestResolveIPAndValidate(t *testing.T) { testReverse(t, "198.41.0.4", "a.root-servers.net.", "") - testReverse(t, "9.9.9.9", "dns.quad9.net.", "") + // testReverse(t, "9.9.9.9", "dns.quad9.net.", "") // started resolving to dns9.quad9.net. testReverse(t, "2620:fe::fe", "dns.quad9.net.", "") testReverse(t, "1.1.1.1", "one.one.one.one.", "") testReverse(t, "2606:4700:4700::1111", "one.one.one.one.", "") - testReverse(t, "93.184.216.34", "example.com.", "no PTR record for IP (nxDomain)") - testReverse(t, "185.199.109.153", "sites.github.io.", "no PTR record for IP (nxDomain)") + testReverse(t, "93.184.216.34", "example.com.", "record does not exist: 34.216.184.93.in-addr.arpa.PTR") + testReverse(t, "185.199.109.153", "sites.github.io.", "record does not exist: 153.109.199.185.in-addr.arpa.PTR") } diff --git a/intel/rrcache.go b/intel/rrcache.go index c99f8b4b..c828433b 100644 --- a/intel/rrcache.go +++ b/intel/rrcache.go @@ -2,40 +2,62 @@ package intel import ( "fmt" + "math/rand" "net" + "sync" "time" "github.com/miekg/dns" ) // RRCache is used to cache DNS data +//nolint:maligned // TODO type RRCache struct { - Domain string - Question dns.Type + sync.Mutex - Answer []dns.RR - Ns []dns.RR - Extra []dns.RR - TTL int64 + Domain string // constant + Question dns.Type // constant - Server string - ServerScope int8 + Answer []dns.RR // might be mixed + Ns []dns.RR // constant + Extra []dns.RR // constant + TTL int64 // constant - updated int64 - servedFromCache bool - requestingNew bool - Filtered bool - FilteredEntries []string + Server string // constant + ServerScope int8 // constant + + servedFromCache bool // mutable + requestingNew bool // mutable + Filtered bool // mutable + FilteredEntries []string // mutable + + updated int64 // mutable +} + +// Expired returns whether the record has expired. +func (rrCache *RRCache) Expired() bool { + return rrCache.TTL <= time.Now().Unix() +} + +// MixAnswers randomizes the answer records to allow dumb clients (who only look at the first record) to reliably connect. +func (rrCache *RRCache) MixAnswers() { + rrCache.Lock() + defer rrCache.Unlock() + + for i := range rrCache.Answer { + j := rand.Intn(i + 1) + rrCache.Answer[i], rrCache.Answer[j] = rrCache.Answer[j], rrCache.Answer[i] + } } // Clean sets all TTLs to 17 and sets cache expiry with specified minimum. -func (m *RRCache) Clean(minExpires uint32) { +func (rrCache *RRCache) Clean(minExpires uint32) { var lowestTTL uint32 = 0xFFFFFFFF var header *dns.RR_Header // set TTLs to 17 // TODO: double append? is there something more elegant? - for _, rr := range append(m.Answer, append(m.Ns, m.Extra...)...) { + for _, rr := range append(rrCache.Answer, append(rrCache.Ns, rrCache.Extra...)...) { header = rr.Header() if lowestTTL > header.Ttl { lowestTTL = header.Ttl @@ -49,12 +71,12 @@ func (m *RRCache) Clean(minExpires uint32) { } // log.Tracef("lowest TTL is %d", lowestTTL) - m.TTL = time.Now().Unix() + int64(lowestTTL) + rrCache.TTL = time.Now().Unix() + int64(lowestTTL) } // ExportAllARecords return of a list of all A and AAAA IP addresses. -func (m *RRCache) ExportAllARecords() (ips []net.IP) { - for _, rr := range m.Answer { +func (rrCache *RRCache) ExportAllARecords() (ips []net.IP) { + for _, rr := range rrCache.Answer { if rr.Header().Class != dns.ClassINET { continue } @@ -76,23 +98,23 @@ func (m *RRCache) ExportAllARecords() (ips []net.IP) { } // ToNameRecord converts the RRCache to a NameRecord for cleaner persistence. -func (m *RRCache) ToNameRecord() *NameRecord { +func (rrCache *RRCache) ToNameRecord() *NameRecord { new := &NameRecord{ - Domain: m.Domain, - Question: m.Question.String(), - TTL: m.TTL, - Server: m.Server, - ServerScope: m.ServerScope, + Domain: rrCache.Domain, + Question: rrCache.Question.String(), + TTL: rrCache.TTL, + Server: rrCache.Server, + ServerScope: rrCache.ServerScope, } // stringify RR entries - for _, entry := range m.Answer { + for _, entry := range rrCache.Answer { new.Answer = append(new.Answer, entry.String()) } - for _, entry := range m.Ns { + for _, entry := range rrCache.Ns { new.Ns = append(new.Ns, entry.String()) } - for _, entry := range m.Extra { + for _, entry := range rrCache.Extra { new.Extra = append(new.Extra, entry.String()) } @@ -100,8 +122,8 @@ func (m *RRCache) ToNameRecord() *NameRecord { } // Save saves the RRCache to the database as a NameRecord. -func (m *RRCache) Save() error { - return m.ToNameRecord().Save() +func (rrCache *RRCache) Save() error { + return rrCache.ToNameRecord().Save() } // GetRRCache tries to load the corresponding NameRecord from the database and convert it. @@ -143,25 +165,25 @@ func GetRRCache(domain string, question dns.Type) (*RRCache, error) { } // ServedFromCache marks the RRCache as served from cache. -func (m *RRCache) ServedFromCache() bool { - return m.servedFromCache +func (rrCache *RRCache) ServedFromCache() bool { + return rrCache.servedFromCache } // RequestingNew informs that it has expired and new RRs are being fetched. -func (m *RRCache) RequestingNew() bool { - return m.requestingNew +func (rrCache *RRCache) RequestingNew() bool { + return rrCache.requestingNew } // Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format. -func (m *RRCache) Flags() string { +func (rrCache *RRCache) Flags() string { var s string - if m.servedFromCache { + if rrCache.servedFromCache { s += "C" } - if m.requestingNew { + if rrCache.requestingNew { s += "R" } - if m.Filtered { + if rrCache.Filtered { s += "F" } @@ -172,27 +194,27 @@ func (m *RRCache) Flags() string { } // IsNXDomain returnes whether the result is nxdomain. -func (m *RRCache) IsNXDomain() bool { - return len(m.Answer) == 0 +func (rrCache *RRCache) IsNXDomain() bool { + return len(rrCache.Answer) == 0 } // ShallowCopy returns a shallow copy of the cache. slices are not copied, but referenced. -func (m *RRCache) ShallowCopy() *RRCache { +func (rrCache *RRCache) ShallowCopy() *RRCache { return &RRCache{ - Domain: m.Domain, - Question: m.Question, - Answer: m.Answer, - Ns: m.Ns, - Extra: m.Extra, - TTL: m.TTL, + Domain: rrCache.Domain, + Question: rrCache.Question, + Answer: rrCache.Answer, + Ns: rrCache.Ns, + Extra: rrCache.Extra, + TTL: rrCache.TTL, - Server: m.Server, - ServerScope: m.ServerScope, + Server: rrCache.Server, + ServerScope: rrCache.ServerScope, - updated: m.updated, - servedFromCache: m.servedFromCache, - requestingNew: m.requestingNew, - Filtered: m.Filtered, - FilteredEntries: m.FilteredEntries, + updated: rrCache.updated, + servedFromCache: rrCache.servedFromCache, + requestingNew: rrCache.requestingNew, + Filtered: rrCache.Filtered, + FilteredEntries: rrCache.FilteredEntries, } } diff --git a/intel/special.go b/intel/special.go deleted file mode 100644 index e2301df3..00000000 --- a/intel/special.go +++ /dev/null @@ -1,51 +0,0 @@ -package intel - -import "strings" - -var ( - localReverseScopes = []string{ - ".10.in-addr.arpa.", - ".16.172.in-addr.arpa.", - ".17.172.in-addr.arpa.", - ".18.172.in-addr.arpa.", - ".19.172.in-addr.arpa.", - ".20.172.in-addr.arpa.", - ".21.172.in-addr.arpa.", - ".22.172.in-addr.arpa.", - ".23.172.in-addr.arpa.", - ".24.172.in-addr.arpa.", - ".25.172.in-addr.arpa.", - ".26.172.in-addr.arpa.", - ".27.172.in-addr.arpa.", - ".28.172.in-addr.arpa.", - ".29.172.in-addr.arpa.", - ".30.172.in-addr.arpa.", - ".31.172.in-addr.arpa.", - ".168.192.in-addr.arpa.", - ".254.169.in-addr.arpa.", - ".8.e.f.ip6.arpa.", - ".9.e.f.ip6.arpa.", - ".a.e.f.ip6.arpa.", - ".b.e.f.ip6.arpa.", - } - - // RFC6761, RFC7686 - specialScopes = []string{ - ".example.", - ".example.com.", - ".example.net.", - ".example.org.", - ".invalid.", - ".test.", - ".onion.", - } -) - -func domainInScopes(fqdn string, list []string) bool { - for _, scope := range list { - if strings.HasSuffix(fqdn, scope) { - return true - } - } - return false -} diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index d2e1aa5a..00799984 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -3,7 +3,9 @@ package nameserver import ( "context" "net" - "runtime" + "strings" + + "github.com/safing/portmaster/network/environment" "github.com/miekg/dns" @@ -18,23 +20,20 @@ import ( ) var ( - localhostIPs []dns.RR -) + module *modules.Module + dnsServer *dns.Server + mtDNSRequest = "dns request" -var ( - listenAddress = "127.0.0.1:53" - localhostIP = net.IPv4(127, 0, 0, 1) + listenAddress = "0.0.0.0:53" + IPv4Localhost = net.IPv4(127, 0, 0, 1) + localhostRRs []dns.RR ) func init() { - modules.Register("nameserver", prep, start, nil, "core", "intel") - - if runtime.GOOS == "windows" { - listenAddress = "0.0.0.0:53" - } + module = modules.Register("nameserver", initLocalhostRRs, start, stop, "core", "intel", "network") } -func prep() error { +func initLocalhostRRs() error { localhostIPv4, err := dns.NewRR("localhost. 17 IN A 127.0.0.1") if err != nil { return err @@ -45,153 +44,202 @@ func prep() error { return err } - localhostIPs = []dns.RR{localhostIPv4, localhostIPv6} - + localhostRRs = []dns.RR{localhostIPv4, localhostIPv6} return nil } func start() error { - server := &dns.Server{Addr: listenAddress, Net: "udp"} - dns.HandleFunc(".", handleRequest) - go run(server) + dnsServer = &dns.Server{Addr: listenAddress, Net: "udp"} + dns.HandleFunc(".", handleRequestAsMicroTask) + + module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error { + err := dnsServer.ListenAndServe() + if err != nil { + // check if we are shutting down + if module.ShutdownInProgress() { + return nil + } + // is something blocking our port? + checkErr := checkForConflictingService() + if checkErr != nil { + return checkErr + } + } + return err + }) + return nil } -func run(server *dns.Server) { - for { - err := server.ListenAndServe() - if err != nil { - log.Errorf("nameserver: server failed: %s", err) - checkForConflictingService(err) - } +func stop() error { + if dnsServer != nil { + return dnsServer.Shutdown() + } + return nil +} + +func returnNXDomain(w dns.ResponseWriter, query *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeNameError) + _ = w.WriteMsg(m) +} + +func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeServerFailure) + _ = w.WriteMsg(m) +} + +func handleRequestAsMicroTask(w dns.ResponseWriter, query *dns.Msg) { + err := module.RunMicroTask(&mtDNSRequest, func(ctx context.Context) error { + return handleRequest(ctx, w, query) + }) + if err != nil { + log.Warningf("intel: failed to handle dns request: %s", err) } } -func nxDomain(w dns.ResponseWriter, query *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(query, dns.RcodeNameError) - w.WriteMsg(m) -} - -func handleRequest(w dns.ResponseWriter, query *dns.Msg) { +func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) error { + // return with server failure if offline + if environment.GetOnlineStatus() == environment.StatusOffline { + returnServerFailure(w, query) + return nil + } // only process first question, that's how everyone does it. question := query.Question[0] - fqdn := dns.Fqdn(question.Name) - qtype := dns.Type(question.Qtype) + q := &intel.Query{ + FQDN: question.Name, + QType: dns.Type(question.Qtype), + } // check class if question.Qclass != dns.ClassINET { // we only serve IN records, return nxdomain - nxDomain(w, query) - return + returnNXDomain(w, query) + return nil } // handle request for localhost - if fqdn == "localhost." { + if strings.HasSuffix(q.FQDN, "localhost.") { m := new(dns.Msg) m.SetReply(query) - m.Answer = localhostIPs - w.WriteMsg(m) + m.Answer = localhostRRs + _ = w.WriteMsg(m) + return nil } // get addresses remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr) if !ok { - log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", fqdn, qtype) - return + log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", q.FQDN, q.QType) + return nil } - if !remoteAddr.IP.Equal(localhostIP) { + if !remoteAddr.IP.Equal(IPv4Localhost) { // if request is not coming from 127.0.0.1, check if it's really local localAddr, ok := w.RemoteAddr().(*net.UDPAddr) if !ok { - log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", fqdn, qtype) - return + log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", q.FQDN, q.QType) + return nil } // ignore external request if !remoteAddr.IP.Equal(localAddr.IP) { - log.Warningf("nameserver: external request for %s%s, ignoring", fqdn, qtype) - return + log.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType) + return nil } } // check if valid domain name - if !netutils.IsValidFqdn(fqdn) { - log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", fqdn) - nxDomain(w, query) - return + if !netutils.IsValidFqdn(q.FQDN) { + log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN) + returnNXDomain(w, query) + return nil } // start tracer - ctx := log.AddTracer(context.Background()) - log.Tracer(ctx).Tracef("nameserver: handling new request for %s%s from %s:%d", fqdn, qtype, remoteAddr.IP, remoteAddr.Port) + ctx, tracer := log.AddTracer(ctx) + tracer.Tracef("nameserver: handling new request for %s%s from %s:%d", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain // get connection - comm, err := network.GetCommunicationByDNSRequest(ctx, remoteAddr.IP, uint16(remoteAddr.Port), fqdn) + comm, err := network.GetCommunicationByDNSRequest(ctx, remoteAddr.IP, uint16(remoteAddr.Port), q.FQDN) if err != nil { - log.ErrorTracef(ctx, "nameserver: could not identify process of %s:%d, returning nxdomain: %s", remoteAddr.IP, remoteAddr.Port, err) - nxDomain(w, query) - return + tracer.Errorf("nameserver: could not identify process of %s:%d, returning nxdomain: %s", remoteAddr.IP, remoteAddr.Port, err) + returnNXDomain(w, query) + return nil } defer func() { go comm.SaveIfNeeded() }() + // save security level to query + q.SecurityLevel = comm.Process().ProfileSet().SecurityLevel() + // check for possible DNS tunneling / data transmission // TODO: improve this - lms := algs.LmsScoreOfDomain(fqdn) + lms := algs.LmsScoreOfDomain(q.FQDN) // log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) if lms < 10 { - log.WarningTracef(ctx, "nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", comm.Process(), fqdn, lms) - nxDomain(w, query) - return + tracer.Warningf("nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", comm.Process(), q.FQDN, lms) + returnNXDomain(w, query) + return nil } // check profile before we even get intel and rr - firewall.DecideOnCommunicationBeforeIntel(comm, fqdn) + firewall.DecideOnCommunicationBeforeIntel(comm, q.FQDN) comm.Lock() comm.SaveWhenFinished() comm.Unlock() if comm.GetVerdict() == network.VerdictBlock || comm.GetVerdict() == network.VerdictDrop { - log.InfoTracef(ctx, "nameserver: %s denied before intel, returning nxdomain", comm) - nxDomain(w, query) - return + tracer.Infof("nameserver: %s denied before intel, returning nxdomain", comm) + returnNXDomain(w, query) + return nil } // get intel and RRs - domainIntel, rrCache := intel.GetIntelAndRRs(ctx, fqdn, qtype, comm.Process().ProfileSet().SecurityLevel()) - if rrCache == nil { + rrCache, err := intel.Resolve(ctx, q) + if err != nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains - log.WarningTracef(ctx, "nameserver: %s requested %s%s, is nxdomain", comm.Process(), fqdn, qtype) - nxDomain(w, query) - return + tracer.Warningf("nameserver: %s requested %s%s: %s", comm.Process(), q.FQDN, q.QType, err) + returnNXDomain(w, query) + return nil } - // set intel + // get current intel comm.Lock() - comm.Intel = domainIntel + domainIntel := comm.Intel comm.Unlock() + if domainIntel == nil { + // fetch intel + domainIntel, err = intel.GetIntel(ctx, q) + if err != nil { + tracer.Warningf("nameserver: failed to get intel for %s%s: %s", q.FQDN, q.QType, err) + returnNXDomain(w, query) + } + comm.Lock() + comm.Intel = domainIntel + comm.Unlock() + } // check with intel - firewall.DecideOnCommunicationAfterIntel(comm, fqdn, rrCache) + firewall.DecideOnCommunicationAfterIntel(comm, q.FQDN, rrCache) switch comm.GetVerdict() { case network.VerdictUndecided, network.VerdictBlock, network.VerdictDrop: - log.InfoTracef(ctx, "nameserver: %s denied after intel, returning nxdomain", comm) - nxDomain(w, query) - return + tracer.Infof("nameserver: %s denied after intel, returning nxdomain", comm) + returnNXDomain(w, query) + return nil } // filter DNS response - rrCache = firewall.FilterDNSResponse(comm, fqdn, rrCache) + rrCache = firewall.FilterDNSResponse(comm, q, rrCache) if rrCache == nil { - log.InfoTracef(ctx, "nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", comm) - nxDomain(w, query) - return + tracer.Infof("nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", comm) + returnNXDomain(w, query) + return nil } // save IP addresses to IPInfo @@ -202,12 +250,13 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { if err != nil { ipInfo = &intel.IPInfo{ IP: v.A.String(), - Domains: []string{fqdn}, + Domains: []string{q.FQDN}, } - ipInfo.Save() + _ = ipInfo.Save() } else { - if ipInfo.AddDomain(fqdn) { - ipInfo.Save() + added := ipInfo.AddDomain(q.FQDN) + if added { + _ = ipInfo.Save() } } case *dns.AAAA: @@ -215,12 +264,13 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { if err != nil { ipInfo = &intel.IPInfo{ IP: v.AAAA.String(), - Domains: []string{fqdn}, + Domains: []string{q.FQDN}, } - ipInfo.Save() + _ = ipInfo.Save() } else { - if ipInfo.AddDomain(fqdn) { - ipInfo.Save() + added := ipInfo.AddDomain(q.FQDN) + if added { + _ = ipInfo.Save() } } } @@ -232,6 +282,8 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { m.Answer = rrCache.Answer m.Ns = rrCache.Ns m.Extra = rrCache.Extra - w.WriteMsg(m) - log.DebugTracef(ctx, "nameserver: returning response %s%s to %s", fqdn, qtype, comm.Process()) + _ = w.WriteMsg(m) + tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, comm.Process()) + + return nil } diff --git a/nameserver/only/nameserver.go b/nameserver/only/nameserver.go index 8b6c2ba1..1a392bf3 100644 --- a/nameserver/only/nameserver.go +++ b/nameserver/only/nameserver.go @@ -3,30 +3,34 @@ package only import ( "context" "net" - "time" + "strings" + + "github.com/safing/portmaster/network/environment" "github.com/miekg/dns" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" - "github.com/safing/portmaster/analytics/algs" "github.com/safing/portmaster/intel" "github.com/safing/portmaster/network/netutils" - "github.com/safing/portmaster/status" ) var ( - localhostIPs []dns.RR + module *modules.Module + dnsServer *dns.Server + mtDNSRequest = "dns request" + + listenAddress = "127.0.0.1:53" + IPv4Localhost = net.IPv4(127, 0, 0, 1) + localhostRRs []dns.RR ) func init() { - modules.Register("nameserver", prep, start, nil, "core", "intel") + module = modules.Register("nameserver", initLocalhostRRs, start, stop, "core", "intel", "network") } -func prep() error { - intel.SetLocalAddrFactory(func(network string) net.Addr { return nil }) - +func initLocalhostRRs() error { localhostIPv4, err := dns.NewRR("localhost. 17 IN A 127.0.0.1") if err != nil { return err @@ -37,107 +41,128 @@ func prep() error { return err } - localhostIPs = []dns.RR{localhostIPv4, localhostIPv6} - + localhostRRs = []dns.RR{localhostIPv4, localhostIPv6} return nil } func start() error { - server := &dns.Server{Addr: "0.0.0.0:53", Net: "udp"} - dns.HandleFunc(".", handleRequest) - go run(server) + dnsServer = &dns.Server{Addr: listenAddress, Net: "udp"} + dns.HandleFunc(".", handleRequestAsMicroTask) + + module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error { + err := dnsServer.ListenAndServe() + if err != nil { + // check if we are shutting down + if module.ShutdownInProgress() { + return nil + } + } + return err + }) + return nil } -func run(server *dns.Server) { - for { - err := server.ListenAndServe() - if err != nil { - log.Errorf("nameserver: server failed: %s", err) - log.Info("nameserver: restarting server in 10 seconds") - time.Sleep(10 * time.Second) - } +func stop() error { + if dnsServer != nil { + return dnsServer.Shutdown() + } + return nil +} + +func returnNXDomain(w dns.ResponseWriter, query *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeNameError) + _ = w.WriteMsg(m) +} + +func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeServerFailure) + _ = w.WriteMsg(m) +} + +func handleRequestAsMicroTask(w dns.ResponseWriter, query *dns.Msg) { + err := module.RunMicroTask(&mtDNSRequest, func(ctx context.Context) error { + return handleRequest(ctx, w, query) + }) + if err != nil { + log.Warningf("intel: failed to handle dns request: %s", err) } } -func nxDomain(w dns.ResponseWriter, query *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(query, dns.RcodeNameError) - w.WriteMsg(m) -} - -func handleRequest(w dns.ResponseWriter, query *dns.Msg) { +func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) error { + // return with server failure if offline + if environment.GetOnlineStatus() == environment.StatusOffline { + returnServerFailure(w, query) + return nil + } // only process first question, that's how everyone does it. question := query.Question[0] - fqdn := dns.Fqdn(question.Name) - qtype := dns.Type(question.Qtype) + q := &intel.Query{ + FQDN: question.Name, + QType: dns.Type(question.Qtype), + } // check class if question.Qclass != dns.ClassINET { // we only serve IN records, return nxdomain - nxDomain(w, query) - return + returnNXDomain(w, query) + return nil } // handle request for localhost - if fqdn == "localhost." { + if strings.HasSuffix(q.FQDN, "localhost.") { m := new(dns.Msg) m.SetReply(query) - m.Answer = localhostIPs - w.WriteMsg(m) + m.Answer = localhostRRs + _ = w.WriteMsg(m) + return nil } // get addresses remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr) if !ok { - log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", fqdn, qtype) - return - } - localAddr, ok := w.RemoteAddr().(*net.UDPAddr) - if !ok { - log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", fqdn, qtype) - return + log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", q.FQDN, q.QType) + return nil } + if !remoteAddr.IP.Equal(IPv4Localhost) { + // if request is not coming from 127.0.0.1, check if it's really local - // ignore external request - if !remoteAddr.IP.Equal(localAddr.IP) { - log.Warningf("nameserver: external request for %s%s, ignoring", fqdn, qtype) - return + localAddr, ok := w.RemoteAddr().(*net.UDPAddr) + if !ok { + log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", q.FQDN, q.QType) + return nil + } + + // ignore external request + if !remoteAddr.IP.Equal(localAddr.IP) { + log.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType) + return nil + } } // check if valid domain name - if !netutils.IsValidFqdn(fqdn) { - log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", fqdn) - nxDomain(w, query) - return + if !netutils.IsValidFqdn(q.FQDN) { + log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN) + returnNXDomain(w, query) + return nil } // start tracer - ctx := log.AddTracer(context.Background()) - log.Tracer(ctx).Tracef("nameserver: handling new request for %s%s from %s:%d", fqdn, qtype, remoteAddr.IP, remoteAddr.Port) + ctx, tracer := log.AddTracer(ctx) + tracer.Tracef("nameserver: handling new request for %s%s from %s:%d", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain - // check for possible DNS tunneling / data transmission - // TODO: improve this - lms := algs.LmsScoreOfDomain(fqdn) - // log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) - if lms < 10 { - log.WarningTracef(ctx, "nameserver: possible data tunnel by %s:%d: %s has lms score of %f, returning nxdomain", remoteAddr.IP, remoteAddr.Port, fqdn, lms) - nxDomain(w, query) - return - } - // get intel and RRs - // start = time.Now() - _, rrCache := intel.GetIntelAndRRs(ctx, fqdn, qtype, status.SecurityLevelDynamic) - // log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start)) - if rrCache == nil { + rrCache, err := intel.Resolve(ctx, q) + if err != nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains - log.WarningTracef(ctx, "nameserver: %s:%d requested %s%s, is nxdomain", remoteAddr.IP, remoteAddr.Port, fqdn, qtype) - nxDomain(w, query) - return + tracer.Warningf("nameserver: request for %s%s: %s", q.FQDN, q.QType, err) + returnNXDomain(w, query) + return nil } // save IP addresses to IPInfo @@ -148,12 +173,13 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { if err != nil { ipInfo = &intel.IPInfo{ IP: v.A.String(), - Domains: []string{fqdn}, + Domains: []string{q.FQDN}, } - ipInfo.Save() + _ = ipInfo.Save() } else { - if ipInfo.AddDomain(fqdn) { - ipInfo.Save() + added := ipInfo.AddDomain(q.FQDN) + if added { + _ = ipInfo.Save() } } case *dns.AAAA: @@ -161,12 +187,13 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { if err != nil { ipInfo = &intel.IPInfo{ IP: v.AAAA.String(), - Domains: []string{fqdn}, + Domains: []string{q.FQDN}, } - ipInfo.Save() + _ = ipInfo.Save() } else { - if ipInfo.AddDomain(fqdn) { - ipInfo.Save() + added := ipInfo.AddDomain(q.FQDN) + if added { + _ = ipInfo.Save() } } } @@ -178,6 +205,8 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { m.Answer = rrCache.Answer m.Ns = rrCache.Ns m.Extra = rrCache.Extra - w.WriteMsg(m) - log.DebugTracef(ctx, "nameserver: returning response %s%s to %s:%d", fqdn, qtype, remoteAddr.IP, remoteAddr.Port) + _ = w.WriteMsg(m) + tracer.Debugf("nameserver: returning response %s%s", q.FQDN, q.QType) + + return nil } diff --git a/nameserver/takeover.go b/nameserver/takeover.go index ab245b05..e55aa46c 100644 --- a/nameserver/takeover.go +++ b/nameserver/takeover.go @@ -7,20 +7,45 @@ import ( "time" "github.com/safing/portbase/log" + "github.com/safing/portbase/modules" "github.com/safing/portbase/notifications" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/process" ) -func checkForConflictingService(err error) { - pid, err := takeover() - if err != nil || pid == 0 { - log.Info("nameserver: restarting server in 10 seconds") - time.Sleep(10 * time.Second) - return +var ( + otherResolverIPs = []net.IP{ + net.IPv4(127, 0, 0, 1), // default + net.IPv4(127, 0, 0, 53), // some resolvers on Linux + } +) + +func checkForConflictingService() error { + var pid int + var err error + + // check multiple IPs for other resolvers + for _, resolverIP := range otherResolverIPs { + pid, err = takeover(resolverIP) + if err == nil && pid != 0 { + break + } + } + // handle returns + if err != nil { + log.Infof("nameserver: could not stop conflicting service: %s", err) + // leave original service-worker error intact + return nil + } + if pid == 0 { + // no conflicting service identified + return nil } - log.Infof("nameserver: stopped conflicting name service with pid %d", pid) + // we killed something! + + // wait for a short duration for the other service to shut down + time.Sleep(10 * time.Millisecond) // notify user (¬ifications.Notification{ @@ -28,15 +53,14 @@ func checkForConflictingService(err error) { Message: fmt.Sprintf("Portmaster stopped a conflicting name service (pid %d) to gain required system integration.", pid), }).Save() - // wait for a short duration for the other service to shut down - time.Sleep(100 * time.Millisecond) + // restart via service-worker logic + return fmt.Errorf("%w: stopped conflicting name service with pid %d", modules.ErrRestartNow, pid) } -func takeover() (int, error) { - pid, _, err := process.GetPidByEndpoints(net.IPv4(127, 0, 0, 1), 53, net.IPv4(127, 0, 0, 1), 65535, packet.UDP) +func takeover(resolverIP net.IP) (int, error) { + pid, _, err := process.GetPidByEndpoints(resolverIP, 53, resolverIP, 65535, packet.UDP) if err != nil { // there may be nothing listening on :53 - log.Tracef("nameserver: expected conflicting name service, but could not find anything listenting on :53") return 0, nil }