From 53eb309e72fa1220f94a9801749b6a4f0b2ddc96 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 22:43:06 +0200 Subject: [PATCH] Add TLS resolver connection reusing and pooling Also, fix caching issues and add more tests --- ...{block_detection.go => block-detection.go} | 0 resolver/clients.go | 176 ++++++++++++++--- resolver/mdns.go | 29 ++- resolver/namerecord.go | 2 +- resolver/namerecord_test.go | 27 +++ resolver/pooling_test.go | 184 ++++++++++++++++++ resolver/resolve.go | 6 +- resolver/resolver.go | 53 ++++- resolver/resolvers.go | 49 +++-- resolver/reverse.go | 4 +- resolver/rrcache_test.go | 41 ++++ 11 files changed, 510 insertions(+), 61 deletions(-) rename resolver/{block_detection.go => block-detection.go} (100%) create mode 100644 resolver/namerecord_test.go create mode 100644 resolver/pooling_test.go create mode 100644 resolver/rrcache_test.go diff --git a/resolver/block_detection.go b/resolver/block-detection.go similarity index 100% rename from resolver/block_detection.go rename to resolver/block-detection.go diff --git a/resolver/clients.go b/resolver/clients.go index 6d1ad4b2..096f2af3 100644 --- a/resolver/clients.go +++ b/resolver/clients.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "crypto/tls" "net" "sync" @@ -9,6 +10,12 @@ import ( "github.com/miekg/dns" ) +const ( + defaultClientTTL = 5 * time.Minute + defaultRequestTimeout = 5 * time.Second + connectionEOLGracePeriod = 10 * time.Second +) + var ( localAddrFactory func(network string) net.Addr ) @@ -27,21 +34,72 @@ func getLocalAddr(network string) net.Addr { return nil } -type clientManager struct { - dnsClient *dns.Client - factory func() *dns.Client +type dnsClientManager struct { + lock sync.Mutex - lock sync.Mutex - refreshAfter time.Time - ttl time.Duration // force refresh of connection to reduce traceability + // set by creator + serverAddress string + ttl time.Duration // force refresh of connection to reduce traceability + factory func() *dns.Client + + // internal + pool []*dnsClient } -func newDNSClientManager(_ *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // new client for every request, as we need to randomize the port +type dnsClient struct { + mgr *dnsClientManager + + inUse bool + useUntil time.Time + dead bool + inPool bool + poolIndex int + + client *dns.Client + conn *dns.Conn +} + +// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). +func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { + if dc.conn == nil { + dc.conn, err = dc.client.Dial(dc.mgr.serverAddress) + if err != nil { + return nil, false, err + } + return dc.conn, true, nil + } + return dc.conn, false, nil +} + +func (dc *dnsClient) done() { + dc.mgr.lock.Lock() + defer dc.mgr.lock.Unlock() + + dc.inUse = false +} + +func (dc *dnsClient) destroy() { + dc.mgr.lock.Lock() + dc.inUse = true // block from being used + dc.dead = true // abort cleaning + if dc.inPool { + dc.inPool = false + dc.mgr.pool[dc.poolIndex] = nil + } + dc.mgr.lock.Unlock() + + if dc.conn != nil { + _ = dc.conn.Close() + } +} + +func newDNSClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: 0, // new client for every request, as we need to randomize the port factory: func() *dns.Client { return &dns.Client{ - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("udp"), }, @@ -50,25 +108,27 @@ func newDNSClientManager(_ *Resolver) *clientManager { } } -func newTCPClientManager(_ *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) +func newTCPClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: defaultClientTTL, factory: func() *dns.Client { return &dns.Client{ Net: "tcp", - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), - KeepAlive: 15 * time.Second, + KeepAlive: defaultClientTTL, }, } }, } } -func newTLSClientManager(resolver *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) +func newTLSClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: defaultClientTTL, factory: func() *dns.Client { return &dns.Client{ Net: "tcp-tls", @@ -77,24 +137,90 @@ func newTLSClientManager(resolver *Resolver) *clientManager { ServerName: resolver.VerifyDomain, // TODO: use portbase rng }, - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), - KeepAlive: 15 * time.Second, + KeepAlive: defaultClientTTL, }, } }, } } -func (cm *clientManager) getDNSClient() *dns.Client { +func (cm *dnsClientManager) getDNSClient() *dnsClient { cm.lock.Lock() defer cm.lock.Unlock() - if cm.dnsClient == nil || cm.ttl == 0 || time.Now().After(cm.refreshAfter) { - cm.dnsClient = cm.factory() - cm.refreshAfter = time.Now().Add(cm.ttl) + // return new immediately if a new client should be used for every request + if cm.ttl == 0 { + return &dnsClient{ + mgr: cm, + client: cm.factory(), + } } - return cm.dnsClient + // get first unused from pool + now := time.Now().UTC() + for _, dc := range cm.pool { + if dc != nil && !dc.inUse && now.Before(dc.useUntil) { + dc.inUse = true + return dc + } + } + + // no available in pool, create new + newClient := &dnsClient{ + mgr: cm, + inUse: true, + useUntil: now.Add(cm.ttl), + inPool: true, + client: cm.factory(), + } + newClient.startCleaner() + + // find free spot in pool + for poolIndex, dc := range cm.pool { + if dc == nil { + cm.pool[poolIndex] = newClient + newClient.poolIndex = poolIndex + return newClient + } + } + + // append to pool + cm.pool = append(cm.pool, newClient) + newClient.poolIndex = len(cm.pool) - 1 + // TODO: shrink pool again? + + return newClient +} + +// startCleaner waits for EOL of the client and then removes it from the pool. +func (dc *dnsClient) startCleaner() { + // While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone. + module.StartWorker("dns client cleanup", func(ctx context.Context) error { + select { + case <-time.After(dc.mgr.ttl + time.Second): + dc.mgr.lock.Lock() + cleanNow := dc.dead || !dc.inUse + dc.mgr.lock.Unlock() + + if cleanNow { + dc.destroy() + return nil + } + case <-ctx.Done(): + // give a short time before kill for graceful request completion + time.Sleep(100 * time.Millisecond) + } + + // wait for grace period to end, then kill + select { + case <-time.After(connectionEOLGracePeriod): + case <-ctx.Done(): + } + + dc.destroy() + return nil + }) } diff --git a/resolver/mdns.go b/resolver/mdns.go index a8ba1ee5..b8595d67 100644 --- a/resolver/mdns.go +++ b/resolver/mdns.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/safing/portmaster/network/netutils" + "github.com/miekg/dns" "github.com/safing/portbase/log" @@ -29,10 +31,11 @@ var ( questionsLock sync.Mutex mDNSResolver = &Resolver{ - Server: ServerSourceMDNS, - ServerType: ServerTypeDNS, - Source: ServerSourceMDNS, - Conn: &mDNSResolverConn{}, + Server: ServerSourceMDNS, + ServerType: ServerTypeDNS, + ServerIPScope: netutils.SiteLocal, + Source: ServerSourceMDNS, + Conn: &mDNSResolverConn{}, } ) @@ -189,15 +192,21 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { // get entry from database if saveFullRequest { + // get from database rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype)) + // if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired: + // create new and do not append if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() { rrCache = &RRCache{ - Domain: question.Name, - Question: dns.Type(question.Qtype), + Domain: question.Name, + Question: dns.Type(question.Qtype), + Server: mDNSResolver.Server, + ServerScope: mDNSResolver.ServerIPScope, } } } + // add all entries to RRCache for _, entry := range message.Answer { if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) { if saveFullRequest { @@ -289,9 +298,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { continue } rrCache = &RRCache{ - Domain: v.Header().Name, - Question: dns.Type(v.Header().Class), - Answer: []dns.RR{v}, + Domain: v.Header().Name, + Question: dns.Type(v.Header().Class), + Answer: []dns.RR{v}, + Server: mDNSResolver.Server, + ServerScope: mDNSResolver.ServerIPScope, } rrCache.Clean(60) err := rrCache.Save() diff --git a/resolver/namerecord.go b/resolver/namerecord.go index d94beaaa..1a594e8f 100644 --- a/resolver/namerecord.go +++ b/resolver/namerecord.go @@ -12,7 +12,7 @@ import ( var ( recordDatabase = database.NewInterface(&database.Options{ AlwaysSetRelativateExpiry: 2592000, // 30 days - CacheSize: 128, + CacheSize: 256, }) ) diff --git a/resolver/namerecord_test.go b/resolver/namerecord_test.go new file mode 100644 index 00000000..f0e21a37 --- /dev/null +++ b/resolver/namerecord_test.go @@ -0,0 +1,27 @@ +package resolver + +import "testing" + +func TestNameRecordStorage(t *testing.T) { + testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com." + testQuestion := "A" + + testNameRecord := &NameRecord{ + Domain: testDomain, + Question: testQuestion, + } + + err := testNameRecord.Save() + if err != nil { + t.Fatal(err) + } + + r, err := GetNameRecord(testDomain, testQuestion) + if err != nil { + t.Fatal(err) + } + + if r.Domain != testDomain || r.Question != testQuestion { + t.Fatal("mismatch") + } +} diff --git a/resolver/pooling_test.go b/resolver/pooling_test.go new file mode 100644 index 00000000..dc341f33 --- /dev/null +++ b/resolver/pooling_test.go @@ -0,0 +1,184 @@ +package resolver + +import ( + "sync" + "testing" + + "github.com/miekg/dns" +) + +var ( + domainFeed = make(chan string) +) + +func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) { + dnsClient := brc.clientManager.getDNSClient() + + // create query + dnsQuery := new(dns.Msg) + dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) + + // get connection + conn, new, err := dnsClient.getConn() + if err != nil { + t.Fatalf("failed to connect: %s", err) //nolint:staticcheck + } + + // query server + reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn) + if err != nil { + t.Fatal(err) //nolint:staticcheck + } + if reply == nil { + t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck + } + + t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl) + dnsClient.done() + wg.Done() +} + +func TestClientPooling(t *testing.T) { + // skip if short - this test depends on the Internet and might fail randomly + if testing.Short() { + t.Skip() + } + + go feedDomains() + + // create separate resolver for this test + resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config") + if err != nil { + t.Fatal(err) + } + brc := resolver.Conn.(*BasicResolverConn) + + wg := &sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(10) + for i := 0; i < 10; i++ { + go testQuery(t, wg, brc, &Query{ + FQDN: <-domainFeed, + QType: dns.Type(dns.TypeA), + }) + } + wg.Wait() + if len(brc.clientManager.pool) != 10 { + t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool)) + } + } +} + +func feedDomains() { + for { + for _, domain := range poolingTestDomains { + domainFeed <- domain + } + } +} + +// Data + +var ( + poolingTestDomains = []string{ + "facebook.com.", + "google.com.", + "youtube.com.", + "twitter.com.", + "instagram.com.", + "linkedin.com.", + "microsoft.com.", + "apple.com.", + "wikipedia.org.", + "plus.google.com.", + "en.wikipedia.org.", + "googletagmanager.com.", + "youtu.be.", + "adobe.com.", + "vimeo.com.", + "pinterest.com.", + "itunes.apple.com.", + "play.google.com.", + "maps.google.com.", + "goo.gl.", + "wordpress.com.", + "blogspot.com.", + "bit.ly.", + "github.com.", + "player.vimeo.com.", + "amazon.com.", + "wordpress.org.", + "docs.google.com.", + "yahoo.com.", + "mozilla.org.", + "tumblr.com.", + "godaddy.com.", + "flickr.com.", + "parked-content.godaddy.com.", + "drive.google.com.", + "support.google.com.", + "apache.org.", + "gravatar.com.", + "europa.eu.", + "qq.com.", + "w3.org.", + "nytimes.com.", + "reddit.com.", + "macromedia.com.", + "get.adobe.com.", + "soundcloud.com.", + "sourceforge.net.", + "sites.google.com.", + "nih.gov.", + "amazonaws.com.", + "t.co.", + "support.microsoft.com.", + "forbes.com.", + "theguardian.com.", + "cnn.com.", + "github.io.", + "bbc.co.uk.", + "dropbox.com.", + "whatsapp.com.", + "medium.com.", + "creativecommons.org.", + "www.ncbi.nlm.nih.gov.", + "httpd.apache.org.", + "archive.org.", + "ec.europa.eu.", + "php.net.", + "apps.apple.com.", + "weebly.com.", + "support.apple.com.", + "weibo.com.", + "wixsite.com.", + "issuu.com.", + "who.int.", + "paypal.com.", + "m.facebook.com.", + "oracle.com.", + "msn.com.", + "gnu.org.", + "tinyurl.com.", + "reuters.com.", + "l.facebook.com.", + "cloudflare.com.", + "wsj.com.", + "washingtonpost.com.", + "domainmarket.com.", + "imdb.com.", + "bbc.com.", + "bing.com.", + "accounts.google.com.", + "vk.com.", + "api.whatsapp.com.", + "opera.com.", + "cdc.gov.", + "slideshare.net.", + "wpa.qq.com.", + "harvard.edu.", + "mit.edu.", + "code.google.com.", + "wikimedia.org.", + } +) diff --git a/resolver/resolve.go b/resolver/resolve.go index f13d07c2..57c0c967 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -114,6 +114,7 @@ func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) { rrCache.MixAnswers() return rrCache, nil } + log.Tracer(ctx).Debugf("resolver: waited for another %s%s query, but cache missed!", q.FQDN, q.QType) // if cache is still empty or non-compliant, go ahead and just query } else { // we are the first! @@ -132,14 +133,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache { if err != nil { if err != database.ErrNotFound { log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) - log.Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) } return nil } // get resolver that rrCache was resolved with - resolver := getResolverByIDWithLocking(rrCache.Server) + resolver := getActiveResolverByIDWithLocking(rrCache.Server) if resolver == nil { + log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server) return nil } @@ -165,6 +166,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache { }) } + log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0))) return rrCache } diff --git a/resolver/resolver.go b/resolver/resolver.go index 65155fab..244b0c57 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "errors" "net" "sync" "time" @@ -92,7 +93,7 @@ type BasicResolverConn struct { sync.Mutex // for lastFail resolver *Resolver - clientManager *clientManager + clientManager *dnsClientManager lastFail time.Time } @@ -126,18 +127,41 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er // start var reply *dns.Msg + var ttl time.Duration var err error - for i := 0; i < 3; i++ { + var conn *dns.Conn + var new bool + var i int - // log query time - // qStart := time.Now() - reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress) - // log.Tracef("resolver: query to %s took %s", resolver.Server, time.Now().Sub(qStart)) + for ; i < 5; i++ { + + // first get connection + dc := brc.clientManager.getDNSClient() + conn, new, err = dc.getConn() + if err != nil { + log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err) + // remove client from pool + dc.destroy() + // try again + continue + } + if new { + log.Tracer(ctx).Tracef("resolver: created new connection to %s", resolver.ServerAddress) + } else { + log.Tracer(ctx).Tracef("resolver: reusing connection to %s", resolver.ServerAddress) + } + + // query server + reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn) + log.Tracer(ctx).Tracef("resolver: query took %s", ttl) // error handling if err != nil { log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err) + // remove client from pool + dc.destroy() + // TODO: handle special cases // 1. connect: network is unreachable // 2. timeout @@ -148,13 +172,23 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er // temporary error if nerr, ok := err.(net.Error); ok && nerr.Timeout() { log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server) + // try again continue } // permanent error break + } else if reply == nil { + // remove client from pool + dc.destroy() + + log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server) + return nil, errors.New("internal error") } + // make client available again + dc.done() + if resolver.IsBlockedUpstream(reply) { return nil, &BlockedUpstreamError{resolver.GetName()} } @@ -166,12 +200,15 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er if err != nil { return nil, err // TODO: mark as failed + } else if reply == nil { + log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), i+1) + return nil, errors.New("internal error") } // hint network environment at successful connection netenv.ReportSuccessfulConnection() - new := &RRCache{ + newRecord := &RRCache{ Domain: q.FQDN, Question: q.QType, Answer: reply.Answer, @@ -182,5 +219,5 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er } // TODO: check if reply.Answer is valid - return new, nil + return newRecord, nil } diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 00ad0d0e..5d65fed3 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -25,7 +25,7 @@ var ( globalResolvers []*Resolver // all (global) resolvers localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope - allResolvers map[string]*Resolver // lookup map of all resolvers + activeResolvers map[string]*Resolver // lookup map of all resolvers resolversLock sync.RWMutex dupReqMap = make(map[string]*sync.WaitGroup) @@ -41,11 +41,11 @@ func indexOfScope(domain string, list []*Scope) int { return -1 } -func getResolverByIDWithLocking(server string) *Resolver { - resolversLock.Lock() - defer resolversLock.Unlock() +func getActiveResolverByIDWithLocking(server string) *Resolver { + resolversLock.RLock() + defer resolversLock.RUnlock() - resolver, ok := allResolvers[server] + resolver, ok := activeResolvers[server] if ok { return resolver } @@ -62,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string { return address } -func clientManagerFactory(serverType string) func(*Resolver) *clientManager { +func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager { switch serverType { case ServerTypeDNS: return newDNSClientManager @@ -152,8 +152,8 @@ func configureSearchDomains(resolver *Resolver, searches []string) { } } -func getConfiguredResolvers() (resolvers []*Resolver) { - for _, server := range configuredNameServers() { +func getConfiguredResolvers(list []string) (resolvers []*Resolver) { + for _, server := range list { resolver, skip, err := createResolver(server, "config") if err != nil { // TODO(ppacher): module error @@ -199,19 +199,40 @@ func loadResolvers() { defer resolversLock.Unlock() newResolvers := append( - getConfiguredResolvers(), + getConfiguredResolvers(configuredNameServers()), getSystemResolvers()..., ) - // save resolvers - globalResolvers = newResolvers - if len(globalResolvers) == 0 { - log.Criticalf("resolver: no (valid) dns servers found in configuration and system") - // TODO(module error) + if len(newResolvers) == 0 { + msg := "no (valid) dns servers found in (user) configuration or system, falling back to defaults" + log.Warningf("resolver: %s", msg) + module.Warning("no-valid-user-resolvers", msg) + + // load defaults directly, overriding config system + newResolvers = getConfiguredResolvers(defaultNameServers) + if len(newResolvers) == 0 { + msg = "no (valid) dns servers found in configuration or system" + log.Criticalf("resolver: %s", msg) + module.Error("no-valid-default-resolvers", msg) + return + } } + // save resolvers + globalResolvers = newResolvers + + // assing resolvers to scopes setLocalAndScopeResolvers(globalResolvers) + // set active resolvers (for cache validation) + // reset + activeResolvers = make(map[string]*Resolver) + // add + for _, resolver := range newResolvers { + activeResolvers[resolver.Server] = resolver + } + activeResolvers[mDNSResolver.Server] = mDNSResolver + // log global resolvers if len(globalResolvers) > 0 { log.Trace("resolver: loaded global resolvers:") diff --git a/resolver/reverse.go b/resolver/reverse.go index 0487cf44..c236818b 100644 --- a/resolver/reverse.go +++ b/resolver/reverse.go @@ -65,12 +65,12 @@ func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) ( for _, rr := range rrCache.Answer { switch v := rr.(type) { case *dns.A: - log.Infof("A: %s", v.A.String()) + // log.Debugf("A: %s", v.A.String()) if ip == v.A.String() { return ptrName, nil } case *dns.AAAA: - log.Infof("AAAA: %s", v.AAAA.String()) + // log.Debugf("AAAA: %s", v.AAAA.String()) if ip == v.AAAA.String() { return ptrName, nil } diff --git a/resolver/rrcache_test.go b/resolver/rrcache_test.go new file mode 100644 index 00000000..8aaa3094 --- /dev/null +++ b/resolver/rrcache_test.go @@ -0,0 +1,41 @@ +package resolver + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestCaching(t *testing.T) { + testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com." + testQuestion := "A" + + testNameRecord := &NameRecord{ + Domain: testDomain, + Question: testQuestion, + } + + err := testNameRecord.Save() + if err != nil { + t.Fatal(err) + } + + rrCache, err := GetRRCache(testDomain, dns.Type(dns.TypeA)) + if err != nil { + t.Fatal(err) + } + + err = rrCache.Save() + if err != nil { + t.Fatal(err) + } + + rrCache2, err := GetRRCache(testDomain, dns.Type(dns.TypeA)) + if err != nil { + t.Fatal(err) + } + + if rrCache2.Domain != rrCache.Domain { + t.Fatal("something very is wrong") + } +}