diff --git a/resolver/clients.go b/resolver/clients.go index 096f2af3..e3456759 100644 --- a/resolver/clients.go +++ b/resolver/clients.go @@ -12,8 +12,9 @@ import ( const ( defaultClientTTL = 5 * time.Minute - defaultRequestTimeout = 5 * time.Second - connectionEOLGracePeriod = 10 * time.Second + defaultRequestTimeout = 3 * time.Second // dns query + defaultConnectTimeout = 2 * time.Second // tcp/tls + connectionEOLGracePeriod = 7 * time.Second ) var ( @@ -43,23 +44,17 @@ type dnsClientManager struct { factory func() *dns.Client // internal - pool []*dnsClient + pool sync.Pool } type dnsClient struct { - mgr *dnsClientManager - - inUse bool - useUntil time.Time - dead bool - inPool bool - poolIndex int - - client *dns.Client - conn *dns.Conn + mgr *dnsClientManager + client *dns.Client + conn *dns.Conn + useUntil time.Time } -// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). +// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { if dc.conn == nil { dc.conn, err = dc.client.Dial(dc.mgr.serverAddress) @@ -71,23 +66,11 @@ func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { return dc.conn, false, nil } -func (dc *dnsClient) done() { - dc.mgr.lock.Lock() - defer dc.mgr.lock.Unlock() - - dc.inUse = false +func (dc *dnsClient) addToPool() { + dc.mgr.pool.Put(dc) } func (dc *dnsClient) destroy() { - dc.mgr.lock.Lock() - dc.inUse = true // block from being used - dc.dead = true // abort cleaning - if dc.inPool { - dc.inPool = false - dc.mgr.pool[dc.poolIndex] = nil - } - dc.mgr.lock.Unlock() - if dc.conn != nil { _ = dc.conn.Close() } @@ -118,6 +101,7 @@ func newTCPClientManager(resolver *Resolver) *dnsClientManager { Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), + Timeout: defaultConnectTimeout, KeepAlive: defaultClientTTL, }, } @@ -140,6 +124,7 @@ func newTLSClientManager(resolver *Resolver) *dnsClientManager { Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), + Timeout: defaultConnectTimeout, KeepAlive: defaultClientTTL, }, } @@ -159,11 +144,18 @@ func (cm *dnsClientManager) getDNSClient() *dnsClient { } } - // get first unused from pool + // get cached client from pool now := time.Now().UTC() - for _, dc := range cm.pool { - if dc != nil && !dc.inUse && now.Before(dc.useUntil) { - dc.inUse = true + +poolLoop: + for { + dc, ok := cm.pool.Get().(*dnsClient) + switch { + case !ok || dc == nil: // cache empty (probably, pool may always return nil!) + break poolLoop // create new + case now.After(dc.useUntil): + continue // get next + default: return dc } } @@ -171,27 +163,11 @@ func (cm *dnsClientManager) getDNSClient() *dnsClient { // no available in pool, create new newClient := &dnsClient{ mgr: cm, - inUse: true, - useUntil: now.Add(cm.ttl), - inPool: true, client: cm.factory(), + useUntil: now.Add(cm.ttl), } newClient.startCleaner() - // find free spot in pool - for poolIndex, dc := range cm.pool { - if dc == nil { - cm.pool[poolIndex] = newClient - newClient.poolIndex = poolIndex - return newClient - } - } - - // append to pool - cm.pool = append(cm.pool, newClient) - newClient.poolIndex = len(cm.pool) - 1 - // TODO: shrink pool again? - return newClient } @@ -200,26 +176,12 @@ func (dc *dnsClient) startCleaner() { // While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone. module.StartWorker("dns client cleanup", func(ctx context.Context) error { select { - case <-time.After(dc.mgr.ttl + time.Second): - dc.mgr.lock.Lock() - cleanNow := dc.dead || !dc.inUse - dc.mgr.lock.Unlock() - - if cleanNow { - dc.destroy() - return nil - } + case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod): + // destroy case <-ctx.Done(): // give a short time before kill for graceful request completion time.Sleep(100 * time.Millisecond) } - - // wait for grace period to end, then kill - select { - case <-time.After(connectionEOLGracePeriod): - case <-ctx.Done(): - } - dc.destroy() return nil }) diff --git a/resolver/pooling_test.go b/resolver/pooling_test.go index dc341f33..3c03c14c 100644 --- a/resolver/pooling_test.go +++ b/resolver/pooling_test.go @@ -2,6 +2,7 @@ package resolver import ( "sync" + "sync/atomic" "testing" "github.com/miekg/dns" @@ -11,7 +12,7 @@ var ( domainFeed = make(chan string) ) -func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) { +func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) { dnsClient := brc.clientManager.getDNSClient() // create query @@ -23,6 +24,9 @@ func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Quer if err != nil { t.Fatalf("failed to connect: %s", err) //nolint:staticcheck } + if new { + atomic.AddUint32(newCnt, 1) + } // query server reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn) @@ -33,8 +37,8 @@ func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Quer t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck } - t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl) - dnsClient.done() + t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl) + dnsClient.addToPool() wg.Done() } @@ -54,17 +58,18 @@ func TestClientPooling(t *testing.T) { brc := resolver.Conn.(*BasicResolverConn) wg := &sync.WaitGroup{} + var newCnt uint32 for i := 0; i < 10; i++ { wg.Add(10) for i := 0; i < 10; i++ { - go testQuery(t, wg, brc, &Query{ + go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck FQDN: <-domainFeed, QType: dns.Type(dns.TypeA), }) } wg.Wait() - if len(brc.clientManager.pool) != 10 { - t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool)) + if newCnt > uint32(10+i) { + t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i) } } } diff --git a/resolver/resolve.go b/resolver/resolve.go index ee76d5fa..e5c406ae 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -14,8 +14,6 @@ import ( ) var ( - mtAsyncResolve = "async resolve" - // basic errors // ErrNotFound is a basic error that will match all "not found" errors @@ -160,7 +158,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache { log.Tracer(ctx).Trace("resolver: serving from cache, requesting new") // resolve async - module.StartLowPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error { + module.StartWorker("resolve async", func(ctx context.Context) error { _, _ = resolveAndCache(ctx, q) return nil }) @@ -220,11 +218,6 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error return nil, ErrNoCompliance } - // prep - lastFailBoundary := time.Now().Add( - -time.Duration(nameserverRetryRate()) * time.Second, - ) - // start resolving var i int diff --git a/resolver/resolver.go b/resolver/resolver.go index d44bba8b..8921e2db 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -215,8 +215,8 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er return nil, errors.New("internal error") } - // make client available again - dc.done() + // make client available (again) + dc.addToPool() if resolver.IsBlockedUpstream(reply) { return nil, &BlockedUpstreamError{resolver.GetName()}