diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 00cdc55e..35aff4e7 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -151,7 +151,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) // that will happen later anyway. case network.VerdictUndecided, network.VerdictAccept: // Save the request as open, as we don't know if there will be a connection or not. - network.SaveOpenDNSRequest(conn) + network.SaveOpenDNSRequest(conn, uint16(q.QType)) firewall.UpdateIPsAndCNAMEs(q, rrCache, conn) default: diff --git a/network/dns.go b/network/dns.go index ef75159a..eaaedf29 100644 --- a/network/dns.go +++ b/network/dns.go @@ -29,33 +29,38 @@ const ( openDNSRequestLimit = 3 * time.Second ) -func getDNSRequestCacheKey(pid int, fqdn string) string { - return strconv.Itoa(pid) + "/" + fqdn +func getDNSRequestCacheKey(pid int, fqdn string, qType uint16) string { + return strconv.Itoa(pid) + "/" + fqdn + dns.Type(qType).String() } func removeOpenDNSRequest(pid int, fqdn string) { openDNSRequestsLock.Lock() defer openDNSRequestsLock.Unlock() - key := getDNSRequestCacheKey(pid, fqdn) - _, ok := openDNSRequests[key] - if ok { - delete(openDNSRequests, key) - return - } + // Delete PID-specific requests. + delete(openDNSRequests, getDNSRequestCacheKey(pid, fqdn, dns.TypeA)) + delete(openDNSRequests, getDNSRequestCacheKey(pid, fqdn, dns.TypeAAAA)) + // If process is known, also check for non-attributed requests. if pid != process.UnidentifiedProcessID { - // check if there is an open dns request from an unidentified process - delete(openDNSRequests, unidentifiedProcessScopePrefix+fqdn) + delete(openDNSRequests, getDNSRequestCacheKey(process.UnidentifiedProcessID, fqdn, dns.TypeA)) + delete(openDNSRequests, getDNSRequestCacheKey(process.UnidentifiedProcessID, fqdn, dns.TypeAAAA)) } } // SaveOpenDNSRequest saves a dns request connection that was allowed to proceed. -func SaveOpenDNSRequest(conn *Connection) { +func SaveOpenDNSRequest(conn *Connection, qType uint16) { openDNSRequestsLock.Lock() defer openDNSRequestsLock.Unlock() - key := getDNSRequestCacheKey(conn.process.Pid, conn.Entity.Domain) + // Only save open A and AAAA requests. + switch qType { + case dns.TypeA, dns.TypeAAAA: + default: + return + } + + key := getDNSRequestCacheKey(conn.process.Pid, conn.Entity.Domain, qType) if existingConn, ok := openDNSRequests[key]; ok { // End previous request and save it. existingConn.Lock()