Merge pull request #90 from safing/fix/post-factum-reviews-1

Implement Post Factum Review Suggestions
This commit is contained in:
Daniel
2020-07-14 15:26:47 +02:00
committed by GitHub
13 changed files with 383 additions and 343 deletions

View File

@@ -228,13 +228,12 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error
// check if we are online
if netenv.GetOnlineStatus() == netenv.StatusOffline {
if netenv.IsOnlineStatusTestDomain(q.FQDN) {
log.Tracer(ctx).Debugf("resolver: permitting online status test domain %s to resolve even though offline", q.FQDN)
} else {
if !netenv.IsOnlineStatusTestDomain(q.FQDN) {
log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN)
// we are offline and this is not an online check query
return nil, ErrOffline
}
log.Tracer(ctx).Debugf("resolver: permitting online status test domain %s to resolve even though offline", q.FQDN)
}
// start resolving
@@ -253,9 +252,6 @@ resolveLoop:
// resolve
rrCache, err = resolver.Conn.Query(ctx, q)
if err != nil {
// TODO: check if we are online?
switch {
case errors.Is(err, ErrNotFound):
// NXDomain, or similar
@@ -268,6 +264,8 @@ resolveLoop:
log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN)
// we are offline and this is not an online check query
return nil, ErrOffline
default:
log.Tracer(ctx).Debugf("resolver: failed to resolve %s: %s", q.FQDN, err)
}
} else {
// no error

View File

@@ -1,15 +0,0 @@
package resolver
// DISABLE TESTING FOR NOW: find a way to have tests with the module system
// import (
// "testing"
// "time"
//
// "github.com/miekg/dns"
// )
// func TestResolve(t *testing.T) {
// Resolve("google.com.", dns.Type(dns.TypeA), 0)
// time.Sleep(200 * time.Millisecond)
// }

View File

@@ -50,8 +50,7 @@ func TestClientPooling(t *testing.T) {
}
// create separate resolver for this test
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
// resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config")
resolver, _, err := createResolver(testResolver, "config")
if err != nil {
t.Fatal(err)
}

View File

@@ -22,14 +22,13 @@ const (
type TCPResolver struct {
BasicResolverConn
clientTTL time.Duration
dnsClient *dns.Client
dnsConnection *dns.Conn
connInstanceID *uint32
clientTTL time.Duration
dnsClient *dns.Client
clientStarted *abool.AtomicBool
connInstanceID *uint32
queries chan *dns.Msg
inFlightQueries map[uint16]*InFlightQuery
clientStarted *abool.AtomicBool
}
// InFlightQuery represents an in flight query of a TCPResolver.
@@ -86,254 +85,10 @@ func (tr *TCPResolver) UseTLS() *TCPResolver {
return tr
}
func (tr *TCPResolver) client(workerCtx context.Context) error { //nolint:gocognit,gocyclo // TODO
connTimer := time.NewTimer(tr.clientTTL)
connClosing := abool.New()
var connCtx context.Context
var cancelConnCtx func()
var recycleConn bool
var shuttingDown bool
var failCnt int
var incoming = make(chan *dns.Msg, 100)
connMgmt:
for {
// cleanup old connection
if tr.dnsConnection != nil {
connClosing.Set()
_ = tr.dnsConnection.Close()
if cancelConnCtx != nil {
cancelConnCtx()
}
tr.dnsConnection = nil
atomic.AddUint32(tr.connInstanceID, 1)
}
// check if we are shutting down or failing
if shuttingDown || failCnt >= FailThreshold || tr.IsFailing() {
// reply to all waiting queries
tr.Lock()
for id, inFlight := range tr.inFlightQueries {
close(inFlight.Response)
delete(tr.inFlightQueries, id)
}
tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
tr.Unlock()
// hint network environment at failed connection
if failCnt >= FailThreshold {
netenv.ReportFailedConnection()
}
// The linter said so. Don't even...
if cancelConnCtx != nil {
cancelConnCtx()
}
return nil
}
// wait until there is something to do
tr.Lock()
waiting := len(tr.inFlightQueries)
tr.Unlock()
if waiting > 0 {
// queue abandoned queries
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
currentConnInstanceID := atomic.LoadUint32(tr.connInstanceID)
tr.Lock()
for id, inFlight := range tr.inFlightQueries {
if inFlight.Started.Before(ignoreBefore) {
// remove
delete(tr.inFlightQueries, id)
} else if inFlight.ConnInstanceID != currentConnInstanceID {
inFlight.ConnInstanceID = currentConnInstanceID
// re-inject
select {
case tr.queries <- inFlight.Msg:
default:
log.Warningf("resolver: failed to re-inject abandoned query to %s", tr.resolver.Name)
}
}
}
tr.Unlock()
} else {
// wait for first query
select {
case <-workerCtx.Done():
// abort
shuttingDown = true
continue connMgmt
case msg := <-tr.queries:
// re-insert, we will handle later
select {
case tr.queries <- msg:
default:
log.Warningf("resolver: failed to re-inject waking query to %s", tr.resolver.Name)
}
}
}
// create connection
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
// refresh dialer for authenticated local address
tr.dnsClient.Dialer = &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
}
// connect
c, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
if err != nil {
tr.ReportFailure()
failCnt++
if tr.IsFailing() {
shuttingDown = true
}
log.Debugf("resolver: failed to connect to %s (%s)", tr.resolver.Name, tr.resolver.ServerAddress)
netenv.ReportFailedConnection()
continue connMgmt
}
tr.dnsConnection = c
log.Debugf("resolver: connected to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr())
// hint network environment at successful connection
netenv.ReportSuccessfulConnection()
// reset timer
connTimer.Stop()
select {
case <-connTimer.C: // try to empty the timer
default:
}
connTimer.Reset(tr.clientTTL)
recycleConn = false
// start reader
module.StartWorker("dns client reader", func(ctx context.Context) error {
conn := tr.dnsConnection
for {
msg, err := conn.ReadMsg()
if err != nil {
if connClosing.SetToIf(false, true) {
if cancelConnCtx != nil {
cancelConnCtx()
}
tr.ReportFailure()
failCnt++
if tr.IsFailing() {
shuttingDown = true
}
log.Warningf("resolver: read error from %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err)
}
return nil
}
incoming <- msg
}
})
// query management
for {
select {
case <-workerCtx.Done():
// shutting down
shuttingDown = true
continue connMgmt
case <-connCtx.Done():
// connection error
continue connMgmt
case <-connTimer.C:
// client TTL expired, recycle connection
recycleConn = true
// trigger check
select {
case incoming <- nil:
default:
// quere is full anyway, do nothing
}
case msg := <-tr.queries:
// write query
_ = tr.dnsConnection.SetWriteDeadline(time.Now().Add(tr.dnsClient.WriteTimeout))
err := tr.dnsConnection.WriteMsg(msg)
if err != nil {
if connClosing.SetToIf(false, true) {
if cancelConnCtx != nil {
cancelConnCtx()
}
tr.ReportFailure()
failCnt++
if tr.IsFailing() {
shuttingDown = true
}
log.Warningf("resolver: write error to %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err)
}
continue connMgmt
}
case msg := <-incoming:
if msg != nil {
// handle query from resolver
tr.Lock()
inFlight, ok := tr.inFlightQueries[msg.Id]
if ok {
delete(tr.inFlightQueries, msg.Id)
}
tr.Unlock()
if ok {
select {
case inFlight.Response <- msg:
failCnt = 0 // reset fail counter
// responded!
default:
// save to cache, if enabled
if !inFlight.Query.NoCaching {
// persist to database
rrCache := inFlight.MakeCacheRecord(msg)
rrCache.Clean(600)
err = rrCache.Save()
if err != nil {
log.Warningf(
"resolver: failed to cache RR for %s%s: %s",
inFlight.Query.FQDN,
inFlight.Query.QType.String(),
err,
)
}
}
}
} else {
log.Debugf(
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
tr.resolver.Name,
tr.dnsConnection.RemoteAddr(),
msg.Id,
msg.Question,
)
}
}
// check if we have finished all queries and want to recycle conn
if recycleConn {
tr.Lock()
activeQueries := len(tr.inFlightQueries)
tr.Unlock()
if activeQueries == 0 {
log.Debugf("resolver: recycling conn to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr())
continue connMgmt
}
}
}
}
}
}
func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
// make sure client is started
tr.startClient()
// create msg
msg := &dns.Msg{}
msg.SetQuestion(q.FQDN, uint16(q.QType))
@@ -348,20 +103,39 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
ConnInstanceID: atomic.LoadUint32(tr.connInstanceID),
}
tr.Lock()
// check for existing query
tr.ensureUniqueID(msg)
// add query to in flight registry
tr.inFlightQueries[msg.Id] = inFlight
tr.Unlock()
// submit msg for writing
tr.queries <- msg
// make sure client is started
if tr.clientStarted.SetToIf(false, true) {
module.StartServiceWorker("dns client", 10*time.Millisecond, tr.client)
}
return inFlight
}
// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked.
func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
// try a random ID 10000 times
for i := 0; i < 10000; i++ { // don't try forever
_, exists := tr.inFlightQueries[msg.Id]
if !exists {
return // we are unique, yay!
}
msg.Id = dns.Id() // regenerate ID
}
// go through the complete space
var id uint16
for ; id <= (1<<16)-1; id++ { // don't try forever
_, exists := tr.inFlightQueries[id]
if !exists {
msg.Id = id
return // we are unique, yay!
}
}
}
// Query executes the given query against the resolver.
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// submit to client
@@ -386,3 +160,291 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
return inFlight.MakeCacheRecord(reply), nil
}
type tcpResolverConnMgr struct {
tr *TCPResolver
responses chan *dns.Msg
failCnt int
}
func (tr *TCPResolver) startClient() {
if tr.clientStarted.SetToIf(false, true) {
mgr := &tcpResolverConnMgr{
tr: tr,
responses: make(chan *dns.Msg, 100),
}
module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run)
}
}
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
// connection lifecycle loop
for {
// check if we are failing
if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() {
mgr.shutdown()
return nil
}
// wait for work before creating connection
proceed := mgr.waitForWork(workerCtx)
if !proceed {
mgr.shutdown()
return nil
}
// create connection
conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection(workerCtx)
if conn == nil {
mgr.failCnt++
continue
}
// hint network environment at successful connection
netenv.ReportSuccessfulConnection()
// handle queries
proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx)
if !proceed {
mgr.shutdown()
return nil
}
}
}
func (mgr *tcpResolverConnMgr) shutdown() {
// reply to all waiting queries
mgr.tr.Lock()
for id, inFlight := range mgr.tr.inFlightQueries {
close(inFlight.Response)
delete(mgr.tr.inFlightQueries, id)
}
mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter
mgr.tr.Unlock()
// hint network environment at failed connection
if mgr.failCnt >= FailThreshold {
netenv.ReportFailedConnection()
}
}
func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) {
// wait until there is something to do
mgr.tr.Lock()
waiting := len(mgr.tr.inFlightQueries)
mgr.tr.Unlock()
if waiting > 0 {
// queue abandoned queries
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID)
mgr.tr.Lock()
defer mgr.tr.Unlock()
for id, inFlight := range mgr.tr.inFlightQueries {
if inFlight.Started.Before(ignoreBefore) {
// remove old queries
close(inFlight.Response)
delete(mgr.tr.inFlightQueries, id)
} else if inFlight.ConnInstanceID != currentConnInstanceID {
inFlight.ConnInstanceID = currentConnInstanceID
// re-inject queries that died with a previously failed connection
select {
case mgr.tr.queries <- inFlight.Msg:
default:
log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.Name)
}
}
// in-flight queries that match the connection instance ID are not changed. They are already in the queue.
}
return true
}
// wait for first query
select {
case <-workerCtx.Done():
return false
case msg := <-mgr.tr.queries:
// re-insert query, we will handle it later
module.StartWorker("reinject triggering dns query", func(ctx context.Context) error {
select {
case mgr.tr.queries <- msg:
case <-time.After(2 * time.Second):
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Name)
}
return nil
})
}
return true
}
func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) (
conn *dns.Conn,
connClosing *abool.AtomicBool,
connCtx context.Context,
cancelConnCtx context.CancelFunc,
) {
// refresh dialer to set an authenticated local address
// TODO: lock dnsClient (only manager should run at any time, so this should not be an issue)
mgr.tr.dnsClient.Dialer = &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
}
// connect
var err error
conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
if err != nil {
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.Name, mgr.tr.resolver.ServerAddress)
return nil, nil, nil, nil
}
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
connClosing = abool.New()
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
// start reader
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
return mgr.msgReader(conn, connClosing, cancelConnCtx)
})
return conn, connClosing, connCtx, cancelConnCtx
}
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
workerCtx context.Context,
conn *dns.Conn,
connClosing *abool.AtomicBool,
connCtx context.Context,
cancelConnCtx context.CancelFunc,
) (proceed bool) {
var readyToRecycle bool
ttlTimer := time.After(mgr.tr.clientTTL)
// clean up connection
defer func() {
connClosing.Set() // silence connection errors
cancelConnCtx()
_ = conn.Close()
// increase instance counter
atomic.AddUint32(mgr.tr.connInstanceID, 1)
}()
for {
select {
case <-workerCtx.Done():
// module shutdown
return false
case <-connCtx.Done():
// connection error
return true
case <-ttlTimer:
// connection TTL reached, rebuild connection
// but handle all in flight queries first
readyToRecycle = true
// trigger check
select {
case mgr.responses <- nil:
default:
// queue is full, check will be triggered anyway
}
case msg := <-mgr.tr.queries:
// write query
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
err := conn.WriteMsg(msg)
if err != nil {
if connClosing.SetToIf(false, true) {
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
}
return true
}
case msg := <-mgr.responses:
if msg != nil {
mgr.handleQueryResponse(conn, msg)
}
if readyToRecycle {
// check to see if we can recycle the connection
mgr.tr.Lock()
activeQueries := len(mgr.tr.inFlightQueries)
mgr.tr.Unlock()
if activeQueries == 0 {
log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
return true
}
}
}
}
}
func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) {
// handle query from resolver
mgr.tr.Lock()
inFlight, ok := mgr.tr.inFlightQueries[msg.Id]
if ok {
delete(mgr.tr.inFlightQueries, msg.Id)
}
mgr.tr.Unlock()
if !ok {
log.Debugf(
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
mgr.tr.resolver.Name,
conn.RemoteAddr(),
msg.Id,
msg.Question,
)
return
}
select {
case inFlight.Response <- msg:
mgr.failCnt = 0 // reset fail counter
// responded!
return
default:
// no one is listening for that response.
}
// if caching is disabled we're done
if inFlight.Query.NoCaching {
return
}
// persist to database
rrCache := inFlight.MakeCacheRecord(msg)
rrCache.Clean(600)
err := rrCache.Save()
if err != nil {
log.Warningf(
"resolver: failed to cache RR for %s%s: %s",
inFlight.Query.FQDN,
inFlight.Query.QType.String(),
err,
)
}
}
func (mgr *tcpResolverConnMgr) msgReader(
conn *dns.Conn,
connClosing *abool.AtomicBool,
cancelConnCtx context.CancelFunc,
) error {
defer cancelConnCtx()
for {
msg, err := conn.ReadMsg()
if err != nil {
if connClosing.SetToIf(false, true) {
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
}
return nil
}
mgr.responses <- msg
}
}

View File

@@ -21,7 +21,7 @@ func init() {
flag.StringVar(
&testResolver,
"resolver",
"dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty",
"dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip",
"set custom resolver for testing",
)
}