Merge pull request #38 from safing/feature/cname-blocking

Add CNAME blocking and upstream DNS resolver block detection support
This commit is contained in:
Daniel
2020-04-21 12:58:19 +02:00
committed by GitHub
30 changed files with 1253 additions and 625 deletions

230
firewall/dns.go Normal file
View File

@@ -0,0 +1,230 @@
package firewall
import (
"net"
"os"
"strings"
"github.com/miekg/dns"
"github.com/safing/portbase/database"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/profile"
"github.com/safing/portmaster/profile/endpoints"
"github.com/safing/portmaster/resolver"
)
func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, scope int8) ([]dns.RR, []string, int) {
goodEntries := make([]dns.RR, 0, len(entries))
filteredRecords := make([]string, 0, len(entries))
// keeps track of the number of valid and allowed
// A and AAAA records.
var allowedAddressRecords int
for _, rr := range entries {
// get IP and classification
var ip net.IP
switch v := rr.(type) {
case *dns.A:
ip = v.A
case *dns.AAAA:
ip = v.AAAA
default:
// add non A/AAAA entries
goodEntries = append(goodEntries, rr)
continue
}
classification := netutils.ClassifyIP(ip)
if p.RemoveOutOfScopeDNS() {
switch {
case classification == netutils.HostLocal:
// No DNS should return localhost addresses
filteredRecords = append(filteredRecords, rr.String())
continue
case scope == netutils.Global && (classification == netutils.SiteLocal || classification == netutils.LinkLocal):
// No global DNS should return LAN addresses
filteredRecords = append(filteredRecords, rr.String())
continue
}
}
if p.RemoveBlockedDNS() {
// filter by flags
switch {
case p.BlockScopeInternet() && classification == netutils.Global:
filteredRecords = append(filteredRecords, rr.String())
continue
case p.BlockScopeLAN() && (classification == netutils.SiteLocal || classification == netutils.LinkLocal):
filteredRecords = append(filteredRecords, rr.String())
continue
case p.BlockScopeLocal() && classification == netutils.HostLocal:
filteredRecords = append(filteredRecords, rr.String())
continue
}
// TODO: filter by endpoint list (IP only)
}
// if survived, add to good entries
allowedAddressRecords++
goodEntries = append(goodEntries, rr)
}
return goodEntries, filteredRecords, allowedAddressRecords
}
func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *resolver.RRCache {
// do not modify own queries
if conn.Process().Pid == os.Getpid() {
return rrCache
}
// get profile
p := conn.Process().Profile()
if p == nil {
conn.Block("no profile")
return nil
}
// check if DNS response filtering is completely turned off
if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() {
return rrCache
}
// duplicate entry
rrCache = rrCache.ShallowCopy()
rrCache.FilteredEntries = make([]string, 0)
var filteredRecords []string
var validIPs int
rrCache.Answer, filteredRecords, validIPs = filterDNSSection(rrCache.Answer, p, rrCache.ServerScope)
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
// we don't count the valid IPs in the extra section
rrCache.Extra, filteredRecords, _ = filterDNSSection(rrCache.Extra, p, rrCache.ServerScope)
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
if len(rrCache.FilteredEntries) > 0 {
rrCache.Filtered = true
if validIPs == 0 {
conn.Block("no addresses returned for this domain are permitted")
return nil
}
log.Infof("filter: filtered DNS replies for %s: %s", conn, strings.Join(rrCache.FilteredEntries, ", "))
}
return rrCache
}
// DecideOnResolvedDNS filters a dns response according to the application profile and settings.
func DecideOnResolvedDNS(conn *network.Connection, q *resolver.Query, rrCache *resolver.RRCache) *resolver.RRCache {
updatedRR := filterDNSResponse(conn, rrCache)
if updatedRR == nil {
return nil
}
updateIPsAndCNAMEs(q, rrCache, conn)
if mayBlockCNAMEs(conn) {
return nil
}
// TODO: Gate17 integration
// tunnelInfo, err := AssignTunnelIP(fqdn)
return updatedRR
}
func mayBlockCNAMEs(conn *network.Connection) bool {
// if we have CNAMEs and the profile is configured to filter them
// we need to re-check the lists and endpoints here
if conn.Process().Profile().FilterCNAMEs() {
conn.Entity.ResetLists()
conn.Entity.EnableCNAMECheck(true)
result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity)
if result == endpoints.Denied {
conn.BlockWithContext(reason.String(), reason.Context())
return true
}
if result == endpoints.NoMatch {
result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity)
if result == endpoints.Denied {
conn.BlockWithContext(reason.String(), reason.Context())
return true
}
}
}
return false
}
func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) {
// save IP addresses to IPInfo
cnames := make(map[string]string)
ips := make(map[string]struct{})
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
switch v := rr.(type) {
case *dns.CNAME:
cnames[v.Hdr.Name] = v.Target
case *dns.A:
ips[v.A.String()] = struct{}{}
case *dns.AAAA:
ips[v.AAAA.String()] = struct{}{}
}
}
for ip := range ips {
record := resolver.ResolvedDomain{
Domain: q.FQDN,
}
// resolve all CNAMEs in the correct order.
var domain = q.FQDN
for {
nextDomain, isCNAME := cnames[domain]
if !isCNAME {
break
}
record.CNAMEs = append(record.CNAMEs, nextDomain)
domain = nextDomain
}
// update the entity to include the cnames
conn.Entity.CNAME = record.CNAMEs
// get the existing IP info or create a new one
var save bool
info, err := resolver.GetIPInfo(ip)
if err != nil {
if err != database.ErrNotFound {
log.Errorf("nameserver: failed to search for IP info record: %s", err)
}
info = &resolver.IPInfo{
IP: ip,
}
save = true
}
// and the new resolved domain record and save
if new := info.AddDomain(record); new {
save = true
}
if save {
if err := info.Save(); err != nil {
log.Errorf("nameserver: failed to save IP info record: %s", err)
}
}
}
}

View File

@@ -85,11 +85,11 @@ func RunInspectors(conn *network.Connection, pkt packet.Packet) (network.Verdict
verdict = network.VerdictDrop
continueInspection = true
case BLOCK_CONN:
conn.SetVerdict(network.VerdictBlock)
conn.SetVerdict(network.VerdictBlock, "", nil)
verdict = conn.Verdict
activeInspectors[key] = true
case DROP_CONN:
conn.SetVerdict(network.VerdictDrop)
conn.SetVerdict(network.VerdictDrop, "", nil)
verdict = conn.Verdict
activeInspectors[key] = true
case STOP_INSPECTING:

View File

@@ -2,7 +2,6 @@ package firewall
import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
@@ -14,10 +13,8 @@ import (
"github.com/safing/portmaster/process"
"github.com/safing/portmaster/profile"
"github.com/safing/portmaster/profile/endpoints"
"github.com/safing/portmaster/resolver"
"github.com/agext/levenshtein"
"github.com/miekg/dns"
)
// Call order:
@@ -26,7 +23,7 @@ import (
// 1. DecideOnConnection
// is called when a DNS query is made, may set verdict to Undeterminable to permit a DNS reply.
// is called with a nil packet.
// 2. FilterDNSResponse
// 2. DecideOnResolvedDNS
// is called to (possibly) filter out A/AAAA records that the filter would deny later.
//
// Network Connection:
@@ -35,7 +32,7 @@ import (
// DecideOnConnection makes a decision about a connection.
// When called, the connection and profile is already locked.
func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint:gocognit,gocyclo // TODO
func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
// update profiles and check if communication needs reevaluation
if conn.UpdateAndCheck() {
log.Infof("filter: re-evaluating verdict on %s", conn)
@@ -46,14 +43,47 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint:
}
}
var deciders = []func(*network.Connection, packet.Packet) bool{
checkPortmasterConnection,
checkSelfCommunication,
checkProfileExists,
checkConnectionType,
checkConnectionScope,
checkEndpointLists,
checkBypassPrevention,
checkFilterLists,
checkInbound,
checkDefaultPermit,
checkAutoPermitRelated,
checkDefaultAction,
}
for _, decider := range deciders {
if decider(conn, pkt) {
return
}
}
// DefaultAction == DefaultActionBlock
conn.Deny("endpoint is not whitelisted (default=block)")
}
// checkPortmasterConnection allows all connection that originate from
// portmaster itself.
func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool {
// grant self
if conn.Process().Pid == os.Getpid() {
log.Infof("filter: granting own connection %s", conn)
conn.Verdict = network.VerdictAccept
conn.Internal = true
return
return true
}
return false
}
// checkSelfCommunication checks if the process is communicating with itself.
func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool {
// check if process is communicating with itself
if pkt != nil {
// TODO: evaluate the case where different IPs in the 127/8 net are used.
@@ -77,18 +107,50 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint:
} else if otherProcess.Pid == conn.Process().Pid {
conn.Accept("connection to self")
conn.Internal = true
return
return true
}
}
}
}
// get profile
p := conn.Process().Profile()
if p == nil {
conn.Block("no profile")
return
return false
}
func checkProfileExists(conn *network.Connection, _ packet.Packet) bool {
if conn.Process().Profile() == nil {
conn.Block("unknown process or profile")
return true
}
return false
}
func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool {
var result endpoints.EPResult
var reason endpoints.Reason
// there must always be a profile.
p := conn.Process().Profile()
// check endpoints list
if conn.Inbound {
result, reason = p.MatchServiceEndpoint(conn.Entity)
} else {
result, reason = p.MatchEndpoint(conn.Entity)
}
switch result {
case endpoints.Denied:
conn.DenyWithContext(reason.String(), reason.Context())
return true
case endpoints.Permitted:
conn.AcceptWithContext(reason.String(), reason.Context())
return true
}
return false
}
func checkConnectionType(conn *network.Connection, _ packet.Packet) bool {
p := conn.Process().Profile()
// check conn type
switch conn.Scope {
@@ -99,16 +161,22 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint:
} else {
conn.Drop("inbound connections blocked")
}
return
return true
}
case network.PeerLAN, network.PeerInternet, network.PeerInvalid:
// Important: PeerHost is and should be missing!
if p.BlockP2P() {
conn.Block("direct connections (P2P) blocked")
return
return true
}
}
return false
}
func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool {
p := conn.Process().Profile()
// check scopes
if conn.Entity.IP != nil {
classification := netutils.ClassifyIP(conn.Entity.IP)
@@ -117,21 +185,21 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint:
case netutils.Global, netutils.GlobalMulticast:
if p.BlockScopeInternet() {
conn.Deny("Internet access blocked") // Block Outbound / Drop Inbound
return
return true
}
case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast:
if p.BlockScopeLAN() {
conn.Block("LAN access blocked") // Block Outbound / Drop Inbound
return
return true
}
case netutils.HostLocal:
if p.BlockScopeLocal() {
conn.Block("Localhost access blocked") // Block Outbound / Drop Inbound
return
return true
}
default: // netutils.Invalid
conn.Deny("invalid IP") // Block Outbound / Drop Inbound
return
return true
}
} else if conn.Entity.Domain != "" {
// DNS Query
@@ -139,197 +207,84 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint:
// TODO: handle domains mapped to localhost
if p.BlockScopeInternet() && p.BlockScopeLAN() {
conn.Block("Internet and LAN access blocked")
return
return true
}
}
return false
}
var result endpoints.EPResult
var reason string
if p.PreventBypassing() {
func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool {
if conn.Process().Profile().PreventBypassing() {
// check for bypass protection
result, reason := PreventBypassing(conn)
switch result {
case endpoints.Denied:
conn.Block("bypass prevention: " + reason)
return
return true
case endpoints.Permitted:
conn.Accept("bypass prevention: " + reason)
return
return true
case endpoints.NoMatch:
}
}
return false
}
// check endpoints list
if conn.Inbound {
result, reason = p.MatchServiceEndpoint(conn.Entity)
} else {
result, reason = p.MatchEndpoint(conn.Entity)
}
switch result {
case endpoints.Denied:
conn.Deny("endpoint is blacklisted: " + reason) // Block Outbound / Drop Inbound
return
case endpoints.Permitted:
conn.Accept("endpoint is whitelisted: " + reason)
return
}
// continuing with result == NoMatch
func checkFilterLists(conn *network.Connection, _ packet.Packet) bool {
// apply privacy filter lists
result, reason = p.MatchFilterLists(conn.Entity)
p := conn.Process().Profile()
result, reason := p.MatchFilterLists(conn.Entity)
switch result {
case endpoints.Denied:
conn.Deny("endpoint in filterlists: " + reason)
return
conn.DenyWithContext(reason.String(), reason.Context())
return true
case endpoints.NoMatch:
// nothing to do
default:
log.Debugf("filter: filter lists returned unsupported verdict: %s", result)
}
return false
}
func checkInbound(conn *network.Connection, _ packet.Packet) bool {
// implicit default=block for inbound
if conn.Inbound {
conn.Drop("endpoint is not whitelisted (incoming is always default=block)")
return
return true
}
return false
}
func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool {
// check default action
p := conn.Process().Profile()
if p.DefaultAction() == profile.DefaultActionPermit {
conn.Accept("endpoint is not blacklisted (default=permit)")
return
return true
}
return false
}
// check relation
func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool {
p := conn.Process().Profile()
if !p.DisableAutoPermit() {
related, reason := checkRelation(conn)
if related {
conn.Accept(reason)
return
return true
}
}
// prompt
if p.DefaultAction() == profile.DefaultActionAsk {
prompt(conn, pkt)
return
}
// DefaultAction == DefaultActionBlock
conn.Deny("endpoint is not whitelisted (default=block)")
return false
}
// FilterDNSResponse filters a dns response according to the application profile and settings.
func FilterDNSResponse(conn *network.Connection, q *resolver.Query, rrCache *resolver.RRCache) *resolver.RRCache { //nolint:gocognit // TODO
// do not modify own queries
if conn.Process().Pid == os.Getpid() {
return rrCache
}
// get profile
func checkDefaultAction(conn *network.Connection, pkt packet.Packet) bool {
p := conn.Process().Profile()
if p == nil {
conn.Block("no profile")
return nil
if p.DefaultAction() == profile.DefaultActionAsk {
prompt(conn, pkt)
return true
}
// check if DNS response filtering is completely turned off
if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() {
return rrCache
}
// duplicate entry
rrCache = rrCache.ShallowCopy()
rrCache.FilteredEntries = make([]string, 0)
// change information
var addressesRemoved int
var addressesOk int
// loop vars
var classification int8
var ip net.IP
// filter function
filterEntries := func(entries []dns.RR) (goodEntries []dns.RR) {
goodEntries = make([]dns.RR, 0, len(entries))
for _, rr := range entries {
// get IP and classification
switch v := rr.(type) {
case *dns.A:
ip = v.A
case *dns.AAAA:
ip = v.AAAA
default:
// add non A/AAAA entries
goodEntries = append(goodEntries, rr)
continue
}
classification = netutils.ClassifyIP(ip)
if p.RemoveOutOfScopeDNS() {
switch {
case classification == netutils.HostLocal:
// No DNS should return localhost addresses
addressesRemoved++
rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String())
continue
case rrCache.ServerScope == netutils.Global && (classification == netutils.SiteLocal || classification == netutils.LinkLocal):
// No global DNS should return LAN addresses
addressesRemoved++
rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String())
continue
}
}
if p.RemoveBlockedDNS() {
// filter by flags
switch {
case p.BlockScopeInternet() && classification == netutils.Global:
addressesRemoved++
rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String())
continue
case p.BlockScopeLAN() && (classification == netutils.SiteLocal || classification == netutils.LinkLocal):
addressesRemoved++
rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String())
continue
case p.BlockScopeLocal() && classification == netutils.HostLocal:
addressesRemoved++
rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String())
continue
}
// TODO: filter by endpoint list (IP only)
}
// if survived, add to good entries
addressesOk++
goodEntries = append(goodEntries, rr)
}
return
}
rrCache.Answer = filterEntries(rrCache.Answer)
rrCache.Extra = filterEntries(rrCache.Extra)
if addressesRemoved > 0 {
rrCache.Filtered = true
if addressesOk == 0 {
conn.Block("no addresses returned for this domain are permitted")
return nil
}
}
if rrCache.Filtered {
log.Infof("filter: filtered DNS replies for %s: %s", conn, strings.Join(rrCache.FilteredEntries, ", "))
}
// TODO: Gate17 integration
// tunnelInfo, err := AssignTunnelIP(fqdn)
return rrCache
return false
}
// checkRelation tries to find a relation between a process and a communication. This is for better out of the box experience and is _not_ meant to thwart intentional malware.

97
intel/block_reason.go Normal file
View File

@@ -0,0 +1,97 @@
package intel
import (
"encoding/json"
"fmt"
"strings"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
)
// ListMatch represents an entity that has been
// matched against filterlists.
type ListMatch struct {
Entity string
ActiveLists []string
InactiveLists []string
}
func (lm *ListMatch) String() string {
inactive := ""
if len(lm.InactiveLists) > 0 {
inactive = " and in deactivated lists " + strings.Join(lm.InactiveLists, ", ")
}
return fmt.Sprintf(
"%s in activated lists %s%s",
lm.Entity,
strings.Join(lm.ActiveLists, ","),
inactive,
)
}
// ListBlockReason is a list of list matches.
type ListBlockReason []ListMatch
func (br ListBlockReason) String() string {
if len(br) == 0 {
return ""
}
matches := make([]string, len(br))
for idx, lm := range br {
matches[idx] = lm.String()
}
return strings.Join(matches, " and ")
}
// Context returns br wrapped into a map. It implements
// the endpoints.Reason interface.
func (br ListBlockReason) Context() interface{} {
return br
}
// MarshalJSON marshals the list block reason into a map
// prefixed with filterlists.
func (br ListBlockReason) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]interface{}{
// we convert to []ListMatch to avoid recursing
// here.
"filterlists": []ListMatch(br),
})
}
// ToRRs returns a set of dns TXT records that describe the
// block reason.
func (br ListBlockReason) ToRRs() []dns.RR {
rrs := make([]dns.RR, 0, len(br))
for _, lm := range br {
blockedBy, err := dns.NewRR(fmt.Sprintf(
"%s-blockedBy. 0 IN TXT %q",
strings.TrimRight(lm.Entity, "."),
strings.Join(lm.ActiveLists, ","),
))
if err == nil {
rrs = append(rrs, blockedBy)
} else {
log.Errorf("intel: failed to create TXT RR for block reason: %s", err)
}
if len(lm.InactiveLists) > 0 {
wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf(
"%s-wouldBeBlockedBy. 0 IN TXT %q",
strings.TrimRight(lm.Entity, "."),
strings.Join(lm.InactiveLists, ","),
))
if err == nil {
rrs = append(rrs, wouldBeBlockedBy)
} else {
log.Errorf("intel: failed to create TXT RR for block reason: %s", err)
}
}
}
return rrs
}

View File

@@ -32,18 +32,46 @@ type Entity struct {
asnListLoaded bool
reverseResolveEnabled bool
resolveSubDomainLists bool
checkCNAMEs bool
// Protocol is the protcol number used by the connection.
Protocol uint8
Port uint16
Domain string
IP net.IP
Country string
ASN uint
// Port is the destination port of the connection
Port uint16
// Domain is the target domain of the connection.
Domain string
// CNAME is a list of domain names that have been
// resolved for Domain.
CNAME []string
// IP is the IP address of the connection. If domain is
// set, IP has been resolved by following all CNAMEs.
IP net.IP
// Country holds the country the IP address (ASN) is
// located in.
Country string
// ASN holds the autonomous system number of the IP.
ASN uint
location *geoip.Location
Lists []string
ListsMap filterlists.LookupMap
// BlockedByLists holds list source IDs that
// are used to block the entity.
BlockedByLists []string
// BlockedEntities holds a list of entities that
// have been blocked. Values can be used as a key
// for the ListOccurences map.
BlockedEntities []string
// ListOccurences is a map that matches an entity (Domain, IPs, ASN, Country, Sub-domain)
// to a list of sources where the entity has been observed in.
ListOccurences map[string][]string
// we only load each data above at most once
fetchLocationOnce sync.Once
@@ -72,13 +100,17 @@ func (e *Entity) ResetLists() {
// TODO(ppacher): our actual goal is to reset the domain
// list right now so we could be more efficient by keeping
// the other lists around.
e.Lists = nil
e.ListsMap = nil
e.BlockedByLists = nil
e.BlockedEntities = nil
e.ListOccurences = nil
e.domainListLoaded = false
e.ipListLoaded = false
e.countryListLoaded = false
e.asnListLoaded = false
e.resolveSubDomainLists = false
e.checkCNAMEs = false
e.loadDomainListOnce = sync.Once{}
e.loadIPListOnce = sync.Once{}
e.loadCoutryListOnce = sync.Once{}
@@ -94,6 +126,21 @@ func (e *Entity) ResolveSubDomainLists(enabled bool) {
e.resolveSubDomainLists = enabled
}
// EnableCNAMECheck enalbes or disables list lookups for
// entity CNAMEs.
func (e *Entity) EnableCNAMECheck(enabled bool) {
if e.domainListLoaded {
log.Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain)
}
e.checkCNAMEs = enabled
}
// CNAMECheckEnabled returns true if the entities CNAMEs should
// also be checked.
func (e *Entity) CNAMECheckEnabled() bool {
return e.checkCNAMEs
}
// Domain and IP
// EnableReverseResolving enables reverse resolving the domain from the IP on demand.
@@ -204,9 +251,19 @@ func (e *Entity) getLists() {
e.getCountryLists()
}
func (e *Entity) mergeList(list []string) {
e.Lists = mergeStringList(e.Lists, list)
e.ListsMap = buildLookupMap(e.Lists)
func (e *Entity) mergeList(key string, list []string) {
if len(list) == 0 {
return
}
if e.ListOccurences == nil {
e.ListOccurences = make(map[string][]string)
}
e.ListOccurences[key] = mergeStringList(e.ListOccurences[key], list)
//e.Lists = mergeStringList(e.Lists, list)
//e.ListsMap = buildLookupMap(e.Lists)
}
func (e *Entity) getDomainLists() {
@@ -220,12 +277,27 @@ func (e *Entity) getDomainLists() {
}
e.loadDomainListOnce.Do(func() {
var domains = []string{domain}
if e.resolveSubDomainLists {
domains = splitDomain(domain)
log.Tracef("intel: subdomain list resolving is enabled, checking %v", domains)
var domainsToInspect = []string{domain}
if e.checkCNAMEs {
log.Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME)
domainsToInspect = append(domainsToInspect, e.CNAME...)
}
var domains []string
if e.resolveSubDomainLists {
for _, domain := range domainsToInspect {
subdomains := splitDomain(domain)
domains = append(domains, subdomains...)
log.Tracef("intel: subdomain list resolving is enabled: %s => %v", domains, subdomains)
}
} else {
domains = domainsToInspect
}
domains = makeDistinct(domains)
for _, d := range domains {
log.Tracef("intel: loading domain list for %s", d)
list, err := filterlists.LookupDomain(d)
@@ -235,7 +307,7 @@ func (e *Entity) getDomainLists() {
return
}
e.mergeList(list)
e.mergeList(d, list)
}
e.domainListLoaded = true
})
@@ -279,7 +351,8 @@ func (e *Entity) getASNLists() {
log.Tracef("intel: loading ASN list for %d", asn)
e.loadAsnListOnce.Do(func() {
list, err := filterlists.LookupASNString(fmt.Sprintf("%d", asn))
asnStr := fmt.Sprintf("%d", asn)
list, err := filterlists.LookupASNString(asnStr)
if err != nil {
log.Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err)
e.loadAsnListOnce = sync.Once{}
@@ -287,7 +360,7 @@ func (e *Entity) getASNLists() {
}
e.asnListLoaded = true
e.mergeList(list)
e.mergeList(asnStr, list)
})
}
@@ -311,7 +384,7 @@ func (e *Entity) getCountryLists() {
}
e.countryListLoaded = true
e.mergeList(list)
e.mergeList(country, list)
})
}
@@ -344,28 +417,69 @@ func (e *Entity) getIPLists() {
return
}
e.ipListLoaded = true
e.mergeList(list)
e.mergeList(ip.String(), list)
})
}
// GetLists returns the filter list identifiers the entity matched and whether this data is set.
func (e *Entity) GetLists() ([]string, bool) {
// LoadLists searches all filterlists for all occurrences of
// this entity.
func (e *Entity) LoadLists() bool {
e.getLists()
if e.Lists == nil {
return nil, false
}
return e.Lists, true
return e.ListOccurences != nil
}
// GetListsMap is like GetLists but returns a lookup map for list IDs.
func (e *Entity) GetListsMap() (filterlists.LookupMap, bool) {
e.getLists()
// MatchLists matches the entities lists against a slice
// of source IDs and updates various entity properties
// like BlockedByLists, ListOccurences and BlockedEntitites.
func (e *Entity) MatchLists(lists []string) bool {
e.BlockedByLists = nil
e.BlockedEntities = nil
if e.ListsMap == nil {
return nil, false
lm := makeMap(lists)
for key, keyLists := range e.ListOccurences {
for _, keyListID := range keyLists {
if _, ok := lm[keyListID]; ok {
e.BlockedByLists = append(e.BlockedByLists, keyListID)
e.BlockedEntities = append(e.BlockedEntities, key)
}
}
}
return e.ListsMap, true
makeDistinct(e.BlockedByLists)
makeDistinct(e.BlockedEntities)
return len(e.BlockedByLists) > 0
}
// ListBlockReason returns the block reason for this entity.
func (e *Entity) ListBlockReason() ListBlockReason {
blockedBy := make([]ListMatch, len(e.BlockedEntities))
lm := makeMap(e.BlockedByLists)
for idx, blockedEntity := range e.BlockedEntities {
if entityLists, ok := e.ListOccurences[blockedEntity]; ok {
var activeLists []string
var inactiveLists []string
for _, l := range entityLists {
if _, ok := lm[l]; ok {
activeLists = append(activeLists, l)
} else {
inactiveLists = append(inactiveLists, l)
}
}
blockedBy[idx] = ListMatch{
Entity: blockedEntity,
ActiveLists: activeLists,
InactiveLists: inactiveLists,
}
}
}
return blockedBy
}
func mergeStringList(a, b []string) []string {
@@ -385,12 +499,26 @@ func mergeStringList(a, b []string) []string {
return res
}
func buildLookupMap(l []string) filterlists.LookupMap {
m := make(filterlists.LookupMap, len(l))
func makeDistinct(slice []string) []string {
m := make(map[string]struct{}, len(slice))
result := make([]string, 0, len(slice))
for _, s := range l {
m[s] = struct{}{}
for _, v := range slice {
if _, ok := m[v]; ok {
continue
}
m[v] = struct{}{}
result = append(result, v)
}
return m
return result
}
func makeMap(slice []string) map[string]struct{} {
lm := make(map[string]struct{})
for _, v := range slice {
lm[v] = struct{}{}
}
return lm
}

View File

@@ -1,25 +0,0 @@
package filterlists
import "strings"
// LookupMap is a helper type for matching a list of endpoint sources
// against a map.
type LookupMap map[string]struct{}
// Match checks if a source in `list` is part of lm.
// Matches are joined to string and returned.
// If nothing is found, an empty string is returned.
func (lm LookupMap) Match(list []string) string {
matches := make([]string, 0, len(list))
for _, l := range list {
if _, ok := lm[l]; ok {
matches = append(matches, l)
}
}
if len(matches) == 0 {
return ""
}
return strings.Join(matches, ", ")
}

View File

@@ -1,92 +0,0 @@
package filterlists
/*
func TestLookupASN(t *testing.T) {
lists, err := LookupASNString("123")
assert.NoError(t, err)
assert.Equal(t, []string{"TEST"}, lists)
lists, err = LookupASNString("does-not-exist")
assert.NoError(t, err)
assert.Empty(t, lists)
defer testMarkNotLoaded()()
lists, err = LookupASNString("123")
assert.NoError(t, err)
assert.Empty(t, lists)
}
func TestLookupCountry(t *testing.T) {
lists, err := LookupCountry("AT")
assert.NoError(t, err)
assert.Equal(t, []string{"TEST"}, lists)
lists, err = LookupCountry("does-not-exist")
assert.NoError(t, err)
assert.Empty(t, lists)
defer testMarkNotLoaded()()
lists, err = LookupCountry("AT")
assert.NoError(t, err)
assert.Empty(t, lists)
}
func TestLookupIP(t *testing.T) {
lists, err := LookupIP(net.IP{1, 1, 1, 1})
assert.NoError(t, err)
assert.Equal(t, []string{"TEST"}, lists)
lists, err = LookupIP(net.IP{127, 0, 0, 1})
assert.NoError(t, err)
assert.Empty(t, lists)
defer testMarkNotLoaded()()
lists, err = LookupIP(net.IP{1, 1, 1, 1})
assert.NoError(t, err)
assert.Empty(t, lists)
}
func TestLookupDomain(t *testing.T) {
lists, err := LookupDomain("example.com")
assert.NoError(t, err)
assert.Equal(t, []string{"TEST"}, lists)
lists, err = LookupDomain("does-not-exist")
assert.NoError(t, err)
assert.Empty(t, lists)
defer testMarkNotLoaded()()
lists, err = LookupDomain("example.com")
assert.NoError(t, err)
assert.Empty(t, lists)
}
// testMarkNotLoaded ensures that functions believe
// filterlists are not yet loaded. It returns a
// func that restores the previous state.
func testMarkNotLoaded() func() {
if isLoaded() {
filterListsLoaded = make(chan struct{})
return func() {
close(filterListsLoaded)
}
}
return func() {}
}
// testMarkLoaded is like testMarkNotLoaded but ensures
// isLoaded() return true. It returns a function to restore
// the previous state.
func testMarkLoaded() func() {
if !isLoaded() {
close(filterListsLoaded)
return func() {
filterListsLoaded = make(chan struct{})
}
}
return func() {}
}
*/

View File

@@ -1,40 +0,0 @@
package intel
// ListSet holds a set of list IDs.
type ListSet struct {
match []string
}
// NewListSet returns a new ListSet with the given list IDs.
func NewListSet(lists []string) *ListSet {
// TODO: validate lists
return &ListSet{
match: lists,
}
}
// Matches returns whether there is a match in the given list IDs.
func (ls *ListSet) Matches(lists []string) (matches bool) {
for _, list := range lists {
for _, entry := range ls.match {
if entry == list {
return true
}
}
}
return false
}
// MatchSet returns the matching list IDs.
func (ls *ListSet) MatchSet(lists []string) (matched []string) {
for _, list := range lists {
for _, entry := range ls.match {
if entry == list {
matched = append(matched, list)
}
}
}
return
}

View File

@@ -2,6 +2,8 @@ package nameserver
import (
"context"
"errors"
"fmt"
"net"
"strings"
@@ -87,10 +89,27 @@ func stop() error {
return nil
}
func returnNXDomain(w dns.ResponseWriter, query *dns.Msg) {
func returnNXDomain(w dns.ResponseWriter, query *dns.Msg, reason string, reasonContext interface{}) {
m := new(dns.Msg)
m.SetRcode(query, dns.RcodeNameError)
_ = w.WriteMsg(m)
rr, _ := dns.NewRR("portmaster.block-reason. 0 IN TXT " + fmt.Sprintf("%q", reason))
m.Extra = []dns.RR{rr}
if reasonContext != nil {
if v, ok := reasonContext.(interface {
ToRRs() []dns.RR
}); ok {
m.Extra = append(m.Extra, v.ToRRs()...)
} else if v, ok := reasonContext.(interface {
ToRR() dns.RR
}); ok {
m.Extra = append(m.Extra, v.ToRR())
}
}
if err := w.WriteMsg(m); err != nil {
log.Errorf("nameserver: failed to send response: %s", err)
}
}
func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) {
@@ -126,7 +145,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
if question.Qclass != dns.ClassINET {
// we only serve IN records, return nxdomain
log.Warningf("nameserver: only IN record requests are supported but received Qclass %d, returning NXDOMAIN", question.Qclass)
returnNXDomain(w, query)
returnNXDomain(w, query, "wrong type", nil)
return nil
}
@@ -166,7 +185,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// check if valid domain name
if !netutils.IsValidFqdn(q.FQDN) {
log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN)
returnNXDomain(w, query)
returnNXDomain(w, query, "invalid domain", nil)
return nil
}
@@ -177,7 +196,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
// get connection
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, remoteAddr.IP, uint16(remoteAddr.Port))
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port))
// once we decided on the connection we might need to save it to the database
// so we defer that check right now.
@@ -202,10 +221,10 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// TODO: this has been obsoleted due to special profiles
if conn.Process().Profile() == nil {
tracer.Infof("nameserver: failed to find process for request %s, returning NXDOMAIN", conn)
returnNXDomain(w, query)
// NOTE(ppacher): saving unknown process connection might end up in a lot of
// processes. Consider disabling that via config.
conn.Failed("Unknown process")
returnNXDomain(w, query, "unknown process", conn.ReasonContext)
return nil
}
@@ -218,8 +237,8 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms)
if lms < 10 {
tracer.Warningf("nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", conn.Process(), q.FQDN, lms)
returnNXDomain(w, query)
conn.Block("Possible data tunnel")
returnNXDomain(w, query, "lms", conn.ReasonContext)
return nil
}
@@ -229,7 +248,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
switch conn.Verdict {
case network.VerdictBlock:
tracer.Infof("nameserver: %s blocked, returning nxdomain", conn)
returnNXDomain(w, query)
returnNXDomain(w, query, conn.Reason, conn.ReasonContext)
return nil
case network.VerdictDrop, network.VerdictFailed:
tracer.Infof("nameserver: %s dropped, not replying", conn)
@@ -241,53 +260,21 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
if err != nil {
// TODO: analyze nxdomain requests, malware could be trying DGA-domains
tracer.Warningf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err)
returnNXDomain(w, query)
conn.Failed("failed to resolve: " + err.Error())
return nil
}
// filter DNS response
rrCache = firewall.FilterDNSResponse(conn, q, rrCache)
// TODO: FilterDNSResponse also sets a connection verdict
if rrCache == nil {
tracer.Infof("nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", conn)
returnNXDomain(w, query)
conn.Block("DNS response filtered")
return nil
}
// save IP addresses to IPInfo
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
switch v := rr.(type) {
case *dns.A:
ipInfo, err := resolver.GetIPInfo(v.A.String())
if err != nil {
ipInfo = &resolver.IPInfo{
IP: v.A.String(),
Domains: []string{q.FQDN},
}
_ = ipInfo.Save()
} else {
added := ipInfo.AddDomain(q.FQDN)
if added {
_ = ipInfo.Save()
}
}
case *dns.AAAA:
ipInfo, err := resolver.GetIPInfo(v.AAAA.String())
if err != nil {
ipInfo = &resolver.IPInfo{
IP: v.AAAA.String(),
Domains: []string{q.FQDN},
}
_ = ipInfo.Save()
} else {
added := ipInfo.AddDomain(q.FQDN)
if added {
_ = ipInfo.Save()
}
}
if errors.Is(err, &resolver.BlockedUpstreamError{}) {
conn.Block(err.Error())
} else {
conn.Failed("failed to resolve: " + err.Error())
}
returnNXDomain(w, query, conn.Reason, conn.ReasonContext)
return nil
}
rrCache = firewall.DecideOnResolvedDNS(conn, q, rrCache)
if rrCache == nil {
returnNXDomain(w, query, conn.Reason, conn.ReasonContext)
return nil
}
// reply to query

View File

@@ -5,6 +5,7 @@ import (
"net"
"strings"
"github.com/safing/portbase/database"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/netenv"
@@ -164,35 +165,60 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
}
// save IP addresses to IPInfo
cnames := make(map[string]string)
ips := make(map[string]struct{})
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
switch v := rr.(type) {
case *dns.CNAME:
cnames[v.Hdr.Name] = v.Target
case *dns.A:
ipInfo, err := resolver.GetIPInfo(v.A.String())
if err != nil {
ipInfo = &resolver.IPInfo{
IP: v.A.String(),
Domains: []string{q.FQDN},
}
_ = ipInfo.Save()
} else {
added := ipInfo.AddDomain(q.FQDN)
if added {
_ = ipInfo.Save()
}
}
ips[v.A.String()] = struct{}{}
case *dns.AAAA:
ipInfo, err := resolver.GetIPInfo(v.AAAA.String())
if err != nil {
ipInfo = &resolver.IPInfo{
IP: v.AAAA.String(),
Domains: []string{q.FQDN},
}
_ = ipInfo.Save()
} else {
added := ipInfo.AddDomain(q.FQDN)
if added {
_ = ipInfo.Save()
}
ips[v.AAAA.String()] = struct{}{}
}
}
for ip := range ips {
record := resolver.ResolvedDomain{
Domain: q.FQDN,
}
// resolve all CNAMEs in the correct order.
var domain = q.FQDN
for {
nextDomain, isCNAME := cnames[domain]
if !isCNAME {
break
}
record.CNAMEs = append(record.CNAMEs, nextDomain)
domain = nextDomain
}
// get the existing IP info or create a new one
var save bool
info, err := resolver.GetIPInfo(ip)
if err != nil {
if err != database.ErrNotFound {
log.Errorf("nameserver: failed to search for IP info record: %s", err)
}
info = &resolver.IPInfo{
IP: ip,
}
save = true
}
// and the new resolved domain record and save
if new := info.AddDomain(record); new {
save = true
}
if save {
if err := info.Save(); err != nil {
log.Errorf("nameserver: failed to save IP info record: %s", err)
}
}
}

View File

@@ -31,9 +31,10 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
Entity *intel.Entity // needs locking, instance is never shared
process *process.Process
Verdict Verdict
Reason string
ReasonID string // format source[:id[:id]] // TODO
Verdict Verdict
Reason string
ReasonContext interface{}
ReasonID string // format source[:id[:id]] // TODO
Started int64
Ended int64
@@ -54,9 +55,9 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
}
// NewConnectionFromDNSRequest returns a new connection based on the given dns request.
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, ip net.IP, port uint16) *Connection {
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection {
// get Process
proc, err := process.GetProcessByEndpoints(ctx, ip, port, dnsAddress, dnsPort, packet.UDP)
proc, err := process.GetProcessByEndpoints(ctx, localIP, localPort, dnsAddress, dnsPort, packet.UDP)
if err != nil {
log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err)
proc = process.GetUnidentifiedProcess(ctx)
@@ -67,7 +68,8 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, ip net.IP, po
Scope: fqdn,
Entity: (&intel.Entity{
Domain: fqdn,
}).Init(),
CNAME: cnames,
}),
process: proc,
Started: timestamp,
Ended: timestamp,
@@ -104,7 +106,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
IP: pkt.Info().Src,
Protocol: uint8(pkt.Info().Protocol),
Port: pkt.Info().SrcPort,
}).Init()
})
} else {
@@ -113,18 +115,21 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
IP: pkt.Info().Dst,
Protocol: uint8(pkt.Info().Protocol),
Port: pkt.Info().DstPort,
}).Init()
})
// check if we can find a domain for that IP
ipinfo, err := resolver.GetIPInfo(pkt.Info().Dst.String())
if err == nil {
lastResolvedDomain := ipinfo.ResolvedDomains.MostRecentDomain()
if lastResolvedDomain != nil {
scope = lastResolvedDomain.Domain
entity.Domain = lastResolvedDomain.Domain
entity.CNAME = lastResolvedDomain.CNAMEs
removeOpenDNSRequest(proc.Pid, lastResolvedDomain.Domain)
}
}
// outbound to domain
scope = ipinfo.Domains[0]
entity.Domain = scope
removeOpenDNSRequest(proc.Pid, scope)
} else {
if scope == "" {
// outbound direct (possibly P2P) connection
switch netutils.ClassifyIP(pkt.Info().Dst) {
@@ -160,59 +165,82 @@ func GetConnection(id string) (*Connection, bool) {
return conn, ok
}
// Accept accepts the connection.
func (conn *Connection) Accept(reason string) {
if conn.SetVerdict(VerdictAccept) {
conn.Reason = reason
// AcceptWithContext accepts the connection.
func (conn *Connection) AcceptWithContext(reason string, ctx interface{}) {
if conn.SetVerdict(VerdictAccept, reason, ctx) {
log.Infof("filter: granting connection %s, %s", conn, conn.Reason)
} else {
log.Warningf("filter: tried to accept %s, but current verdict is %s", conn, conn.Verdict)
}
}
// Block blocks the connection.
func (conn *Connection) Block(reason string) {
if conn.SetVerdict(VerdictBlock) {
conn.Reason = reason
// Accept is like AcceptWithContext but only accepts a reason.
func (conn *Connection) Accept(reason string) {
conn.AcceptWithContext(reason, nil)
}
// BlockWithContext blocks the connection.
func (conn *Connection) BlockWithContext(reason string, ctx interface{}) {
if conn.SetVerdict(VerdictBlock, reason, ctx) {
log.Infof("filter: blocking connection %s, %s", conn, conn.Reason)
} else {
log.Warningf("filter: tried to block %s, but current verdict is %s", conn, conn.Verdict)
}
}
// Drop drops the connection.
func (conn *Connection) Drop(reason string) {
if conn.SetVerdict(VerdictDrop) {
conn.Reason = reason
// Block is like BlockWithContext but does only accepts a reason.
func (conn *Connection) Block(reason string) {
conn.BlockWithContext(reason, nil)
}
// DropWithContext drops the connection.
func (conn *Connection) DropWithContext(reason string, ctx interface{}) {
if conn.SetVerdict(VerdictDrop, reason, ctx) {
log.Infof("filter: dropping connection %s, %s", conn, conn.Reason)
} else {
log.Warningf("filter: tried to drop %s, but current verdict is %s", conn, conn.Verdict)
}
}
// Deny blocks or drops the link depending on the connection direction.
func (conn *Connection) Deny(reason string) {
// Drop is like DropWithContext but does only accepts a reason.
func (conn *Connection) Drop(reason string) {
conn.DropWithContext(reason, nil)
}
// DenyWithContext blocks or drops the link depending on the connection direction.
func (conn *Connection) DenyWithContext(reason string, ctx interface{}) {
if conn.Inbound {
conn.Drop(reason)
conn.DropWithContext(reason, ctx)
} else {
conn.Block(reason)
conn.BlockWithContext(reason, ctx)
}
}
// Failed marks the connection with VerdictFailed and stores the reason.
func (conn *Connection) Failed(reason string) {
if conn.SetVerdict(VerdictFailed) {
conn.Reason = reason
// Deny is like DenyWithContext but only accepts a reason.
func (conn *Connection) Deny(reason string) {
conn.DenyWithContext(reason, nil)
}
// FailedWithContext marks the connection with VerdictFailed and stores the reason.
func (conn *Connection) FailedWithContext(reason string, ctx interface{}) {
if conn.SetVerdict(VerdictFailed, reason, ctx) {
log.Infof("filter: dropping connection %s because of an internal error: %s", conn, reason)
} else {
log.Warningf("filter: tried to drop %s due to error but current verdict is %s", conn, conn.Verdict)
}
}
// Failed is like FailedWithContext but only accepts a string.
func (conn *Connection) Failed(reason string) {
conn.FailedWithContext(reason, nil)
}
// SetVerdict sets a new verdict for the connection, making sure it does not interfere with previous verdicts.
func (conn *Connection) SetVerdict(newVerdict Verdict) (ok bool) {
func (conn *Connection) SetVerdict(newVerdict Verdict, reason string, ctx interface{}) (ok bool) {
if newVerdict >= conn.Verdict {
conn.Verdict = newVerdict
conn.Reason = reason
conn.ReasonContext = ctx
return true
}
return false

View File

@@ -30,6 +30,9 @@ var (
CfgOptionFilterSubDomainsKey = "filter/includeSubdomains"
cfgOptionFilterSubDomains config.IntOption // security level option
CfgOptionFilterCNAMEKey = "filter/includeCNAMEs"
cfgOptionFilterCNAME config.IntOption // security level option
CfgOptionBlockScopeLocalKey = "filter/blockLocal"
cfgOptionBlockScopeLocal config.IntOption // security level option
@@ -180,6 +183,24 @@ Examples:
cfgOptionFilterLists = config.Concurrent.GetAsStringArray(CfgOptionFilterListKey, []string{})
cfgStringArrayOptions[CfgOptionFilterListKey] = cfgOptionFilterLists
// Include CNAMEs
err = config.Register(&config.Option{
Name: "Filter CNAMEs",
Key: CfgOptionFilterCNAMEKey,
Description: "Also filter requests where a CNAME would be blocked",
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: status.SecurityLevelsAll,
ValidationRegex: "^(7|6|4)$",
ExpertiseLevel: config.ExpertiseLevelExpert,
})
if err != nil {
return err
}
cfgOptionFilterCNAME = config.Concurrent.GetAsInt(CfgOptionFilterCNAMEKey, int64(status.SecurityLevelsAll))
cfgIntOptions[CfgOptionFilterCNAMEKey] = cfgOptionFilterCNAME
// Include subdomains
err = config.Register(&config.Option{
Name: "Filter SubDomains",
Key: CfgOptionFilterSubDomainsKey,

View File

@@ -8,8 +8,8 @@ type EndpointAny struct {
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointAny) Matches(entity *intel.Entity) (result EPResult, reason string) {
return ep.matchesPPP(entity), "matches *"
func (ep *EndpointAny) Matches(entity *intel.Entity) (EPResult, Reason) {
return ep.match(ep, entity, "*", "matches")
}
func (ep *EndpointAny) String() string {

View File

@@ -16,24 +16,22 @@ var (
type EndpointASN struct {
EndpointBase
ASN uint
Reason string
ASN uint
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointASN) Matches(entity *intel.Entity) (result EPResult, reason string) {
if entity.IP == nil {
return Undeterminable, ""
}
func (ep *EndpointASN) Matches(entity *intel.Entity) (EPResult, Reason) {
asn, ok := entity.GetASN()
if !ok {
return Undeterminable, ""
return Undeterminable, nil
}
if asn == ep.ASN {
return ep.matchesPPP(entity), ep.Reason
asnStr := strconv.Itoa(int(ep.ASN))
return ep.match(ep, entity, asnStr, "IP is part of AS")
}
return NoMatch, ""
return NoMatch, nil
}
func (ep *EndpointASN) String() string {
@@ -48,8 +46,7 @@ func parseTypeASN(fields []string) (Endpoint, error) {
}
ep := &EndpointASN{
ASN: uint(asn),
Reason: "IP is part of AS" + strconv.FormatInt(int64(asn), 10),
ASN: uint(asn),
}
return ep.parsePPP(ep, fields)
}

View File

@@ -19,19 +19,16 @@ type EndpointCountry struct {
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointCountry) Matches(entity *intel.Entity) (result EPResult, reason string) {
if entity.IP == nil {
return Undeterminable, ""
}
func (ep *EndpointCountry) Matches(entity *intel.Entity) (EPResult, Reason) {
country, ok := entity.GetCountry()
if !ok {
return Undeterminable, ""
return Undeterminable, nil
}
if country == ep.Country {
return ep.matchesPPP(entity), "IP is located in " + country
return ep.match(ep, entity, country, "IP is located in")
}
return NoMatch, ""
return NoMatch, nil
}
func (ep *EndpointCountry) String() string {

View File

@@ -28,42 +28,60 @@ type EndpointDomain struct {
Domain string
DomainZone string
MatchType uint8
Reason string
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointDomain) Matches(entity *intel.Entity) (result EPResult, reason string) {
if entity.Domain == "" {
return NoMatch, ""
}
func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, Reason) {
result, reason := ep.match(ep, entity, ep.Domain, "domain matches")
switch ep.MatchType {
case domainMatchTypeExact:
if entity.Domain == ep.Domain {
return ep.matchesPPP(entity), ep.Reason
if domain == ep.Domain {
return result, reason
}
case domainMatchTypeZone:
if entity.Domain == ep.Domain {
return ep.matchesPPP(entity), ep.Reason
if domain == ep.Domain {
return result, reason
}
if strings.HasSuffix(entity.Domain, ep.DomainZone) {
return ep.matchesPPP(entity), ep.Reason
if strings.HasSuffix(domain, ep.DomainZone) {
return result, reason
}
case domainMatchTypeSuffix:
if strings.HasSuffix(entity.Domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
if strings.HasSuffix(domain, ep.Domain) {
return result, reason
}
case domainMatchTypePrefix:
if strings.HasPrefix(entity.Domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
if strings.HasPrefix(domain, ep.Domain) {
return result, reason
}
case domainMatchTypeContains:
if strings.Contains(entity.Domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
if strings.Contains(domain, ep.Domain) {
return result, reason
}
}
return NoMatch, nil
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointDomain) Matches(entity *intel.Entity) (EPResult, Reason) {
if entity.Domain == "" {
return NoMatch, nil
}
result, reason := ep.check(entity, entity.Domain)
if result != NoMatch {
return result, reason
}
if entity.CNAMECheckEnabled() {
for _, domain := range entity.CNAME {
result, reason = ep.check(entity, domain)
if result == Denied {
return result, reason
}
}
}
return NoMatch, ""
return NoMatch, nil
}
func (ep *EndpointDomain) String() string {
@@ -76,7 +94,6 @@ func parseTypeDomain(fields []string) (Endpoint, error) {
if domainRegex.MatchString(domain) || altDomainRegex.MatchString(domain) {
ep := &EndpointDomain{
OriginalValue: domain,
Reason: "domain matches " + domain,
}
// fix domain ending

View File

@@ -10,19 +10,19 @@ import (
type EndpointIP struct {
EndpointBase
IP net.IP
Reason string
IP net.IP
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointIP) Matches(entity *intel.Entity) (result EPResult, reason string) {
func (ep *EndpointIP) Matches(entity *intel.Entity) (EPResult, Reason) {
if entity.IP == nil {
return Undeterminable, ""
return Undeterminable, nil
}
if ep.IP.Equal(entity.IP) {
return ep.matchesPPP(entity), ep.Reason
return ep.match(ep, entity, ep.IP.String(), "IP matches")
}
return NoMatch, ""
return NoMatch, nil
}
func (ep *EndpointIP) String() string {
@@ -33,8 +33,7 @@ func parseTypeIP(fields []string) (Endpoint, error) {
ip := net.ParseIP(fields[1])
if ip != nil {
ep := &EndpointIP{
IP: ip,
Reason: "IP is " + ip.String(),
IP: ip,
}
return ep.parsePPP(ep, fields)
}

View File

@@ -10,19 +10,18 @@ import (
type EndpointIPRange struct {
EndpointBase
Net *net.IPNet
Reason string
Net *net.IPNet
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointIPRange) Matches(entity *intel.Entity) (result EPResult, reason string) {
func (ep *EndpointIPRange) Matches(entity *intel.Entity) (EPResult, Reason) {
if entity.IP == nil {
return Undeterminable, ""
return Undeterminable, nil
}
if ep.Net.Contains(entity.IP) {
return ep.matchesPPP(entity), ep.Reason
return ep.match(ep, entity, ep.Net.String(), "IP is in")
}
return NoMatch, ""
return NoMatch, nil
}
func (ep *EndpointIPRange) String() string {
@@ -33,8 +32,7 @@ func parseTypeIPRange(fields []string) (Endpoint, error) {
_, net, err := net.ParseCIDR(fields[1])
if err == nil {
ep := &EndpointIPRange{
Net: net,
Reason: "IP is part of " + net.String(),
Net: net,
}
return ep.parsePPP(ep, fields)
}

View File

@@ -10,22 +10,21 @@ import (
type EndpointLists struct {
EndpointBase
ListSet *intel.ListSet
ListSet []string
Lists string
Reason string
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointLists) Matches(entity *intel.Entity) (result EPResult, reason string) {
lists, ok := entity.GetLists()
if !ok {
return Undeterminable, ""
func (ep *EndpointLists) Matches(entity *intel.Entity) (EPResult, Reason) {
if !entity.LoadLists() {
return Undeterminable, nil
}
matched := ep.ListSet.MatchSet(lists)
if len(matched) > 0 {
return ep.matchesPPP(entity), ep.Reason
if entity.MatchLists(ep.ListSet) {
return ep.match(ep, entity, ep.Lists, "filterlist contains", "filterlist", entity.ListBlockReason())
}
return NoMatch, ""
return NoMatch, nil
}
func (ep *EndpointLists) String() string {
@@ -36,9 +35,8 @@ func parseTypeList(fields []string) (Endpoint, error) {
if strings.HasPrefix(fields[1], "L:") {
lists := strings.Split(strings.TrimPrefix(fields[1], "L:"), ",")
ep := &EndpointLists{
ListSet: intel.NewListSet(lists),
ListSet: lists,
Lists: "L:" + strings.Join(lists, ","),
Reason: "matched lists " + strings.Join(lists, ","),
}
return ep.parsePPP(ep, fields)
}

View File

@@ -11,7 +11,7 @@ import (
// Endpoint describes an Endpoint Matcher
type Endpoint interface {
Matches(entity *intel.Entity) (result EPResult, reason string)
Matches(entity *intel.Entity) (EPResult, Reason)
String() string
}
@@ -24,6 +24,35 @@ type EndpointBase struct { //nolint:maligned // TODO
Permitted bool
}
func (ep *EndpointBase) match(s fmt.Stringer, entity *intel.Entity, value, desc string, keyval ...interface{}) (EPResult, Reason) {
result := ep.matchesPPP(entity)
if result == Undeterminable || result == NoMatch {
return result, nil
}
return result, ep.makeReason(s, value, desc, keyval...)
}
func (ep *EndpointBase) makeReason(s fmt.Stringer, value, desc string, keyval ...interface{}) Reason {
r := &reason{
description: desc,
Filter: ep.renderPPP(s.String()),
Permitted: ep.Permitted,
Value: value,
}
r.Extra = make(map[string]interface{})
for idx := 0; idx < len(keyval)/2; idx += 2 {
key := keyval[idx]
val := keyval[idx+1]
r.Extra[key.(string)] = val
}
return r
}
func (ep *EndpointBase) matchesPPP(entity *intel.Entity) (result EPResult) {
// only check if protocol is defined
if ep.Protocol > 0 {

View File

@@ -21,6 +21,12 @@ const (
Permitted
)
// IsDecision returns true if result represents a decision
// and false if result is NoMatch or Undeterminable.
func IsDecision(result EPResult) bool {
return result == Denied || result == Permitted || result == Undeterminable
}
// ParseEndpoints parses a list of endpoints and returns a list of Endpoints for matching.
func ParseEndpoints(entries []string) (Endpoints, error) {
var firstErr error
@@ -57,7 +63,7 @@ func (e Endpoints) IsSet() bool {
}
// Match checks whether the given entity matches any of the endpoint definitions in the list.
func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason string) {
func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason Reason) {
for _, entry := range e {
if entry != nil {
if result, reason = entry.Matches(entity); result != NoMatch {
@@ -66,7 +72,7 @@ func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason string)
}
}
return NoMatch, ""
return NoMatch, nil
}
func (e Endpoints) String() string {

View File

@@ -0,0 +1,34 @@
package endpoints
// Reason describes the reason why an endpoint has been
// permitted or blocked.
type Reason interface {
// String should return a human readable string
// describing the decision reason.
String() string
// Context returns the context that was used
// for the decision.
Context() interface{}
}
type reason struct {
description string
Filter string
Value string
Permitted bool
Extra map[string]interface{}
}
func (r *reason) String() string {
prefix := "endpoint in blocklist: "
if r.Permitted {
prefix = "endpoint in whitelist: "
}
return prefix + r.description + " " + r.Value
}
func (r *reason) Context() interface{} {
return r
}

View File

@@ -43,6 +43,7 @@ type LayeredProfile struct {
RemoveOutOfScopeDNS config.BoolOption
RemoveBlockedDNS config.BoolOption
FilterSubDomains config.BoolOption
FilterCNAMEs config.BoolOption
PreventBypassing config.BoolOption
}
@@ -99,6 +100,10 @@ func NewLayeredProfile(localProfile *Profile) *LayeredProfile {
CfgOptionFilterSubDomainsKey,
cfgOptionFilterSubDomains,
)
new.FilterCNAMEs = new.wrapSecurityLevelOption(
CfgOptionFilterCNAMEKey,
cfgOptionFilterCNAME,
)
new.PreventBypassing = new.wrapSecurityLevelOption(
CfgOptionPreventBypassingKey,
cfgOptionPreventBypassing,
@@ -199,12 +204,12 @@ func (lp *LayeredProfile) DefaultAction() uint8 {
}
// MatchEndpoint checks if the given endpoint matches an entry in any of the profiles.
func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (result endpoints.EPResult, reason string) {
func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
for _, layer := range lp.layers {
if layer.endpoints.IsSet() {
result, reason = layer.endpoints.Match(entity)
if result != endpoints.NoMatch {
return
result, reason := layer.endpoints.Match(entity)
if endpoints.IsDecision(result) {
return result, reason
}
}
}
@@ -215,14 +220,14 @@ func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (result endpoints.
}
// MatchServiceEndpoint checks if the given endpoint of an inbound connection matches an entry in any of the profiles.
func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (result endpoints.EPResult, reason string) {
func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
entity.EnableReverseResolving()
for _, layer := range lp.layers {
if layer.serviceEndpoints.IsSet() {
result, reason = layer.serviceEndpoints.Match(entity)
if result != endpoints.NoMatch {
return
result, reason := layer.serviceEndpoints.Match(entity)
if endpoints.IsDecision(result) {
return result, reason
}
}
}
@@ -234,33 +239,34 @@ func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (result end
// MatchFilterLists matches the entity against the set of filter
// lists.
func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, string) {
func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
entity.ResolveSubDomainLists(lp.FilterSubDomains())
lookupMap, hasLists := entity.GetListsMap()
if !hasLists {
return endpoints.NoMatch, ""
}
entity.EnableCNAMECheck(lp.FilterCNAMEs())
for _, layer := range lp.layers {
if reason := lookupMap.Match(layer.filterListIDs); reason != "" {
return endpoints.Denied, reason
}
// only check the first layer that has filter list
// IDs defined.
// search for the first layer that has filterListIDs set
if len(layer.filterListIDs) > 0 {
return endpoints.NoMatch, ""
entity.LoadLists()
if entity.MatchLists(layer.filterListIDs) {
return endpoints.Denied, entity.ListBlockReason()
}
return endpoints.NoMatch, nil
}
}
cfgLock.RLock()
defer cfgLock.RUnlock()
if reason := lookupMap.Match(cfgFilterLists); reason != "" {
return endpoints.Denied, reason
if len(cfgFilterLists) > 0 {
entity.LoadLists()
if entity.MatchLists(cfgFilterLists) {
return endpoints.Denied, entity.ListBlockReason()
}
}
return endpoints.NoMatch, ""
return endpoints.NoMatch, nil
}
// AddEndpoint adds an endpoint to the local endpoint list, saves the local profile and reloads the configuration.

View File

@@ -0,0 +1,61 @@
package resolver
import (
"net"
"github.com/miekg/dns"
)
// Supported upstream block detections
const (
BlockDetectionRefused = "refused"
BlockDetectionZeroIP = "zeroip"
BlockDetectionEmptyAnswer = "empty"
BlockDetectionDisabled = "disabled"
)
func isBlockedUpstream(resolver *Resolver, answer *dns.Msg) bool {
if resolver.UpstreamBlockDetection == BlockDetectionDisabled {
return false
}
switch resolver.UpstreamBlockDetection {
case BlockDetectionRefused:
return answer.Rcode == dns.RcodeRefused
case BlockDetectionZeroIP:
if answer.Rcode != dns.RcodeSuccess {
return false
}
var ips []net.IP
for _, rr := range answer.Answer {
switch v := rr.(type) {
case *dns.A:
ips = append(ips, v.A)
case *dns.AAAA:
ips = append(ips, v.AAAA)
}
}
if len(ips) == 0 {
return false // we expected an empty IP
}
for _, ip := range ips {
if ip.To4() != nil {
if !ip.Equal(net.IPv4zero) {
return false
}
} else {
if !ip.To16().Equal(net.IPv6zero) {
return false
}
}
}
return true
case BlockDetectionEmptyAnswer:
return answer.Rcode == dns.RcodeNameError && len(answer.Ns) == 0 && len(answer.Answer) == 0 && len(answer.Extra) == 0
}
return false
}

View File

@@ -29,28 +29,30 @@ var (
// We encourage everyone who has the technical abilities to set their own preferred servers.
// Default 1: Cloudflare
"dot://1.1.1.1:853?verify=cloudflare-dns.com", // Cloudflare
"dot://1.0.0.1:853?verify=cloudflare-dns.com", // Cloudflare
"dot://1.1.1.1:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", // Cloudflare
"dot://1.0.0.1:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", // Cloudflare
// Default 2: Quad9
"dot://9.9.9.9:853?verify=dns.quad9.net", // Quad9
"dot://149.112.112.112:853?verify=dns.quad9.net", // Quad9
"dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", // Quad9
"dot://149.112.112.112:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", // Quad9
// Fallback 1: Cloudflare
"dns://1.1.1.1:53", // Cloudflare
"dns://1.0.0.1:53", // Cloudflare
"dns://1.1.1.1:53?name=Cloudflare&blockedif=zeroip", // Cloudflare
"dns://1.0.0.1:53?name=Cloudflare&blockedif=zeroip", // Cloudflare
// Fallback 2: Quad9
"dns://9.9.9.9:53", // Quad9
"dns://149.112.112.112:53", // Quad9
"dns://9.9.9.9:53?name=Quad9&blockedif=empty", // Quad9
"dns://149.112.112.112:53?name=Quad9&blockedif=empty", // Quad9
// supported parameters
// - `verify=domain`: verify domain (dot only)
// future parameters:
//
// - `name=name`: human readable name for resolver
// - `blockedif=baredns`: how to detect if the dns service blocked something
// - `baredns`: NXDomain result, but without any other record in any section
// - `blockedif=empty`: how to detect if the dns service blocked something
// - `empty`: NXDomain result, but without any other record in any section
// - `refused`: Request was refused
// - `zeroip`: Answer only contains zeroip
}
CfgOptionNameServersKey = "dns/nameservers"

View File

@@ -16,13 +16,92 @@ var (
})
)
// ResolvedDomain holds a Domain name and a list of
// CNAMES that have been resolved.
type ResolvedDomain struct {
// Domain is the domain as requested by the application.
Domain string
// CNAMEs is a list of CNAMEs that have been resolved for
// Domain.
CNAMEs []string
}
// String returns a string representation of ResolvedDomain including
// the CNAME chain. It implements fmt.Stringer
func (resolved *ResolvedDomain) String() string {
ret := resolved.Domain
cnames := ""
if len(resolved.CNAMEs) > 0 {
cnames = " (-> " + strings.Join(resolved.CNAMEs, "->") + ")"
}
return ret + cnames
}
// ResolvedDomains is a helper type for operating on a slice
// of ResolvedDomain
type ResolvedDomains []ResolvedDomain
// String returns a string representation of all domains joined
// to a single string.
func (rds ResolvedDomains) String() string {
domains := make([]string, len(rds))
for idx, n := range rds {
domains[idx] = n.String()
}
return strings.Join(domains, " or ")
}
// MostRecentDomain returns the most recent domain.
func (rds ResolvedDomains) MostRecentDomain() *ResolvedDomain {
if len(rds) == 0 {
return nil
}
// TODO(ppacher): we could also do that by using ResolvedAt()
mostRecent := rds[len(rds)-1]
return &mostRecent
}
// IPInfo represents various information about an IP.
type IPInfo struct {
record.Base
sync.Mutex
IP string
Domains []string
// IP holds the acutal IP address.
IP string
// Domains holds a list of domains that have been
// resolved to IP. This field is deprecated and should
// be removed.
// DEPRECATED: remove with alpha.
Domains []string `json:"Domains,omitempty"`
// ResolvedDomain is a slice of domains that
// have been requested by various applications
// and have been resolved to IP.
ResolvedDomains ResolvedDomains
}
// AddDomain adds a new resolved domain to ipi.
func (ipi *IPInfo) AddDomain(resolved ResolvedDomain) bool {
for idx, d := range ipi.ResolvedDomains {
if d.Domain == resolved.Domain {
if utils.StringSliceEqual(d.CNAMEs, resolved.CNAMEs) {
return false
}
// we have a different CNAME chain now, remove the previous
// entry and add it at the end.
ipi.ResolvedDomains = append(ipi.ResolvedDomains[:idx], ipi.ResolvedDomains[idx+1:]...)
ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved)
return true
}
}
ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved)
return true
}
func makeIPInfoKey(ip string) string {
@@ -46,6 +125,19 @@ func GetIPInfo(ip string) (*IPInfo, error) {
if err != nil {
return nil, err
}
// Legacy support,
// DEPRECATED: remove with alpha
if len(new.Domains) > 0 && len(new.ResolvedDomains) == 0 {
for _, d := range new.Domains {
new.ResolvedDomains = append(new.ResolvedDomains, ResolvedDomain{
Domain: d,
// rest is empty...
})
}
new.Domains = nil // clean up so we remove it from the database
}
return new, nil
}
@@ -57,17 +149,6 @@ func GetIPInfo(ip string) (*IPInfo, error) {
return new, nil
}
// AddDomain adds a domain to the list and reports back if it was added, or was already present.
func (ipi *IPInfo) AddDomain(domain string) (added bool) {
ipi.Lock()
defer ipi.Unlock()
if !utils.StringInSlice(ipi.Domains, domain) {
ipi.Domains = append([]string{domain}, ipi.Domains...)
return true
}
return false
}
// Save saves the IPInfo record to the database.
func (ipi *IPInfo) Save() error {
ipi.Lock()
@@ -75,17 +156,21 @@ func (ipi *IPInfo) Save() error {
ipi.SetKey(makeIPInfoKey(ipi.IP))
}
ipi.Unlock()
return ipInfoDatabase.Put(ipi)
}
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) FmtDomains() string {
return strings.Join(ipi.Domains, " or ")
// Legacy support
// Ensure we don't write new Domain fields into the
// database.
// DEPRECATED: remove with alpha
if len(ipi.Domains) > 0 {
ipi.Domains = nil
}
return ipInfoDatabase.Put(ipi)
}
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) String() string {
ipi.Lock()
defer ipi.Unlock()
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.FmtDomains())
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.ResolvedDomains.String())
}

View File

@@ -1,25 +1,48 @@
package resolver
import "testing"
import (
"testing"
func testDomains(t *testing.T, ipi *IPInfo, expectedDomains string) {
if ipi.FmtDomains() != expectedDomains {
t.Errorf("unexpected domains '%s', expected '%s'", ipi.FmtDomains(), expectedDomains)
}
}
"github.com/stretchr/testify/assert"
)
func TestIPInfo(t *testing.T) {
ipi := &IPInfo{
IP: "1.2.3.4",
Domains: []string{"example.com.", "sub.example.com."},
example := ResolvedDomain{
Domain: "example.com.",
}
subExample := ResolvedDomain{
Domain: "sub1.example.com",
CNAMEs: []string{"example.com"},
}
testDomains(t, ipi, "example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("sub.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi := &IPInfo{
IP: "1.2.3.4",
ResolvedDomains: ResolvedDomains{
example,
subExample,
},
}
sub2Example := ResolvedDomain{
Domain: "sub2.example.com",
CNAMEs: []string{"sub1.example.com", "example.com"},
}
added := ipi.AddDomain(sub2Example)
assert.True(t, added)
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains)
// try again, should do nothing now
added = ipi.AddDomain(sub2Example)
assert.False(t, added)
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains)
subOverWrite := ResolvedDomain{
Domain: "sub1.example.com",
CNAMEs: []string{}, // now without CNAMEs
}
added = ipi.AddDomain(subOverWrite)
assert.True(t, added)
assert.Equal(t, ResolvedDomains{example, sub2Example, subOverWrite}, ipi.ResolvedDomains)
}

View File

@@ -37,6 +37,21 @@ var (
ErrNoCompliance = fmt.Errorf("%w: no compliant resolvers for this query", ErrBlocked)
)
// BlockedUpstreamError is returned when a DNS request
// has been blocked by the upstream server.
type BlockedUpstreamError struct {
ResolverName string
}
func (blocked *BlockedUpstreamError) Error() string {
return fmt.Sprintf("Endpoint blocked by upstream DNS resolver %s", blocked.ResolverName)
}
// Unwrap implements errors.Unwrapper
func (blocked *BlockedUpstreamError) Unwrap() error {
return ErrBlocked
}
// Query describes a dns query.
type Query struct {
FQDN string

View File

@@ -28,6 +28,19 @@ type Resolver struct {
// Server config url (and ID)
Server string
// Name is the name of the resolver as passed via
// ?name=.
Name string
// UpstreamBlockDetection defines the detection type
// to identifier upstream DNS query blocking.
// Valid values are:
// - zeroip
// - empty
// - refused (default)
// - disabled
UpstreamBlockDetection string
// Parsed config
ServerType string
ServerAddress string
@@ -46,9 +59,25 @@ type Resolver struct {
Conn ResolverConn
}
// IsBlockedUpstream returns true if the request has been blocked
// upstream.
func (resolver *Resolver) IsBlockedUpstream(answer *dns.Msg) bool {
return isBlockedUpstream(resolver, answer)
}
// GetName returns the name of the server. If no name
// is configured the server address is returned.
func (resolver *Resolver) GetName() string {
if resolver.Name != "" {
return resolver.Name
}
return resolver.Server
}
// String returns the URL representation of the resolver.
func (resolver *Resolver) String() string {
return resolver.Server
return resolver.GetName()
}
// ResolverConn is an interface to implement different types of query backends.
@@ -126,6 +155,10 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
break
}
if resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{resolver.GetName()}
}
// no error
break
}

View File

@@ -107,13 +107,26 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
return nil, false, fmt.Errorf("DOT must have a verify query parameter set")
}
blockType := query.Get("blockedif")
if blockType == "" {
blockType = BlockDetectionRefused
}
switch blockType {
case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP:
default:
return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)")
}
new := &Resolver{
Server: resolverURL,
ServerType: u.Scheme,
ServerAddress: u.Host,
ServerIPScope: scope,
Source: source,
VerifyDomain: verifyDomain,
Server: resolverURL,
ServerType: u.Scheme,
ServerAddress: u.Host,
ServerIPScope: scope,
Source: source,
VerifyDomain: verifyDomain,
Name: query.Get("name"),
UpstreamBlockDetection: blockType,
}
newConn := &BasicResolverConn{