diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index ce1b7a62..bc0a9735 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -261,7 +261,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er rrCache, err := resolver.Resolve(ctx, q) if err != nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains - tracer.Warningf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err) + tracer.Debugf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err) if errors.Is(err, resolver.ErrBlocked) { conn.Block(err.Error()) diff --git a/netenv/online-status.go b/netenv/online-status.go index 62b75f6c..014efc8f 100644 --- a/netenv/online-status.go +++ b/netenv/online-status.go @@ -106,8 +106,6 @@ var ( captivePortalURL string captivePortalLock sync.Mutex - - waitForever = make(chan time.Time) ) func init() { @@ -234,7 +232,7 @@ func getDynamicStatusTrigger() <-chan time.Time { case StatusSemiOnline: return time.After(5 * time.Minute) case StatusOnline: - return waitForever + return nil default: // unknown status return time.After(5 * time.Minute) } diff --git a/network/proc/findpid.go b/network/proc/findpid.go index 3b123f4f..0610e361 100644 --- a/network/proc/findpid.go +++ b/network/proc/findpid.go @@ -21,31 +21,18 @@ var ( pidsByUser = make(map[int][]int) ) -// FindConnectionPID returns the pid of the given socket info. -func FindConnectionPID(socketInfo *socket.ConnectionInfo) (pid int) { +// GetPID returns the already existing pid of the given socket info or searches for it. +// This also acts as a getter for socket.*Info.PID, as locking for that occurs here. +func GetPID(socketInfo socket.Info) (pid int) { pidsByUserLock.Lock() defer pidsByUserLock.Unlock() - if socketInfo.PID != socket.UnidentifiedProcessID { - return socket.UnidentifiedProcessID + if socketInfo.GetPID() != socket.UnidentifiedProcessID { + return socketInfo.GetPID() } - pid = findPID(socketInfo.UID, socketInfo.Inode) - socketInfo.PID = pid - return pid -} - -// FindBindPID returns the pid of the given socket info. -func FindBindPID(socketInfo *socket.BindInfo) (pid int) { - pidsByUserLock.Lock() - defer pidsByUserLock.Unlock() - - if socketInfo.PID != socket.UnidentifiedProcessID { - return socket.UnidentifiedProcessID - } - - pid = findPID(socketInfo.UID, socketInfo.Inode) - socketInfo.PID = pid + pid = findPID(socketInfo.GetUID(), socketInfo.GetInode()) + socketInfo.SetPID(pid) return pid } @@ -175,6 +162,9 @@ entryLoop: } +// readDirNames only reads the directory names. Using ioutil.ReadDir() would call `lstat` on every +// resulting directory name, which we don't need. This function will be called a lot, so we should +// refrain from unnecessary work. func readDirNames(dir string) (names []string) { file, err := os.Open(dir) if err != nil { diff --git a/network/proc/tables_test.go b/network/proc/tables_test.go index bcce8498..d5b1959f 100644 --- a/network/proc/tables_test.go +++ b/network/proc/tables_test.go @@ -14,12 +14,12 @@ func TestSockets(t *testing.T) { } fmt.Println("\nTCP 4 connections:") for _, connection := range connections { - pid := FindConnectionPID(connection) + pid := GetPID(connection) fmt.Printf("%d: %+v\n", pid, connection) } fmt.Println("\nTCP 4 listeners:") for _, listener := range listeners { - pid := FindBindPID(listener) + pid := GetPID(listener) fmt.Printf("%d: %+v\n", pid, listener) } @@ -29,12 +29,12 @@ func TestSockets(t *testing.T) { } fmt.Println("\nTCP 6 connections:") for _, connection := range connections { - pid := FindConnectionPID(connection) + pid := GetPID(connection) fmt.Printf("%d: %+v\n", pid, connection) } fmt.Println("\nTCP 6 listeners:") for _, listener := range listeners { - pid := FindBindPID(listener) + pid := GetPID(listener) fmt.Printf("%d: %+v\n", pid, listener) } @@ -44,7 +44,7 @@ func TestSockets(t *testing.T) { } fmt.Println("\nUDP 4 binds:") for _, bind := range binds { - pid := FindBindPID(bind) + pid := GetPID(bind) fmt.Printf("%d: %+v\n", pid, bind) } @@ -54,7 +54,7 @@ func TestSockets(t *testing.T) { } fmt.Println("\nUDP 6 binds:") for _, bind := range binds { - pid := FindBindPID(bind) + pid := GetPID(bind) fmt.Printf("%d: %+v\n", pid, bind) } } diff --git a/network/socket/socket.go b/network/socket/socket.go index e8dfe1d9..22f37ef0 100644 --- a/network/socket/socket.go +++ b/network/socket/socket.go @@ -29,3 +29,35 @@ type Address struct { IP net.IP Port uint16 } + +// Info is a generic interface to both ConnectionInfo and BindInfo. +type Info interface { + GetPID() int + SetPID(int) + GetUID() int + GetInode() int +} + +// GetPID returns the PID. +func (i *ConnectionInfo) GetPID() int { return i.PID } + +// SetPID sets the PID to the given value. +func (i *ConnectionInfo) SetPID(pid int) { i.PID = pid } + +// GetUID returns the UID. +func (i *ConnectionInfo) GetUID() int { return i.UID } + +// GetInode returns the Inode. +func (i *ConnectionInfo) GetInode() int { return i.Inode } + +// GetPID returns the PID. +func (i *BindInfo) GetPID() int { return i.PID } + +// SetPID sets the PID to the given value. +func (i *BindInfo) SetPID(pid int) { i.PID = pid } + +// GetUID returns the UID. +func (i *BindInfo) GetUID() int { return i.UID } + +// GetInode returns the Inode. +func (i *BindInfo) GetInode() int { return i.Inode } diff --git a/network/state/lookup.go b/network/state/lookup.go index b05fe237..48d2f5b6 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -81,7 +81,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) ( if localPort == socketInfo.Local.Port && (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { table.lock.RUnlock() - return checkBindPID(socketInfo, true) + return checkPID(socketInfo, true) } } @@ -90,7 +90,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) ( if localPort == socketInfo.Local.Port && localIP.Equal(socketInfo.Local.IP) { table.lock.RUnlock() - return checkConnectionPID(socketInfo, false) + return checkPID(socketInfo, false) } } @@ -138,12 +138,12 @@ func (table *udpTable) lookup(pktInfo *packet.Info) ( // do not check direction if remoteIP/Port is not given if pktInfo.RemotePort() == 0 { - return checkBindPID(socketInfo, pktInfo.Inbound) + return checkPID(socketInfo, pktInfo.Inbound) } // get direction and return connInbound := table.getDirection(socketInfo, pktInfo) - return checkBindPID(socketInfo, connInbound) + return checkPID(socketInfo, connInbound) } } diff --git a/network/state/system_linux.go b/network/state/system_linux.go index c4b67777..9a0a41ab 100644 --- a/network/state/system_linux.go +++ b/network/state/system_linux.go @@ -14,30 +14,10 @@ var ( getUDP6Table = proc.GetUDP6Table ) -func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { +func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) { for i := 0; i <= lookupRetries; i++ { // look for PID - pid = proc.FindConnectionPID(socketInfo) - if pid != socket.UnidentifiedProcessID { - // if we found a PID, return - break - } - - // every time, except for the last iteration - if i < lookupRetries { - // we found no PID, we could have been too fast, give the kernel some time to think - // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total - time.Sleep(time.Duration(i+1) * baseWaitTime) - } - } - - return pid, connInbound, nil -} - -func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { - for i := 0; i <= lookupRetries; i++ { - // look for PID - pid = proc.FindBindPID(socketInfo) + pid = proc.GetPID(socketInfo) if pid != socket.UnidentifiedProcessID { // if we found a PID, return break diff --git a/network/state/system_windows.go b/network/state/system_windows.go index a03ea5f6..56927366 100644 --- a/network/state/system_windows.go +++ b/network/state/system_windows.go @@ -12,10 +12,6 @@ var ( getUDP6Table = iphelper.GetUDP6Table ) -func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { - return socketInfo.PID, connInbound, nil -} - -func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { - return socketInfo.PID, connInbound, nil +func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) { + return socketInfo.GetPID(), connInbound, nil } diff --git a/resolver/resolve.go b/resolver/resolve.go index 5b0ac937..f6f62ca2 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -228,13 +228,12 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error // check if we are online if netenv.GetOnlineStatus() == netenv.StatusOffline { - if netenv.IsOnlineStatusTestDomain(q.FQDN) { - log.Tracer(ctx).Debugf("resolver: permitting online status test domain %s to resolve even though offline", q.FQDN) - } else { + if !netenv.IsOnlineStatusTestDomain(q.FQDN) { log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN) // we are offline and this is not an online check query return nil, ErrOffline } + log.Tracer(ctx).Debugf("resolver: permitting online status test domain %s to resolve even though offline", q.FQDN) } // start resolving @@ -253,9 +252,6 @@ resolveLoop: // resolve rrCache, err = resolver.Conn.Query(ctx, q) if err != nil { - - // TODO: check if we are online? - switch { case errors.Is(err, ErrNotFound): // NXDomain, or similar @@ -268,6 +264,8 @@ resolveLoop: log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN) // we are offline and this is not an online check query return nil, ErrOffline + default: + log.Tracer(ctx).Debugf("resolver: failed to resolve %s: %s", q.FQDN, err) } } else { // no error diff --git a/resolver/resolve_test.go b/resolver/resolve_test.go deleted file mode 100644 index f621a635..00000000 --- a/resolver/resolve_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package resolver - -// DISABLE TESTING FOR NOW: find a way to have tests with the module system - -// import ( -// "testing" -// "time" -// -// "github.com/miekg/dns" -// ) - -// func TestResolve(t *testing.T) { -// Resolve("google.com.", dns.Type(dns.TypeA), 0) -// time.Sleep(200 * time.Millisecond) -// } diff --git a/resolver/resolver-pooled_test.go b/resolver/resolver-pooled_test.go index 5b460584..b7da984a 100644 --- a/resolver/resolver-pooled_test.go +++ b/resolver/resolver-pooled_test.go @@ -50,8 +50,7 @@ func TestClientPooling(t *testing.T) { } // create separate resolver for this test - resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config") - // resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config") + resolver, _, err := createResolver(testResolver, "config") if err != nil { t.Fatal(err) } diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index 8111bf94..3e84d942 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -22,14 +22,13 @@ const ( type TCPResolver struct { BasicResolverConn - clientTTL time.Duration - dnsClient *dns.Client - dnsConnection *dns.Conn - connInstanceID *uint32 + clientTTL time.Duration + dnsClient *dns.Client + clientStarted *abool.AtomicBool + connInstanceID *uint32 queries chan *dns.Msg inFlightQueries map[uint16]*InFlightQuery - clientStarted *abool.AtomicBool } // InFlightQuery represents an in flight query of a TCPResolver. @@ -86,254 +85,10 @@ func (tr *TCPResolver) UseTLS() *TCPResolver { return tr } -func (tr *TCPResolver) client(workerCtx context.Context) error { //nolint:gocognit,gocyclo // TODO - connTimer := time.NewTimer(tr.clientTTL) - connClosing := abool.New() - var connCtx context.Context - var cancelConnCtx func() - var recycleConn bool - var shuttingDown bool - var failCnt int - var incoming = make(chan *dns.Msg, 100) - -connMgmt: - for { - // cleanup old connection - if tr.dnsConnection != nil { - connClosing.Set() - _ = tr.dnsConnection.Close() - if cancelConnCtx != nil { - cancelConnCtx() - } - - tr.dnsConnection = nil - atomic.AddUint32(tr.connInstanceID, 1) - } - - // check if we are shutting down or failing - if shuttingDown || failCnt >= FailThreshold || tr.IsFailing() { - // reply to all waiting queries - tr.Lock() - for id, inFlight := range tr.inFlightQueries { - close(inFlight.Response) - delete(tr.inFlightQueries, id) - } - tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds - tr.Unlock() - - // hint network environment at failed connection - if failCnt >= FailThreshold { - netenv.ReportFailedConnection() - } - - // The linter said so. Don't even... - if cancelConnCtx != nil { - cancelConnCtx() - } - - return nil - } - - // wait until there is something to do - tr.Lock() - waiting := len(tr.inFlightQueries) - tr.Unlock() - if waiting > 0 { - // queue abandoned queries - ignoreBefore := time.Now().Add(-ignoreQueriesAfter) - currentConnInstanceID := atomic.LoadUint32(tr.connInstanceID) - tr.Lock() - for id, inFlight := range tr.inFlightQueries { - if inFlight.Started.Before(ignoreBefore) { - // remove - delete(tr.inFlightQueries, id) - } else if inFlight.ConnInstanceID != currentConnInstanceID { - inFlight.ConnInstanceID = currentConnInstanceID - // re-inject - select { - case tr.queries <- inFlight.Msg: - default: - log.Warningf("resolver: failed to re-inject abandoned query to %s", tr.resolver.Name) - } - } - } - tr.Unlock() - } else { - // wait for first query - select { - case <-workerCtx.Done(): - // abort - shuttingDown = true - continue connMgmt - case msg := <-tr.queries: - // re-insert, we will handle later - select { - case tr.queries <- msg: - default: - log.Warningf("resolver: failed to re-inject waking query to %s", tr.resolver.Name) - } - } - } - - // create connection - connCtx, cancelConnCtx = context.WithCancel(workerCtx) - // refresh dialer for authenticated local address - tr.dnsClient.Dialer = &net.Dialer{ - LocalAddr: getLocalAddr("tcp"), - Timeout: defaultConnectTimeout, - KeepAlive: defaultClientTTL, - } - // connect - c, err := tr.dnsClient.Dial(tr.resolver.ServerAddress) - if err != nil { - tr.ReportFailure() - failCnt++ - if tr.IsFailing() { - shuttingDown = true - } - log.Debugf("resolver: failed to connect to %s (%s)", tr.resolver.Name, tr.resolver.ServerAddress) - netenv.ReportFailedConnection() - continue connMgmt - } - tr.dnsConnection = c - log.Debugf("resolver: connected to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr()) - - // hint network environment at successful connection - netenv.ReportSuccessfulConnection() - - // reset timer - connTimer.Stop() - select { - case <-connTimer.C: // try to empty the timer - default: - } - connTimer.Reset(tr.clientTTL) - recycleConn = false - - // start reader - module.StartWorker("dns client reader", func(ctx context.Context) error { - conn := tr.dnsConnection - for { - msg, err := conn.ReadMsg() - if err != nil { - if connClosing.SetToIf(false, true) { - if cancelConnCtx != nil { - cancelConnCtx() - } - tr.ReportFailure() - failCnt++ - if tr.IsFailing() { - shuttingDown = true - } - log.Warningf("resolver: read error from %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err) - } - return nil - } - incoming <- msg - } - }) - - // query management - for { - select { - case <-workerCtx.Done(): - // shutting down - shuttingDown = true - continue connMgmt - case <-connCtx.Done(): - // connection error - continue connMgmt - case <-connTimer.C: - // client TTL expired, recycle connection - recycleConn = true - // trigger check - select { - case incoming <- nil: - default: - // quere is full anyway, do nothing - } - - case msg := <-tr.queries: - // write query - _ = tr.dnsConnection.SetWriteDeadline(time.Now().Add(tr.dnsClient.WriteTimeout)) - err := tr.dnsConnection.WriteMsg(msg) - if err != nil { - if connClosing.SetToIf(false, true) { - if cancelConnCtx != nil { - cancelConnCtx() - } - tr.ReportFailure() - failCnt++ - if tr.IsFailing() { - shuttingDown = true - } - log.Warningf("resolver: write error to %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err) - } - continue connMgmt - } - - case msg := <-incoming: - - if msg != nil { - // handle query from resolver - tr.Lock() - inFlight, ok := tr.inFlightQueries[msg.Id] - if ok { - delete(tr.inFlightQueries, msg.Id) - } - tr.Unlock() - - if ok { - select { - case inFlight.Response <- msg: - failCnt = 0 // reset fail counter - // responded! - default: - // save to cache, if enabled - if !inFlight.Query.NoCaching { - // persist to database - rrCache := inFlight.MakeCacheRecord(msg) - rrCache.Clean(600) - err = rrCache.Save() - if err != nil { - log.Warningf( - "resolver: failed to cache RR for %s%s: %s", - inFlight.Query.FQDN, - inFlight.Query.QType.String(), - err, - ) - } - } - } - } else { - log.Debugf( - "resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v", - tr.resolver.Name, - tr.dnsConnection.RemoteAddr(), - msg.Id, - msg.Question, - ) - } - } - - // check if we have finished all queries and want to recycle conn - if recycleConn { - tr.Lock() - activeQueries := len(tr.inFlightQueries) - tr.Unlock() - if activeQueries == 0 { - log.Debugf("resolver: recycling conn to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr()) - continue connMgmt - } - } - - } - } - - } -} - func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery { + // make sure client is started + tr.startClient() + // create msg msg := &dns.Msg{} msg.SetQuestion(q.FQDN, uint16(q.QType)) @@ -348,20 +103,39 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery { ConnInstanceID: atomic.LoadUint32(tr.connInstanceID), } tr.Lock() + // check for existing query + tr.ensureUniqueID(msg) + // add query to in flight registry tr.inFlightQueries[msg.Id] = inFlight tr.Unlock() // submit msg for writing tr.queries <- msg - // make sure client is started - if tr.clientStarted.SetToIf(false, true) { - module.StartServiceWorker("dns client", 10*time.Millisecond, tr.client) - } - return inFlight } +// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked. +func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) { + // try a random ID 10000 times + for i := 0; i < 10000; i++ { // don't try forever + _, exists := tr.inFlightQueries[msg.Id] + if !exists { + return // we are unique, yay! + } + msg.Id = dns.Id() // regenerate ID + } + // go through the complete space + var id uint16 + for ; id <= (1<<16)-1; id++ { // don't try forever + _, exists := tr.inFlightQueries[id] + if !exists { + msg.Id = id + return // we are unique, yay! + } + } +} + // Query executes the given query against the resolver. func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { // submit to client @@ -386,3 +160,291 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { return inFlight.MakeCacheRecord(reply), nil } + +type tcpResolverConnMgr struct { + tr *TCPResolver + responses chan *dns.Msg + failCnt int +} + +func (tr *TCPResolver) startClient() { + if tr.clientStarted.SetToIf(false, true) { + mgr := &tcpResolverConnMgr{ + tr: tr, + responses: make(chan *dns.Msg, 100), + } + module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run) + } +} + +func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { + // connection lifecycle loop + for { + // check if we are failing + if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() { + mgr.shutdown() + return nil + } + + // wait for work before creating connection + proceed := mgr.waitForWork(workerCtx) + if !proceed { + mgr.shutdown() + return nil + } + + // create connection + conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection(workerCtx) + if conn == nil { + mgr.failCnt++ + continue + } + + // hint network environment at successful connection + netenv.ReportSuccessfulConnection() + + // handle queries + proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx) + if !proceed { + mgr.shutdown() + return nil + } + } +} + +func (mgr *tcpResolverConnMgr) shutdown() { + // reply to all waiting queries + mgr.tr.Lock() + for id, inFlight := range mgr.tr.inFlightQueries { + close(inFlight.Response) + delete(mgr.tr.inFlightQueries, id) + } + mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds + atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter + mgr.tr.Unlock() + + // hint network environment at failed connection + if mgr.failCnt >= FailThreshold { + netenv.ReportFailedConnection() + } +} + +func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) { + // wait until there is something to do + mgr.tr.Lock() + waiting := len(mgr.tr.inFlightQueries) + mgr.tr.Unlock() + if waiting > 0 { + // queue abandoned queries + ignoreBefore := time.Now().Add(-ignoreQueriesAfter) + currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID) + mgr.tr.Lock() + defer mgr.tr.Unlock() + for id, inFlight := range mgr.tr.inFlightQueries { + if inFlight.Started.Before(ignoreBefore) { + // remove old queries + close(inFlight.Response) + delete(mgr.tr.inFlightQueries, id) + } else if inFlight.ConnInstanceID != currentConnInstanceID { + inFlight.ConnInstanceID = currentConnInstanceID + // re-inject queries that died with a previously failed connection + select { + case mgr.tr.queries <- inFlight.Msg: + default: + log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.Name) + } + } + // in-flight queries that match the connection instance ID are not changed. They are already in the queue. + } + return true + } + + // wait for first query + select { + case <-workerCtx.Done(): + return false + case msg := <-mgr.tr.queries: + // re-insert query, we will handle it later + module.StartWorker("reinject triggering dns query", func(ctx context.Context) error { + select { + case mgr.tr.queries <- msg: + case <-time.After(2 * time.Second): + log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Name) + } + return nil + }) + } + + return true +} + +func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) ( + conn *dns.Conn, + connClosing *abool.AtomicBool, + connCtx context.Context, + cancelConnCtx context.CancelFunc, +) { + // refresh dialer to set an authenticated local address + // TODO: lock dnsClient (only manager should run at any time, so this should not be an issue) + mgr.tr.dnsClient.Dialer = &net.Dialer{ + LocalAddr: getLocalAddr("tcp"), + Timeout: defaultConnectTimeout, + KeepAlive: defaultClientTTL, + } + // connect + var err error + conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress) + if err != nil { + log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.Name, mgr.tr.resolver.ServerAddress) + return nil, nil, nil, nil + } + connCtx, cancelConnCtx = context.WithCancel(workerCtx) + connClosing = abool.New() + + log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr()) + + // start reader + module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error { + return mgr.msgReader(conn, connClosing, cancelConnCtx) + }) + + return conn, connClosing, connCtx, cancelConnCtx +} + +func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter. + workerCtx context.Context, + conn *dns.Conn, + connClosing *abool.AtomicBool, + connCtx context.Context, + cancelConnCtx context.CancelFunc, +) (proceed bool) { + var readyToRecycle bool + ttlTimer := time.After(mgr.tr.clientTTL) + + // clean up connection + defer func() { + connClosing.Set() // silence connection errors + cancelConnCtx() + _ = conn.Close() + + // increase instance counter + atomic.AddUint32(mgr.tr.connInstanceID, 1) + }() + + for { + select { + case <-workerCtx.Done(): + // module shutdown + return false + + case <-connCtx.Done(): + // connection error + return true + + case <-ttlTimer: + // connection TTL reached, rebuild connection + // but handle all in flight queries first + readyToRecycle = true + // trigger check + select { + case mgr.responses <- nil: + default: + // queue is full, check will be triggered anyway + } + + case msg := <-mgr.tr.queries: + // write query + _ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout)) + err := conn.WriteMsg(msg) + if err != nil { + if connClosing.SetToIf(false, true) { + log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err) + } + return true + } + + case msg := <-mgr.responses: + if msg != nil { + mgr.handleQueryResponse(conn, msg) + } + + if readyToRecycle { + // check to see if we can recycle the connection + mgr.tr.Lock() + activeQueries := len(mgr.tr.inFlightQueries) + mgr.tr.Unlock() + if activeQueries == 0 { + log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr()) + return true + } + } + + } + } +} + +func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) { + // handle query from resolver + mgr.tr.Lock() + inFlight, ok := mgr.tr.inFlightQueries[msg.Id] + if ok { + delete(mgr.tr.inFlightQueries, msg.Id) + } + mgr.tr.Unlock() + + if !ok { + log.Debugf( + "resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v", + mgr.tr.resolver.Name, + conn.RemoteAddr(), + msg.Id, + msg.Question, + ) + return + } + + select { + case inFlight.Response <- msg: + mgr.failCnt = 0 // reset fail counter + // responded! + return + default: + // no one is listening for that response. + } + + // if caching is disabled we're done + if inFlight.Query.NoCaching { + return + } + + // persist to database + rrCache := inFlight.MakeCacheRecord(msg) + rrCache.Clean(600) + err := rrCache.Save() + if err != nil { + log.Warningf( + "resolver: failed to cache RR for %s%s: %s", + inFlight.Query.FQDN, + inFlight.Query.QType.String(), + err, + ) + } +} + +func (mgr *tcpResolverConnMgr) msgReader( + conn *dns.Conn, + connClosing *abool.AtomicBool, + cancelConnCtx context.CancelFunc, +) error { + defer cancelConnCtx() + for { + msg, err := conn.ReadMsg() + if err != nil { + if connClosing.SetToIf(false, true) { + log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err) + } + return nil + } + mgr.responses <- msg + } +} diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index d8d6dcff..d8ab14e4 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -21,7 +21,7 @@ func init() { flag.StringVar( &testResolver, "resolver", - "dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", + "dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "set custom resolver for testing", ) }