Merge pull request #103 from safing/feature/simplify-dns-resolvers

Simplify dns resolvers
This commit is contained in:
Daniel
2020-07-21 16:46:16 +02:00
committed by GitHub
9 changed files with 152 additions and 432 deletions

View File

@@ -2,6 +2,7 @@ package resolver
import (
"context"
"net"
"strings"
"time"
@@ -94,3 +95,21 @@ func start() error {
return nil
}
var (
localAddrFactory func(network string) net.Addr
)
// SetLocalAddrFactory supplies the intel package with a function to get permitted local addresses for connections.
func SetLocalAddrFactory(laf func(network string) net.Addr) {
if localAddrFactory == nil {
localAddrFactory = laf
}
}
func getLocalAddr(network string) net.Addr {
if localAddrFactory != nil {
return localAddrFactory(network)
}
return nil
}

View File

@@ -278,6 +278,10 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error
resolveLoop:
for i = 0; i < 2; i++ {
for _, resolver := range resolvers {
if module.IsStopping() {
return nil, errors.New("shutting down")
}
// check if resolver failed recently (on first run)
if i == 0 && resolver.Conn.IsFailing() {
log.Tracer(ctx).Tracef("resolver: skipping resolver %s, because it failed recently", resolver)
@@ -294,19 +298,25 @@ resolveLoop:
case errors.Is(err, ErrBlocked):
// some resolvers might also block
return nil, err
case errors.Is(err, ErrContinue):
continue
case netenv.GetOnlineStatus() == netenv.StatusOffline &&
!netenv.IsConnectivityDomain(q.FQDN):
log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN)
// we are offline and this is not an online check query
return nil, ErrOffline
case errors.Is(err, ErrContinue):
continue
case errors.Is(err, ErrTimeout):
resolver.Conn.ReportFailure()
log.Tracer(ctx).Debugf("resolver: query to %s timed out", resolver.GetName())
continue
default:
// includes ErrTimeout
log.Tracer(ctx).Debugf("resolver: failed to resolve %s: %s", q.FQDN, err)
resolver.Conn.ReportFailure()
log.Tracer(ctx).Debugf("resolver: query to %s failed: %s", resolver.GetName(), err)
continue
}
}
if rrCache == nil {
// Defensive: This should normally not happen.
continue
}
break resolveLoop

View File

@@ -241,12 +241,12 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
}
switch entry.(type) {
case *dns.A:
scavengedRecords[fmt.Sprintf("%s_A", entry.Header().Name)] = entry
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = entry
case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%s_AAAA", entry.Header().Name)] = entry
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = entry
case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%s_PTR", entry.Header().Name)] = entry
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = entry
}
}
}

View File

@@ -0,0 +1,92 @@
package resolver
import (
"context"
"net"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netenv"
)
var (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 3 * time.Second // dns query
defaultConnectTimeout = 5 * time.Second // tcp/tls
)
// PlainResolver is a resolver using plain DNS.
type PlainResolver struct {
BasicResolverConn
}
// NewPlainResolver returns a new TPCResolver.
func NewPlainResolver(resolver *Resolver) *PlainResolver {
return &PlainResolver{
BasicResolverConn: BasicResolverConn{
resolver: resolver,
},
}
}
// Query executes the given query against the resolver.
func (pr *PlainResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get timeout from context and config
var timeout time.Duration
if deadline, ok := ctx.Deadline(); !ok {
timeout = 0
} else {
timeout = time.Until(deadline)
}
if timeout > defaultRequestTimeout {
timeout = defaultRequestTimeout
}
// create client
dnsClient := &dns.Client{
Timeout: timeout,
Dialer: &net.Dialer{
Timeout: timeout,
LocalAddr: getLocalAddr("udp"),
},
}
// query server
reply, ttl, err := dnsClient.Exchange(dnsQuery, pr.resolver.ServerAddress)
log.Tracer(ctx).Tracef("resolver: query took %s", ttl)
// error handling
if err != nil {
// Hint network environment at failed connection if err is not a timeout.
if nErr, ok := err.(net.Error); ok && !nErr.Timeout() {
netenv.ReportFailedConnection()
}
return nil, err
}
// check if blocked
if pr.resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{pr.resolver.GetName()}
}
// hint network environment at successful connection
netenv.ReportSuccessfulConnection()
newRecord := &RRCache{
Domain: q.FQDN,
Question: q.QType,
Answer: reply.Answer,
Ns: reply.Ns,
Extra: reply.Extra,
Server: pr.resolver.Server,
ServerScope: pr.resolver.ServerIPScope,
}
// TODO: check if reply.Answer is valid
return newRecord, nil
}

View File

@@ -1,187 +0,0 @@
package resolver
import (
"context"
"crypto/tls"
"net"
"sync"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/utils"
)
var (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 3 * time.Second // dns query
defaultConnectTimeout = 5 * time.Second // tcp/tls
connectionEOLGracePeriod = 7 * time.Second
localAddrFactory func(network string) net.Addr
)
// SetLocalAddrFactory supplies the intel package with a function to get permitted local addresses for connections.
func SetLocalAddrFactory(laf func(network string) net.Addr) {
if localAddrFactory == nil {
localAddrFactory = laf
}
}
func getLocalAddr(network string) net.Addr {
if localAddrFactory != nil {
return localAddrFactory(network)
}
return nil
}
type dnsClientManager struct {
lock sync.Mutex
// set by creator
resolver *Resolver
ttl time.Duration // force refresh of connection to reduce traceability
factory func() *dns.Client
// internal
pool utils.StablePool
}
type dnsClient struct {
mgr *dnsClientManager
client *dns.Client
conn *dns.Conn
useUntil time.Time
}
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
if dc.conn == nil {
dc.conn, err = dc.client.Dial(dc.mgr.resolver.ServerAddress)
if err != nil {
return nil, false, err
}
return dc.conn, true, nil
}
return dc.conn, false, nil
}
func (dc *dnsClient) addToPool() {
dc.mgr.pool.Put(dc)
}
func (dc *dnsClient) destroy() {
if dc.conn != nil {
_ = dc.conn.Close()
}
}
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
resolver: resolver,
ttl: 0, // new client for every request, as we need to randomize the port
factory: func() *dns.Client {
return &dns.Client{
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("udp"),
},
}
},
}
}
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
resolver: resolver,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp",
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
},
}
},
}
}
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
resolver: resolver,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp-tls",
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: resolver.VerifyDomain,
// TODO: use portbase rng
},
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
},
}
},
}
}
func (cm *dnsClientManager) getDNSClient() *dnsClient {
cm.lock.Lock()
defer cm.lock.Unlock()
// return new immediately if a new client should be used for every request
if cm.ttl == 0 {
return &dnsClient{
mgr: cm,
client: cm.factory(),
}
}
// get cached client from pool
now := time.Now().UTC()
poolLoop:
for {
dc, ok := cm.pool.Get().(*dnsClient)
switch {
case !ok || dc == nil: // cache empty (probably, pool may always return nil!)
break poolLoop // create new
case now.After(dc.useUntil):
continue // get next
default:
return dc
}
}
// no available in pool, create new
newClient := &dnsClient{
mgr: cm,
client: cm.factory(),
useUntil: now.Add(cm.ttl),
}
newClient.startCleaner()
return newClient
}
// startCleaner waits for EOL of the client and then removes it from the pool.
func (dc *dnsClient) startCleaner() {
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
select {
case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod):
// destroy
case <-ctx.Done():
// give a short time before kill for graceful request completion
time.Sleep(100 * time.Millisecond)
}
dc.destroy()
return nil
})
}

View File

@@ -1,82 +0,0 @@
package resolver
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
)
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
dnsClient := brc.clientManager.getDNSClient()
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get connection
conn, new, err := dnsClient.getConn()
if err != nil {
t.Logf("failed to connect: %s", err) //nolint:staticcheck
wg.Done()
return
}
if new {
atomic.AddUint32(newCnt, 1)
}
// query server
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
if err != nil {
t.Logf("client failed: %s", err) //nolint:staticcheck
wg.Done()
return
}
if reply == nil {
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
}
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
dnsClient.addToPool()
wg.Done()
}
func TestClientPooling(t *testing.T) {
// skip if short - this test depends on the Internet and might fail randomly
if testing.Short() {
t.Skip()
}
// create separate resolver for this test
resolver, _, err := createResolver(testResolver, "config")
if err != nil {
t.Fatal(err)
}
brc := &BasicResolverConn{
clientManager: clientManagerFactory(resolver.ServerType)(resolver),
resolver: resolver,
}
resolver.Conn = brc
started := time.Now()
wg := &sync.WaitGroup{}
var newCnt uint32
for i := 0; i < 10; i++ {
wg.Add(10)
for j := 0; j < 10; j++ {
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
if newCnt > uint32(10+i) {
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
}
}
t.Logf("time taken: %s", time.Since(started))
}

View File

@@ -145,7 +145,6 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
select {
case reply = <-inFlight.Response:
case <-time.After(defaultRequestTimeout):
tr.ReportFailure()
return nil, ErrTimeout
}
@@ -180,6 +179,14 @@ func (tr *TCPResolver) startClient() {
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
// connection lifecycle loop
for {
// check if we are shutting down
select {
case <-workerCtx.Done():
mgr.shutdown()
return nil
default:
}
// check if we are failing
if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() {
mgr.shutdown()
@@ -251,7 +258,7 @@ func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed b
select {
case mgr.tr.queries <- inFlight.Msg:
default:
log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.Name)
log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.GetName())
}
}
// in-flight queries that match the connection instance ID are not changed. They are already in the queue.
@@ -269,7 +276,7 @@ func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed b
select {
case mgr.tr.queries <- msg:
case <-time.After(2 * time.Second):
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Name)
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.GetName())
}
return nil
})
@@ -295,13 +302,13 @@ func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) (
var err error
conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
if err != nil {
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.Name, mgr.tr.resolver.ServerAddress)
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.GetName(), mgr.tr.resolver.ServerAddress)
return nil, nil, nil, nil
}
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
connClosing = abool.New()
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.GetName(), conn.RemoteAddr())
// start reader
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
@@ -358,7 +365,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
err := conn.WriteMsg(msg)
if err != nil {
if connClosing.SetToIf(false, true) {
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err)
}
return true
}
@@ -374,7 +381,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
activeQueries := len(mgr.tr.inFlightQueries)
mgr.tr.Unlock()
if activeQueries == 0 {
log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.GetName(), conn.RemoteAddr())
return true
}
}
@@ -395,7 +402,7 @@ func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg)
if !ok {
log.Debugf(
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
mgr.tr.resolver.Name,
mgr.tr.resolver.GetName(),
conn.RemoteAddr(),
msg.Id,
msg.Question,
@@ -441,7 +448,7 @@ func (mgr *tcpResolverConnMgr) msgReader(
msg, err := conn.ReadMsg()
if err != nil {
if connClosing.SetToIf(false, true) {
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err)
}
return nil
}

View File

@@ -2,13 +2,11 @@ package resolver
import (
"context"
"errors"
"net"
"sync"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netenv"
)
@@ -106,11 +104,10 @@ type ResolverConn interface { //nolint:go-lint // TODO
type BasicResolverConn struct {
sync.Mutex // for lastFail
resolver *Resolver
clientManager *dnsClientManager
resolver *Resolver
lastFail time.Time
fails int
failingUntil time.Time
fails int
}
// ReportFailure reports that an error occurred with this resolver.
@@ -122,17 +119,12 @@ func (brc *BasicResolverConn) ReportFailure() {
brc.Lock()
defer brc.Unlock()
now := time.Now().UTC()
failDuration := time.Duration(nameserverRetryRate()) * time.Second
// reset fail counter if currently not failing
if now.Add(-failDuration).After(brc.lastFail) {
brc.fails++
if brc.fails > FailThreshold {
brc.failingUntil = time.Now().Add(time.Duration(nameserverRetryRate()) * time.Second)
brc.fails = 0
}
// update
brc.lastFail = now
brc.fails++
}
// IsFailing returns if this resolver is currently failing.
@@ -140,122 +132,5 @@ func (brc *BasicResolverConn) IsFailing() bool {
brc.Lock()
defer brc.Unlock()
failDuration := time.Duration(nameserverRetryRate()) * time.Second
return brc.fails >= FailThreshold && time.Now().UTC().Add(-failDuration).Before(brc.lastFail)
}
// Query executes the given query against the resolver.
func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, error) {
// convenience
resolver := brc.resolver
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// start
var reply *dns.Msg
var ttl time.Duration
var err error
var conn *dns.Conn
var new bool
var tries int
for ; tries < 3; tries++ {
// first get connection
dc := brc.clientManager.getDNSClient()
conn, new, err = dc.getConn()
if err != nil {
log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err)
// remove client from pool
dc.destroy()
// report that resolver had an error
brc.ReportFailure()
// hint network environment at failed connection
netenv.ReportFailedConnection()
// TODO: handle special cases
// 1. connect: network is unreachable
// 2. timeout
// try again
continue
}
if new {
log.Tracer(ctx).Tracef("resolver: created new connection to %s (%s)", resolver.Name, resolver.ServerAddress)
} else {
log.Tracer(ctx).Tracef("resolver: reusing connection to %s (%s)", resolver.Name, resolver.ServerAddress)
}
// query server
reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn)
log.Tracer(ctx).Tracef("resolver: query took %s", ttl)
// error handling
if err != nil {
log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err)
// remove client from pool
dc.destroy()
// temporary error
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server)
// try again
continue
}
// report failed if dns (nothing happens at getConn())
if resolver.ServerType == ServerTypeDNS {
// report that resolver had an error
brc.ReportFailure()
// hint network environment at failed connection
netenv.ReportFailedConnection()
}
// permanent error
break
} else if reply == nil {
// remove client from pool
dc.destroy()
log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server)
return nil, errors.New("internal error")
}
// make client available (again)
dc.addToPool()
if resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{resolver.GetName()}
}
// no error
break
}
if err != nil {
return nil, err
// TODO: mark as failed
} else if reply == nil {
log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), tries+1)
return nil, errors.New("internal error")
}
// hint network environment at successful connection
netenv.ReportSuccessfulConnection()
newRecord := &RRCache{
Domain: q.FQDN,
Question: q.QType,
Answer: reply.Answer,
Ns: reply.Ns,
Extra: reply.Extra,
Server: resolver.Server,
ServerScope: resolver.ServerIPScope,
}
// TODO: check if reply.Answer is valid
return newRecord, nil
return time.Now().Before(brc.failingUntil)
}

View File

@@ -69,22 +69,8 @@ func resolverConnFactory(resolver *Resolver) ResolverConn {
return NewTCPResolver(resolver)
case ServerTypeDoT:
return NewTCPResolver(resolver).UseTLS()
default:
return &BasicResolverConn{
clientManager: clientManagerFactory(resolver.ServerType)(resolver),
resolver: resolver,
}
}
}
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
switch serverType {
case ServerTypeDNS:
return newDNSClientManager
case ServerTypeDoT:
return newTLSClientManager
case ServerTypeTCP:
return newTCPClientManager
return NewPlainResolver(resolver)
default:
return nil
}