diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index 6cab73f5..3e84d942 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -104,13 +104,7 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery { } tr.Lock() // check for existing query - for i := 0; i < 10; i++ { // don't try forever - _, exists := tr.inFlightQueries[msg.Id] - if !exists { - break // we are unique, yay! - } - msg.Id = dns.Id() // regenerate ID - } + tr.ensureUniqueID(msg) // add query to in flight registry tr.inFlightQueries[msg.Id] = inFlight tr.Unlock() @@ -121,6 +115,27 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery { return inFlight } +// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked. +func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) { + // try a random ID 10000 times + for i := 0; i < 10000; i++ { // don't try forever + _, exists := tr.inFlightQueries[msg.Id] + if !exists { + return // we are unique, yay! + } + msg.Id = dns.Id() // regenerate ID + } + // go through the complete space + var id uint16 + for ; id <= (1<<16)-1; id++ { // don't try forever + _, exists := tr.inFlightQueries[id] + if !exists { + msg.Id = id + return // we are unique, yay! + } + } +} + // Query executes the given query against the resolver. func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { // submit to client @@ -147,32 +162,22 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { } type tcpResolverConnMgr struct { - tr *TCPResolver - workerCtx context.Context - conn *dns.Conn - connCtx context.Context - cancelConnCtx func() - connTimer *time.Timer - connClosing *abool.AtomicBool - responses chan *dns.Msg - failCnt int + tr *TCPResolver + responses chan *dns.Msg + failCnt int } func (tr *TCPResolver) startClient() { if tr.clientStarted.SetToIf(false, true) { mgr := &tcpResolverConnMgr{ - tr: tr, - connTimer: time.NewTimer(tr.clientTTL), - connClosing: abool.New(), - responses: make(chan *dns.Msg, 100), + tr: tr, + responses: make(chan *dns.Msg, 100), } module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run) } } func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { - mgr.workerCtx = workerCtx - // connection lifecycle loop for { // check if we are failing @@ -181,19 +186,16 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { return nil } - // clean up anything that is left over - mgr.cleanupConnection() - // wait for work before creating connection - proceed := mgr.waitForWork() + proceed := mgr.waitForWork(workerCtx) if !proceed { mgr.shutdown() return nil } // create connection - success := mgr.establishConnection() - if !success { + conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection(workerCtx) + if conn == nil { mgr.failCnt++ continue } @@ -202,7 +204,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { netenv.ReportSuccessfulConnection() // handle queries - proceed = mgr.queryHandler() + proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx) if !proceed { mgr.shutdown() return nil @@ -210,26 +212,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { } } -func (mgr *tcpResolverConnMgr) cleanupConnection() { - // cleanup old connection - if mgr.conn != nil { - mgr.connClosing.Set() // silence connection errors - _ = mgr.conn.Close() - if mgr.cancelConnCtx != nil { - mgr.cancelConnCtx() - } - - // delete old connection - mgr.conn = nil - - // increase instance counter - atomic.AddUint32(mgr.tr.connInstanceID, 1) - } -} - func (mgr *tcpResolverConnMgr) shutdown() { - mgr.cleanupConnection() - // reply to all waiting queries mgr.tr.Lock() for id, inFlight := range mgr.tr.inFlightQueries { @@ -246,7 +229,7 @@ func (mgr *tcpResolverConnMgr) shutdown() { } } -func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) { +func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) { // wait until there is something to do mgr.tr.Lock() waiting := len(mgr.tr.inFlightQueries) @@ -256,6 +239,7 @@ func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) { ignoreBefore := time.Now().Add(-ignoreQueriesAfter) currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID) mgr.tr.Lock() + defer mgr.tr.Unlock() for id, inFlight := range mgr.tr.inFlightQueries { if inFlight.Started.Before(ignoreBefore) { // remove old queries @@ -272,29 +256,34 @@ func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) { } // in-flight queries that match the connection instance ID are not changed. They are already in the queue. } - mgr.tr.Unlock() - } else { - // wait for first query - select { - case <-mgr.workerCtx.Done(): - return false - case msg := <-mgr.tr.queries: - // re-insert query, we will handle it later + return true + } + + // wait for first query + select { + case <-workerCtx.Done(): + return false + case msg := <-mgr.tr.queries: + // re-insert query, we will handle it later + module.StartWorker("reinject triggering dns query", func(ctx context.Context) error { select { case mgr.tr.queries <- msg: - default: + case <-time.After(2 * time.Second): log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Name) } - } + return nil + }) } return true } -func (mgr *tcpResolverConnMgr) establishConnection() (success bool) { - // create connection - mgr.connCtx, mgr.cancelConnCtx = context.WithCancel(mgr.workerCtx) - mgr.connClosing = abool.New() +func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) ( + conn *dns.Conn, + connClosing *abool.AtomicBool, + connCtx context.Context, + cancelConnCtx context.CancelFunc, +) { // refresh dialer to set an authenticated local address // TODO: lock dnsClient (only manager should run at any time, so this should not be an issue) mgr.tr.dnsClient.Dialer = &net.Dialer{ @@ -303,42 +292,56 @@ func (mgr *tcpResolverConnMgr) establishConnection() (success bool) { KeepAlive: defaultClientTTL, } // connect - c, err := mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress) + var err error + conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress) if err != nil { log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.Name, mgr.tr.resolver.ServerAddress) - return false + return nil, nil, nil, nil } - mgr.conn = c - log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, mgr.conn.RemoteAddr()) + connCtx, cancelConnCtx = context.WithCancel(workerCtx) + connClosing = abool.New() - // reset timer - mgr.connTimer.Stop() - select { - case <-mgr.connTimer.C: // try to empty the timer - default: - } - mgr.connTimer.Reset(mgr.tr.clientTTL) + log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr()) // start reader - module.StartServiceWorker("dns client reader", 10*time.Millisecond, mgr.msgReader) + module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error { + return mgr.msgReader(conn, connClosing, cancelConnCtx) + }) - return true + return conn, connClosing, connCtx, cancelConnCtx } -func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit +func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter. + workerCtx context.Context, + conn *dns.Conn, + connClosing *abool.AtomicBool, + connCtx context.Context, + cancelConnCtx context.CancelFunc, +) (proceed bool) { var readyToRecycle bool + ttlTimer := time.After(mgr.tr.clientTTL) + + // clean up connection + defer func() { + connClosing.Set() // silence connection errors + cancelConnCtx() + _ = conn.Close() + + // increase instance counter + atomic.AddUint32(mgr.tr.connInstanceID, 1) + }() for { select { - case <-mgr.workerCtx.Done(): + case <-workerCtx.Done(): // module shutdown return false - case <-mgr.connCtx.Done(): + case <-connCtx.Done(): // connection error return true - case <-mgr.connTimer.C: + case <-ttlTimer: // connection TTL reached, rebuild connection // but handle all in flight queries first readyToRecycle = true @@ -351,57 +354,18 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit case msg := <-mgr.tr.queries: // write query - _ = mgr.conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout)) - err := mgr.conn.WriteMsg(msg) + _ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout)) + err := conn.WriteMsg(msg) if err != nil { - if mgr.connClosing.SetToIf(false, true) { - mgr.cancelConnCtx() - log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, mgr.conn.RemoteAddr(), err) + if connClosing.SetToIf(false, true) { + log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err) } return true } case msg := <-mgr.responses: - if msg != nil { // nil messages only trigger the recycle check - // handle query from resolver - mgr.tr.Lock() - inFlight, ok := mgr.tr.inFlightQueries[msg.Id] - if ok { - delete(mgr.tr.inFlightQueries, msg.Id) - } - mgr.tr.Unlock() - - if ok { - select { - case inFlight.Response <- msg: - mgr.failCnt = 0 // reset fail counter - // responded! - default: - // save to cache, if enabled - if !inFlight.Query.NoCaching { - // persist to database - rrCache := inFlight.MakeCacheRecord(msg) - rrCache.Clean(600) - err := rrCache.Save() - if err != nil { - log.Warningf( - "resolver: failed to cache RR for %s%s: %s", - inFlight.Query.FQDN, - inFlight.Query.QType.String(), - err, - ) - } - } - } - } else { - log.Debugf( - "resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v", - mgr.tr.resolver.Name, - mgr.conn.RemoteAddr(), - msg.Id, - msg.Question, - ) - } + if msg != nil { + mgr.handleQueryResponse(conn, msg) } if readyToRecycle { @@ -410,7 +374,7 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit activeQueries := len(mgr.tr.inFlightQueries) mgr.tr.Unlock() if activeQueries == 0 { - log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, mgr.conn.RemoteAddr()) + log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr()) return true } } @@ -419,18 +383,65 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit } } -func (mgr *tcpResolverConnMgr) msgReader(workerCtx context.Context) error { - // copy values from manager - conn := mgr.conn - cancelConnCtx := mgr.cancelConnCtx - connClosing := mgr.connClosing +func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) { + // handle query from resolver + mgr.tr.Lock() + inFlight, ok := mgr.tr.inFlightQueries[msg.Id] + if ok { + delete(mgr.tr.inFlightQueries, msg.Id) + } + mgr.tr.Unlock() + if !ok { + log.Debugf( + "resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v", + mgr.tr.resolver.Name, + conn.RemoteAddr(), + msg.Id, + msg.Question, + ) + return + } + + select { + case inFlight.Response <- msg: + mgr.failCnt = 0 // reset fail counter + // responded! + return + default: + // no one is listening for that response. + } + + // if caching is disabled we're done + if inFlight.Query.NoCaching { + return + } + + // persist to database + rrCache := inFlight.MakeCacheRecord(msg) + rrCache.Clean(600) + err := rrCache.Save() + if err != nil { + log.Warningf( + "resolver: failed to cache RR for %s%s: %s", + inFlight.Query.FQDN, + inFlight.Query.QType.String(), + err, + ) + } +} + +func (mgr *tcpResolverConnMgr) msgReader( + conn *dns.Conn, + connClosing *abool.AtomicBool, + cancelConnCtx context.CancelFunc, +) error { + defer cancelConnCtx() for { msg, err := conn.ReadMsg() if err != nil { if connClosing.SetToIf(false, true) { - cancelConnCtx() - log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, mgr.conn.RemoteAddr(), err) + log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err) } return nil }