Merge pull request #90 from safing/fix/post-factum-reviews-1
Implement Post Factum Review Suggestions
This commit is contained in:
@@ -261,7 +261,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
|
|||||||
rrCache, err := resolver.Resolve(ctx, q)
|
rrCache, err := resolver.Resolve(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: analyze nxdomain requests, malware could be trying DGA-domains
|
// TODO: analyze nxdomain requests, malware could be trying DGA-domains
|
||||||
tracer.Warningf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err)
|
tracer.Debugf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err)
|
||||||
|
|
||||||
if errors.Is(err, resolver.ErrBlocked) {
|
if errors.Is(err, resolver.ErrBlocked) {
|
||||||
conn.Block(err.Error())
|
conn.Block(err.Error())
|
||||||
|
|||||||
@@ -106,8 +106,6 @@ var (
|
|||||||
|
|
||||||
captivePortalURL string
|
captivePortalURL string
|
||||||
captivePortalLock sync.Mutex
|
captivePortalLock sync.Mutex
|
||||||
|
|
||||||
waitForever = make(chan time.Time)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -234,7 +232,7 @@ func getDynamicStatusTrigger() <-chan time.Time {
|
|||||||
case StatusSemiOnline:
|
case StatusSemiOnline:
|
||||||
return time.After(5 * time.Minute)
|
return time.After(5 * time.Minute)
|
||||||
case StatusOnline:
|
case StatusOnline:
|
||||||
return waitForever
|
return nil
|
||||||
default: // unknown status
|
default: // unknown status
|
||||||
return time.After(5 * time.Minute)
|
return time.After(5 * time.Minute)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,31 +21,18 @@ var (
|
|||||||
pidsByUser = make(map[int][]int)
|
pidsByUser = make(map[int][]int)
|
||||||
)
|
)
|
||||||
|
|
||||||
// FindConnectionPID returns the pid of the given socket info.
|
// GetPID returns the already existing pid of the given socket info or searches for it.
|
||||||
func FindConnectionPID(socketInfo *socket.ConnectionInfo) (pid int) {
|
// This also acts as a getter for socket.*Info.PID, as locking for that occurs here.
|
||||||
|
func GetPID(socketInfo socket.Info) (pid int) {
|
||||||
pidsByUserLock.Lock()
|
pidsByUserLock.Lock()
|
||||||
defer pidsByUserLock.Unlock()
|
defer pidsByUserLock.Unlock()
|
||||||
|
|
||||||
if socketInfo.PID != socket.UnidentifiedProcessID {
|
if socketInfo.GetPID() != socket.UnidentifiedProcessID {
|
||||||
return socket.UnidentifiedProcessID
|
return socketInfo.GetPID()
|
||||||
}
|
}
|
||||||
|
|
||||||
pid = findPID(socketInfo.UID, socketInfo.Inode)
|
pid = findPID(socketInfo.GetUID(), socketInfo.GetInode())
|
||||||
socketInfo.PID = pid
|
socketInfo.SetPID(pid)
|
||||||
return pid
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindBindPID returns the pid of the given socket info.
|
|
||||||
func FindBindPID(socketInfo *socket.BindInfo) (pid int) {
|
|
||||||
pidsByUserLock.Lock()
|
|
||||||
defer pidsByUserLock.Unlock()
|
|
||||||
|
|
||||||
if socketInfo.PID != socket.UnidentifiedProcessID {
|
|
||||||
return socket.UnidentifiedProcessID
|
|
||||||
}
|
|
||||||
|
|
||||||
pid = findPID(socketInfo.UID, socketInfo.Inode)
|
|
||||||
socketInfo.PID = pid
|
|
||||||
return pid
|
return pid
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,6 +162,9 @@ entryLoop:
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readDirNames only reads the directory names. Using ioutil.ReadDir() would call `lstat` on every
|
||||||
|
// resulting directory name, which we don't need. This function will be called a lot, so we should
|
||||||
|
// refrain from unnecessary work.
|
||||||
func readDirNames(dir string) (names []string) {
|
func readDirNames(dir string) (names []string) {
|
||||||
file, err := os.Open(dir)
|
file, err := os.Open(dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -14,12 +14,12 @@ func TestSockets(t *testing.T) {
|
|||||||
}
|
}
|
||||||
fmt.Println("\nTCP 4 connections:")
|
fmt.Println("\nTCP 4 connections:")
|
||||||
for _, connection := range connections {
|
for _, connection := range connections {
|
||||||
pid := FindConnectionPID(connection)
|
pid := GetPID(connection)
|
||||||
fmt.Printf("%d: %+v\n", pid, connection)
|
fmt.Printf("%d: %+v\n", pid, connection)
|
||||||
}
|
}
|
||||||
fmt.Println("\nTCP 4 listeners:")
|
fmt.Println("\nTCP 4 listeners:")
|
||||||
for _, listener := range listeners {
|
for _, listener := range listeners {
|
||||||
pid := FindBindPID(listener)
|
pid := GetPID(listener)
|
||||||
fmt.Printf("%d: %+v\n", pid, listener)
|
fmt.Printf("%d: %+v\n", pid, listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,12 +29,12 @@ func TestSockets(t *testing.T) {
|
|||||||
}
|
}
|
||||||
fmt.Println("\nTCP 6 connections:")
|
fmt.Println("\nTCP 6 connections:")
|
||||||
for _, connection := range connections {
|
for _, connection := range connections {
|
||||||
pid := FindConnectionPID(connection)
|
pid := GetPID(connection)
|
||||||
fmt.Printf("%d: %+v\n", pid, connection)
|
fmt.Printf("%d: %+v\n", pid, connection)
|
||||||
}
|
}
|
||||||
fmt.Println("\nTCP 6 listeners:")
|
fmt.Println("\nTCP 6 listeners:")
|
||||||
for _, listener := range listeners {
|
for _, listener := range listeners {
|
||||||
pid := FindBindPID(listener)
|
pid := GetPID(listener)
|
||||||
fmt.Printf("%d: %+v\n", pid, listener)
|
fmt.Printf("%d: %+v\n", pid, listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ func TestSockets(t *testing.T) {
|
|||||||
}
|
}
|
||||||
fmt.Println("\nUDP 4 binds:")
|
fmt.Println("\nUDP 4 binds:")
|
||||||
for _, bind := range binds {
|
for _, bind := range binds {
|
||||||
pid := FindBindPID(bind)
|
pid := GetPID(bind)
|
||||||
fmt.Printf("%d: %+v\n", pid, bind)
|
fmt.Printf("%d: %+v\n", pid, bind)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ func TestSockets(t *testing.T) {
|
|||||||
}
|
}
|
||||||
fmt.Println("\nUDP 6 binds:")
|
fmt.Println("\nUDP 6 binds:")
|
||||||
for _, bind := range binds {
|
for _, bind := range binds {
|
||||||
pid := FindBindPID(bind)
|
pid := GetPID(bind)
|
||||||
fmt.Printf("%d: %+v\n", pid, bind)
|
fmt.Printf("%d: %+v\n", pid, bind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,3 +29,35 @@ type Address struct {
|
|||||||
IP net.IP
|
IP net.IP
|
||||||
Port uint16
|
Port uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Info is a generic interface to both ConnectionInfo and BindInfo.
|
||||||
|
type Info interface {
|
||||||
|
GetPID() int
|
||||||
|
SetPID(int)
|
||||||
|
GetUID() int
|
||||||
|
GetInode() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPID returns the PID.
|
||||||
|
func (i *ConnectionInfo) GetPID() int { return i.PID }
|
||||||
|
|
||||||
|
// SetPID sets the PID to the given value.
|
||||||
|
func (i *ConnectionInfo) SetPID(pid int) { i.PID = pid }
|
||||||
|
|
||||||
|
// GetUID returns the UID.
|
||||||
|
func (i *ConnectionInfo) GetUID() int { return i.UID }
|
||||||
|
|
||||||
|
// GetInode returns the Inode.
|
||||||
|
func (i *ConnectionInfo) GetInode() int { return i.Inode }
|
||||||
|
|
||||||
|
// GetPID returns the PID.
|
||||||
|
func (i *BindInfo) GetPID() int { return i.PID }
|
||||||
|
|
||||||
|
// SetPID sets the PID to the given value.
|
||||||
|
func (i *BindInfo) SetPID(pid int) { i.PID = pid }
|
||||||
|
|
||||||
|
// GetUID returns the UID.
|
||||||
|
func (i *BindInfo) GetUID() int { return i.UID }
|
||||||
|
|
||||||
|
// GetInode returns the Inode.
|
||||||
|
func (i *BindInfo) GetInode() int { return i.Inode }
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) (
|
|||||||
if localPort == socketInfo.Local.Port &&
|
if localPort == socketInfo.Local.Port &&
|
||||||
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
|
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
|
||||||
table.lock.RUnlock()
|
table.lock.RUnlock()
|
||||||
return checkBindPID(socketInfo, true)
|
return checkPID(socketInfo, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,7 +90,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) (
|
|||||||
if localPort == socketInfo.Local.Port &&
|
if localPort == socketInfo.Local.Port &&
|
||||||
localIP.Equal(socketInfo.Local.IP) {
|
localIP.Equal(socketInfo.Local.IP) {
|
||||||
table.lock.RUnlock()
|
table.lock.RUnlock()
|
||||||
return checkConnectionPID(socketInfo, false)
|
return checkPID(socketInfo, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,12 +138,12 @@ func (table *udpTable) lookup(pktInfo *packet.Info) (
|
|||||||
|
|
||||||
// do not check direction if remoteIP/Port is not given
|
// do not check direction if remoteIP/Port is not given
|
||||||
if pktInfo.RemotePort() == 0 {
|
if pktInfo.RemotePort() == 0 {
|
||||||
return checkBindPID(socketInfo, pktInfo.Inbound)
|
return checkPID(socketInfo, pktInfo.Inbound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// get direction and return
|
// get direction and return
|
||||||
connInbound := table.getDirection(socketInfo, pktInfo)
|
connInbound := table.getDirection(socketInfo, pktInfo)
|
||||||
return checkBindPID(socketInfo, connInbound)
|
return checkPID(socketInfo, connInbound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,30 +14,10 @@ var (
|
|||||||
getUDP6Table = proc.GetUDP6Table
|
getUDP6Table = proc.GetUDP6Table
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
|
func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
|
||||||
for i := 0; i <= lookupRetries; i++ {
|
for i := 0; i <= lookupRetries; i++ {
|
||||||
// look for PID
|
// look for PID
|
||||||
pid = proc.FindConnectionPID(socketInfo)
|
pid = proc.GetPID(socketInfo)
|
||||||
if pid != socket.UnidentifiedProcessID {
|
|
||||||
// if we found a PID, return
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// every time, except for the last iteration
|
|
||||||
if i < lookupRetries {
|
|
||||||
// we found no PID, we could have been too fast, give the kernel some time to think
|
|
||||||
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
|
|
||||||
time.Sleep(time.Duration(i+1) * baseWaitTime)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pid, connInbound, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
|
|
||||||
for i := 0; i <= lookupRetries; i++ {
|
|
||||||
// look for PID
|
|
||||||
pid = proc.FindBindPID(socketInfo)
|
|
||||||
if pid != socket.UnidentifiedProcessID {
|
if pid != socket.UnidentifiedProcessID {
|
||||||
// if we found a PID, return
|
// if we found a PID, return
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -12,10 +12,6 @@ var (
|
|||||||
getUDP6Table = iphelper.GetUDP6Table
|
getUDP6Table = iphelper.GetUDP6Table
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
|
func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
|
||||||
return socketInfo.PID, connInbound, nil
|
return socketInfo.GetPID(), connInbound, nil
|
||||||
}
|
|
||||||
|
|
||||||
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
|
|
||||||
return socketInfo.PID, connInbound, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -228,13 +228,12 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error
|
|||||||
|
|
||||||
// check if we are online
|
// check if we are online
|
||||||
if netenv.GetOnlineStatus() == netenv.StatusOffline {
|
if netenv.GetOnlineStatus() == netenv.StatusOffline {
|
||||||
if netenv.IsOnlineStatusTestDomain(q.FQDN) {
|
if !netenv.IsOnlineStatusTestDomain(q.FQDN) {
|
||||||
log.Tracer(ctx).Debugf("resolver: permitting online status test domain %s to resolve even though offline", q.FQDN)
|
|
||||||
} else {
|
|
||||||
log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN)
|
log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN)
|
||||||
// we are offline and this is not an online check query
|
// we are offline and this is not an online check query
|
||||||
return nil, ErrOffline
|
return nil, ErrOffline
|
||||||
}
|
}
|
||||||
|
log.Tracer(ctx).Debugf("resolver: permitting online status test domain %s to resolve even though offline", q.FQDN)
|
||||||
}
|
}
|
||||||
|
|
||||||
// start resolving
|
// start resolving
|
||||||
@@ -253,9 +252,6 @@ resolveLoop:
|
|||||||
// resolve
|
// resolve
|
||||||
rrCache, err = resolver.Conn.Query(ctx, q)
|
rrCache, err = resolver.Conn.Query(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
// TODO: check if we are online?
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, ErrNotFound):
|
case errors.Is(err, ErrNotFound):
|
||||||
// NXDomain, or similar
|
// NXDomain, or similar
|
||||||
@@ -268,6 +264,8 @@ resolveLoop:
|
|||||||
log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN)
|
log.Tracer(ctx).Debugf("resolver: not resolving %s, device is offline", q.FQDN)
|
||||||
// we are offline and this is not an online check query
|
// we are offline and this is not an online check query
|
||||||
return nil, ErrOffline
|
return nil, ErrOffline
|
||||||
|
default:
|
||||||
|
log.Tracer(ctx).Debugf("resolver: failed to resolve %s: %s", q.FQDN, err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// no error
|
// no error
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
package resolver
|
|
||||||
|
|
||||||
// DISABLE TESTING FOR NOW: find a way to have tests with the module system
|
|
||||||
|
|
||||||
// import (
|
|
||||||
// "testing"
|
|
||||||
// "time"
|
|
||||||
//
|
|
||||||
// "github.com/miekg/dns"
|
|
||||||
// )
|
|
||||||
|
|
||||||
// func TestResolve(t *testing.T) {
|
|
||||||
// Resolve("google.com.", dns.Type(dns.TypeA), 0)
|
|
||||||
// time.Sleep(200 * time.Millisecond)
|
|
||||||
// }
|
|
||||||
@@ -50,8 +50,7 @@ func TestClientPooling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create separate resolver for this test
|
// create separate resolver for this test
|
||||||
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
|
resolver, _, err := createResolver(testResolver, "config")
|
||||||
// resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,14 +22,13 @@ const (
|
|||||||
type TCPResolver struct {
|
type TCPResolver struct {
|
||||||
BasicResolverConn
|
BasicResolverConn
|
||||||
|
|
||||||
clientTTL time.Duration
|
clientTTL time.Duration
|
||||||
dnsClient *dns.Client
|
dnsClient *dns.Client
|
||||||
dnsConnection *dns.Conn
|
|
||||||
connInstanceID *uint32
|
|
||||||
|
|
||||||
|
clientStarted *abool.AtomicBool
|
||||||
|
connInstanceID *uint32
|
||||||
queries chan *dns.Msg
|
queries chan *dns.Msg
|
||||||
inFlightQueries map[uint16]*InFlightQuery
|
inFlightQueries map[uint16]*InFlightQuery
|
||||||
clientStarted *abool.AtomicBool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// InFlightQuery represents an in flight query of a TCPResolver.
|
// InFlightQuery represents an in flight query of a TCPResolver.
|
||||||
@@ -86,254 +85,10 @@ func (tr *TCPResolver) UseTLS() *TCPResolver {
|
|||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr *TCPResolver) client(workerCtx context.Context) error { //nolint:gocognit,gocyclo // TODO
|
|
||||||
connTimer := time.NewTimer(tr.clientTTL)
|
|
||||||
connClosing := abool.New()
|
|
||||||
var connCtx context.Context
|
|
||||||
var cancelConnCtx func()
|
|
||||||
var recycleConn bool
|
|
||||||
var shuttingDown bool
|
|
||||||
var failCnt int
|
|
||||||
var incoming = make(chan *dns.Msg, 100)
|
|
||||||
|
|
||||||
connMgmt:
|
|
||||||
for {
|
|
||||||
// cleanup old connection
|
|
||||||
if tr.dnsConnection != nil {
|
|
||||||
connClosing.Set()
|
|
||||||
_ = tr.dnsConnection.Close()
|
|
||||||
if cancelConnCtx != nil {
|
|
||||||
cancelConnCtx()
|
|
||||||
}
|
|
||||||
|
|
||||||
tr.dnsConnection = nil
|
|
||||||
atomic.AddUint32(tr.connInstanceID, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if we are shutting down or failing
|
|
||||||
if shuttingDown || failCnt >= FailThreshold || tr.IsFailing() {
|
|
||||||
// reply to all waiting queries
|
|
||||||
tr.Lock()
|
|
||||||
for id, inFlight := range tr.inFlightQueries {
|
|
||||||
close(inFlight.Response)
|
|
||||||
delete(tr.inFlightQueries, id)
|
|
||||||
}
|
|
||||||
tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
|
|
||||||
tr.Unlock()
|
|
||||||
|
|
||||||
// hint network environment at failed connection
|
|
||||||
if failCnt >= FailThreshold {
|
|
||||||
netenv.ReportFailedConnection()
|
|
||||||
}
|
|
||||||
|
|
||||||
// The linter said so. Don't even...
|
|
||||||
if cancelConnCtx != nil {
|
|
||||||
cancelConnCtx()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until there is something to do
|
|
||||||
tr.Lock()
|
|
||||||
waiting := len(tr.inFlightQueries)
|
|
||||||
tr.Unlock()
|
|
||||||
if waiting > 0 {
|
|
||||||
// queue abandoned queries
|
|
||||||
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
|
|
||||||
currentConnInstanceID := atomic.LoadUint32(tr.connInstanceID)
|
|
||||||
tr.Lock()
|
|
||||||
for id, inFlight := range tr.inFlightQueries {
|
|
||||||
if inFlight.Started.Before(ignoreBefore) {
|
|
||||||
// remove
|
|
||||||
delete(tr.inFlightQueries, id)
|
|
||||||
} else if inFlight.ConnInstanceID != currentConnInstanceID {
|
|
||||||
inFlight.ConnInstanceID = currentConnInstanceID
|
|
||||||
// re-inject
|
|
||||||
select {
|
|
||||||
case tr.queries <- inFlight.Msg:
|
|
||||||
default:
|
|
||||||
log.Warningf("resolver: failed to re-inject abandoned query to %s", tr.resolver.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tr.Unlock()
|
|
||||||
} else {
|
|
||||||
// wait for first query
|
|
||||||
select {
|
|
||||||
case <-workerCtx.Done():
|
|
||||||
// abort
|
|
||||||
shuttingDown = true
|
|
||||||
continue connMgmt
|
|
||||||
case msg := <-tr.queries:
|
|
||||||
// re-insert, we will handle later
|
|
||||||
select {
|
|
||||||
case tr.queries <- msg:
|
|
||||||
default:
|
|
||||||
log.Warningf("resolver: failed to re-inject waking query to %s", tr.resolver.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create connection
|
|
||||||
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
|
|
||||||
// refresh dialer for authenticated local address
|
|
||||||
tr.dnsClient.Dialer = &net.Dialer{
|
|
||||||
LocalAddr: getLocalAddr("tcp"),
|
|
||||||
Timeout: defaultConnectTimeout,
|
|
||||||
KeepAlive: defaultClientTTL,
|
|
||||||
}
|
|
||||||
// connect
|
|
||||||
c, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
|
|
||||||
if err != nil {
|
|
||||||
tr.ReportFailure()
|
|
||||||
failCnt++
|
|
||||||
if tr.IsFailing() {
|
|
||||||
shuttingDown = true
|
|
||||||
}
|
|
||||||
log.Debugf("resolver: failed to connect to %s (%s)", tr.resolver.Name, tr.resolver.ServerAddress)
|
|
||||||
netenv.ReportFailedConnection()
|
|
||||||
continue connMgmt
|
|
||||||
}
|
|
||||||
tr.dnsConnection = c
|
|
||||||
log.Debugf("resolver: connected to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr())
|
|
||||||
|
|
||||||
// hint network environment at successful connection
|
|
||||||
netenv.ReportSuccessfulConnection()
|
|
||||||
|
|
||||||
// reset timer
|
|
||||||
connTimer.Stop()
|
|
||||||
select {
|
|
||||||
case <-connTimer.C: // try to empty the timer
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
connTimer.Reset(tr.clientTTL)
|
|
||||||
recycleConn = false
|
|
||||||
|
|
||||||
// start reader
|
|
||||||
module.StartWorker("dns client reader", func(ctx context.Context) error {
|
|
||||||
conn := tr.dnsConnection
|
|
||||||
for {
|
|
||||||
msg, err := conn.ReadMsg()
|
|
||||||
if err != nil {
|
|
||||||
if connClosing.SetToIf(false, true) {
|
|
||||||
if cancelConnCtx != nil {
|
|
||||||
cancelConnCtx()
|
|
||||||
}
|
|
||||||
tr.ReportFailure()
|
|
||||||
failCnt++
|
|
||||||
if tr.IsFailing() {
|
|
||||||
shuttingDown = true
|
|
||||||
}
|
|
||||||
log.Warningf("resolver: read error from %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
incoming <- msg
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// query management
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-workerCtx.Done():
|
|
||||||
// shutting down
|
|
||||||
shuttingDown = true
|
|
||||||
continue connMgmt
|
|
||||||
case <-connCtx.Done():
|
|
||||||
// connection error
|
|
||||||
continue connMgmt
|
|
||||||
case <-connTimer.C:
|
|
||||||
// client TTL expired, recycle connection
|
|
||||||
recycleConn = true
|
|
||||||
// trigger check
|
|
||||||
select {
|
|
||||||
case incoming <- nil:
|
|
||||||
default:
|
|
||||||
// quere is full anyway, do nothing
|
|
||||||
}
|
|
||||||
|
|
||||||
case msg := <-tr.queries:
|
|
||||||
// write query
|
|
||||||
_ = tr.dnsConnection.SetWriteDeadline(time.Now().Add(tr.dnsClient.WriteTimeout))
|
|
||||||
err := tr.dnsConnection.WriteMsg(msg)
|
|
||||||
if err != nil {
|
|
||||||
if connClosing.SetToIf(false, true) {
|
|
||||||
if cancelConnCtx != nil {
|
|
||||||
cancelConnCtx()
|
|
||||||
}
|
|
||||||
tr.ReportFailure()
|
|
||||||
failCnt++
|
|
||||||
if tr.IsFailing() {
|
|
||||||
shuttingDown = true
|
|
||||||
}
|
|
||||||
log.Warningf("resolver: write error to %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
continue connMgmt
|
|
||||||
}
|
|
||||||
|
|
||||||
case msg := <-incoming:
|
|
||||||
|
|
||||||
if msg != nil {
|
|
||||||
// handle query from resolver
|
|
||||||
tr.Lock()
|
|
||||||
inFlight, ok := tr.inFlightQueries[msg.Id]
|
|
||||||
if ok {
|
|
||||||
delete(tr.inFlightQueries, msg.Id)
|
|
||||||
}
|
|
||||||
tr.Unlock()
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
select {
|
|
||||||
case inFlight.Response <- msg:
|
|
||||||
failCnt = 0 // reset fail counter
|
|
||||||
// responded!
|
|
||||||
default:
|
|
||||||
// save to cache, if enabled
|
|
||||||
if !inFlight.Query.NoCaching {
|
|
||||||
// persist to database
|
|
||||||
rrCache := inFlight.MakeCacheRecord(msg)
|
|
||||||
rrCache.Clean(600)
|
|
||||||
err = rrCache.Save()
|
|
||||||
if err != nil {
|
|
||||||
log.Warningf(
|
|
||||||
"resolver: failed to cache RR for %s%s: %s",
|
|
||||||
inFlight.Query.FQDN,
|
|
||||||
inFlight.Query.QType.String(),
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf(
|
|
||||||
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
|
|
||||||
tr.resolver.Name,
|
|
||||||
tr.dnsConnection.RemoteAddr(),
|
|
||||||
msg.Id,
|
|
||||||
msg.Question,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if we have finished all queries and want to recycle conn
|
|
||||||
if recycleConn {
|
|
||||||
tr.Lock()
|
|
||||||
activeQueries := len(tr.inFlightQueries)
|
|
||||||
tr.Unlock()
|
|
||||||
if activeQueries == 0 {
|
|
||||||
log.Debugf("resolver: recycling conn to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr())
|
|
||||||
continue connMgmt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
|
func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
|
||||||
|
// make sure client is started
|
||||||
|
tr.startClient()
|
||||||
|
|
||||||
// create msg
|
// create msg
|
||||||
msg := &dns.Msg{}
|
msg := &dns.Msg{}
|
||||||
msg.SetQuestion(q.FQDN, uint16(q.QType))
|
msg.SetQuestion(q.FQDN, uint16(q.QType))
|
||||||
@@ -348,20 +103,39 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
|
|||||||
ConnInstanceID: atomic.LoadUint32(tr.connInstanceID),
|
ConnInstanceID: atomic.LoadUint32(tr.connInstanceID),
|
||||||
}
|
}
|
||||||
tr.Lock()
|
tr.Lock()
|
||||||
|
// check for existing query
|
||||||
|
tr.ensureUniqueID(msg)
|
||||||
|
// add query to in flight registry
|
||||||
tr.inFlightQueries[msg.Id] = inFlight
|
tr.inFlightQueries[msg.Id] = inFlight
|
||||||
tr.Unlock()
|
tr.Unlock()
|
||||||
|
|
||||||
// submit msg for writing
|
// submit msg for writing
|
||||||
tr.queries <- msg
|
tr.queries <- msg
|
||||||
|
|
||||||
// make sure client is started
|
|
||||||
if tr.clientStarted.SetToIf(false, true) {
|
|
||||||
module.StartServiceWorker("dns client", 10*time.Millisecond, tr.client)
|
|
||||||
}
|
|
||||||
|
|
||||||
return inFlight
|
return inFlight
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked.
|
||||||
|
func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
|
||||||
|
// try a random ID 10000 times
|
||||||
|
for i := 0; i < 10000; i++ { // don't try forever
|
||||||
|
_, exists := tr.inFlightQueries[msg.Id]
|
||||||
|
if !exists {
|
||||||
|
return // we are unique, yay!
|
||||||
|
}
|
||||||
|
msg.Id = dns.Id() // regenerate ID
|
||||||
|
}
|
||||||
|
// go through the complete space
|
||||||
|
var id uint16
|
||||||
|
for ; id <= (1<<16)-1; id++ { // don't try forever
|
||||||
|
_, exists := tr.inFlightQueries[id]
|
||||||
|
if !exists {
|
||||||
|
msg.Id = id
|
||||||
|
return // we are unique, yay!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Query executes the given query against the resolver.
|
// Query executes the given query against the resolver.
|
||||||
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
// submit to client
|
// submit to client
|
||||||
@@ -386,3 +160,291 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
|||||||
|
|
||||||
return inFlight.MakeCacheRecord(reply), nil
|
return inFlight.MakeCacheRecord(reply), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tcpResolverConnMgr struct {
|
||||||
|
tr *TCPResolver
|
||||||
|
responses chan *dns.Msg
|
||||||
|
failCnt int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *TCPResolver) startClient() {
|
||||||
|
if tr.clientStarted.SetToIf(false, true) {
|
||||||
|
mgr := &tcpResolverConnMgr{
|
||||||
|
tr: tr,
|
||||||
|
responses: make(chan *dns.Msg, 100),
|
||||||
|
}
|
||||||
|
module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
||||||
|
// connection lifecycle loop
|
||||||
|
for {
|
||||||
|
// check if we are failing
|
||||||
|
if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() {
|
||||||
|
mgr.shutdown()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for work before creating connection
|
||||||
|
proceed := mgr.waitForWork(workerCtx)
|
||||||
|
if !proceed {
|
||||||
|
mgr.shutdown()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// create connection
|
||||||
|
conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection(workerCtx)
|
||||||
|
if conn == nil {
|
||||||
|
mgr.failCnt++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// hint network environment at successful connection
|
||||||
|
netenv.ReportSuccessfulConnection()
|
||||||
|
|
||||||
|
// handle queries
|
||||||
|
proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx)
|
||||||
|
if !proceed {
|
||||||
|
mgr.shutdown()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) shutdown() {
|
||||||
|
// reply to all waiting queries
|
||||||
|
mgr.tr.Lock()
|
||||||
|
for id, inFlight := range mgr.tr.inFlightQueries {
|
||||||
|
close(inFlight.Response)
|
||||||
|
delete(mgr.tr.inFlightQueries, id)
|
||||||
|
}
|
||||||
|
mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
|
||||||
|
atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter
|
||||||
|
mgr.tr.Unlock()
|
||||||
|
|
||||||
|
// hint network environment at failed connection
|
||||||
|
if mgr.failCnt >= FailThreshold {
|
||||||
|
netenv.ReportFailedConnection()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) {
|
||||||
|
// wait until there is something to do
|
||||||
|
mgr.tr.Lock()
|
||||||
|
waiting := len(mgr.tr.inFlightQueries)
|
||||||
|
mgr.tr.Unlock()
|
||||||
|
if waiting > 0 {
|
||||||
|
// queue abandoned queries
|
||||||
|
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
|
||||||
|
currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID)
|
||||||
|
mgr.tr.Lock()
|
||||||
|
defer mgr.tr.Unlock()
|
||||||
|
for id, inFlight := range mgr.tr.inFlightQueries {
|
||||||
|
if inFlight.Started.Before(ignoreBefore) {
|
||||||
|
// remove old queries
|
||||||
|
close(inFlight.Response)
|
||||||
|
delete(mgr.tr.inFlightQueries, id)
|
||||||
|
} else if inFlight.ConnInstanceID != currentConnInstanceID {
|
||||||
|
inFlight.ConnInstanceID = currentConnInstanceID
|
||||||
|
// re-inject queries that died with a previously failed connection
|
||||||
|
select {
|
||||||
|
case mgr.tr.queries <- inFlight.Msg:
|
||||||
|
default:
|
||||||
|
log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// in-flight queries that match the connection instance ID are not changed. They are already in the queue.
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for first query
|
||||||
|
select {
|
||||||
|
case <-workerCtx.Done():
|
||||||
|
return false
|
||||||
|
case msg := <-mgr.tr.queries:
|
||||||
|
// re-insert query, we will handle it later
|
||||||
|
module.StartWorker("reinject triggering dns query", func(ctx context.Context) error {
|
||||||
|
select {
|
||||||
|
case mgr.tr.queries <- msg:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) (
|
||||||
|
conn *dns.Conn,
|
||||||
|
connClosing *abool.AtomicBool,
|
||||||
|
connCtx context.Context,
|
||||||
|
cancelConnCtx context.CancelFunc,
|
||||||
|
) {
|
||||||
|
// refresh dialer to set an authenticated local address
|
||||||
|
// TODO: lock dnsClient (only manager should run at any time, so this should not be an issue)
|
||||||
|
mgr.tr.dnsClient.Dialer = &net.Dialer{
|
||||||
|
LocalAddr: getLocalAddr("tcp"),
|
||||||
|
Timeout: defaultConnectTimeout,
|
||||||
|
KeepAlive: defaultClientTTL,
|
||||||
|
}
|
||||||
|
// connect
|
||||||
|
var err error
|
||||||
|
conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.Name, mgr.tr.resolver.ServerAddress)
|
||||||
|
return nil, nil, nil, nil
|
||||||
|
}
|
||||||
|
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
|
||||||
|
connClosing = abool.New()
|
||||||
|
|
||||||
|
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
|
||||||
|
|
||||||
|
// start reader
|
||||||
|
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
|
||||||
|
return mgr.msgReader(conn, connClosing, cancelConnCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
return conn, connClosing, connCtx, cancelConnCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
|
||||||
|
workerCtx context.Context,
|
||||||
|
conn *dns.Conn,
|
||||||
|
connClosing *abool.AtomicBool,
|
||||||
|
connCtx context.Context,
|
||||||
|
cancelConnCtx context.CancelFunc,
|
||||||
|
) (proceed bool) {
|
||||||
|
var readyToRecycle bool
|
||||||
|
ttlTimer := time.After(mgr.tr.clientTTL)
|
||||||
|
|
||||||
|
// clean up connection
|
||||||
|
defer func() {
|
||||||
|
connClosing.Set() // silence connection errors
|
||||||
|
cancelConnCtx()
|
||||||
|
_ = conn.Close()
|
||||||
|
|
||||||
|
// increase instance counter
|
||||||
|
atomic.AddUint32(mgr.tr.connInstanceID, 1)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-workerCtx.Done():
|
||||||
|
// module shutdown
|
||||||
|
return false
|
||||||
|
|
||||||
|
case <-connCtx.Done():
|
||||||
|
// connection error
|
||||||
|
return true
|
||||||
|
|
||||||
|
case <-ttlTimer:
|
||||||
|
// connection TTL reached, rebuild connection
|
||||||
|
// but handle all in flight queries first
|
||||||
|
readyToRecycle = true
|
||||||
|
// trigger check
|
||||||
|
select {
|
||||||
|
case mgr.responses <- nil:
|
||||||
|
default:
|
||||||
|
// queue is full, check will be triggered anyway
|
||||||
|
}
|
||||||
|
|
||||||
|
case msg := <-mgr.tr.queries:
|
||||||
|
// write query
|
||||||
|
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
|
||||||
|
err := conn.WriteMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
if connClosing.SetToIf(false, true) {
|
||||||
|
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
case msg := <-mgr.responses:
|
||||||
|
if msg != nil {
|
||||||
|
mgr.handleQueryResponse(conn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if readyToRecycle {
|
||||||
|
// check to see if we can recycle the connection
|
||||||
|
mgr.tr.Lock()
|
||||||
|
activeQueries := len(mgr.tr.inFlightQueries)
|
||||||
|
mgr.tr.Unlock()
|
||||||
|
if activeQueries == 0 {
|
||||||
|
log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) {
|
||||||
|
// handle query from resolver
|
||||||
|
mgr.tr.Lock()
|
||||||
|
inFlight, ok := mgr.tr.inFlightQueries[msg.Id]
|
||||||
|
if ok {
|
||||||
|
delete(mgr.tr.inFlightQueries, msg.Id)
|
||||||
|
}
|
||||||
|
mgr.tr.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
log.Debugf(
|
||||||
|
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
|
||||||
|
mgr.tr.resolver.Name,
|
||||||
|
conn.RemoteAddr(),
|
||||||
|
msg.Id,
|
||||||
|
msg.Question,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case inFlight.Response <- msg:
|
||||||
|
mgr.failCnt = 0 // reset fail counter
|
||||||
|
// responded!
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// no one is listening for that response.
|
||||||
|
}
|
||||||
|
|
||||||
|
// if caching is disabled we're done
|
||||||
|
if inFlight.Query.NoCaching {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// persist to database
|
||||||
|
rrCache := inFlight.MakeCacheRecord(msg)
|
||||||
|
rrCache.Clean(600)
|
||||||
|
err := rrCache.Save()
|
||||||
|
if err != nil {
|
||||||
|
log.Warningf(
|
||||||
|
"resolver: failed to cache RR for %s%s: %s",
|
||||||
|
inFlight.Query.FQDN,
|
||||||
|
inFlight.Query.QType.String(),
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) msgReader(
|
||||||
|
conn *dns.Conn,
|
||||||
|
connClosing *abool.AtomicBool,
|
||||||
|
cancelConnCtx context.CancelFunc,
|
||||||
|
) error {
|
||||||
|
defer cancelConnCtx()
|
||||||
|
for {
|
||||||
|
msg, err := conn.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
if connClosing.SetToIf(false, true) {
|
||||||
|
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mgr.responses <- msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func init() {
|
|||||||
flag.StringVar(
|
flag.StringVar(
|
||||||
&testResolver,
|
&testResolver,
|
||||||
"resolver",
|
"resolver",
|
||||||
"dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty",
|
"dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip",
|
||||||
"set custom resolver for testing",
|
"set custom resolver for testing",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user