From bffe4a9eaff9ef01978c630f31f7c5472b2213d8 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Fri, 17 Apr 2020 15:55:52 +0200 Subject: [PATCH 1/7] Add CNAME blocking support --- intel/entity.go | 63 ++++++++++++-- nameserver/nameserver.go | 124 +++++++++++++++++++-------- nameserver/only/nameserver.go | 76 ++++++++++------ network/connection.go | 26 +++--- profile/config.go | 20 +++++ profile/endpoints/endpoint-domain.go | 88 ++++++++++++++----- profile/profile-layered.go | 6 ++ resolver/ipinfo.go | 123 ++++++++++++++++++++++---- resolver/ipinfo_test.go | 55 ++++++++---- 9 files changed, 445 insertions(+), 136 deletions(-) diff --git a/intel/entity.go b/intel/entity.go index 290da215..39a3765c 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -32,14 +32,32 @@ type Entity struct { asnListLoaded bool reverseResolveEnabled bool resolveSubDomainLists bool + checkCNAMEs bool + // Protocol is the protcol number used by the connection. Protocol uint8 - Port uint16 - Domain string - IP net.IP - Country string - ASN uint + // Port is the destination port of the connection + Port uint16 + + // Domain is the target domain of the connection. + Domain string + + // CNAME is a list of domain names that have been + // resolved for Domain. + CNAME []string + + // IP is the IP address of the connection. If domain is + // set, IP has been resolved by following all CNAMEs. + IP net.IP + + // Country holds the country the IP address (ASN) is + // located in. + Country string + + // ASN holds the autonomous system number of the IP. + ASN uint + location *geoip.Location Lists []string @@ -79,6 +97,7 @@ func (e *Entity) ResetLists() { e.countryListLoaded = false e.asnListLoaded = false e.resolveSubDomainLists = false + e.checkCNAMEs = false e.loadDomainListOnce = sync.Once{} e.loadIPListOnce = sync.Once{} e.loadCoutryListOnce = sync.Once{} @@ -94,6 +113,21 @@ func (e *Entity) ResolveSubDomainLists(enabled bool) { e.resolveSubDomainLists = enabled } +// EnableCNAMECheck enalbes or disables list lookups for +// entity CNAMEs. +func (e *Entity) EnableCNAMECheck(enabled bool) { + if e.domainListLoaded { + log.Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain) + } + e.checkCNAMEs = enabled +} + +// CNAMECheckEnabled returns true if the entities CNAMEs should +// also be checked. +func (e *Entity) CNAMECheckEnabled() bool { + return e.checkCNAMEs +} + // Domain and IP // EnableReverseResolving enables reverse resolving the domain from the IP on demand. @@ -220,10 +254,23 @@ func (e *Entity) getDomainLists() { } e.loadDomainListOnce.Do(func() { - var domains = []string{domain} + var domainsToInspect = []string{domain} + + if e.checkCNAMEs { + log.Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME) + domainsToInspect = append(domainsToInspect, e.CNAME...) + } + + var domains []string if e.resolveSubDomainLists { - domains = splitDomain(domain) - log.Tracef("intel: subdomain list resolving is enabled, checking %v", domains) + for _, domain := range domainsToInspect { + subdomains := splitDomain(domain) + domains = append(domains, subdomains...) + + log.Tracef("intel: subdomain list resolving is enabled: %s => %v", domains, subdomains) + } + } else { + domains = domainsToInspect } for _, d := range domains { diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 8f2dbd65..f02151fe 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -2,9 +2,11 @@ package nameserver import ( "context" + "fmt" "net" "strings" + "github.com/safing/portbase/database" "github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/log" @@ -14,6 +16,7 @@ import ( "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" + "github.com/safing/portmaster/profile/endpoints" "github.com/safing/portmaster/resolver" "github.com/miekg/dns" @@ -87,9 +90,11 @@ func stop() error { return nil } -func returnNXDomain(w dns.ResponseWriter, query *dns.Msg) { +func returnNXDomain(w dns.ResponseWriter, query *dns.Msg, reason string) { m := new(dns.Msg) m.SetRcode(query, dns.RcodeNameError) + rr, _ := dns.NewRR("portmaster.block.reason. 0 IN TXT " + fmt.Sprintf("%q", reason)) + m.Extra = []dns.RR{rr} _ = w.WriteMsg(m) } @@ -126,7 +131,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er if question.Qclass != dns.ClassINET { // we only serve IN records, return nxdomain log.Warningf("nameserver: only IN record requests are supported but received Qclass %d, returning NXDOMAIN", question.Qclass) - returnNXDomain(w, query) + returnNXDomain(w, query, "wrong type") return nil } @@ -166,7 +171,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // check if valid domain name if !netutils.IsValidFqdn(q.FQDN) { log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN) - returnNXDomain(w, query) + returnNXDomain(w, query, "invalid domain") return nil } @@ -177,7 +182,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain // get connection - conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, remoteAddr.IP, uint16(remoteAddr.Port)) + conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port)) // once we decided on the connection we might need to save it to the database // so we defer that check right now. @@ -202,7 +207,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // TODO: this has been obsoleted due to special profiles if conn.Process().Profile() == nil { tracer.Infof("nameserver: failed to find process for request %s, returning NXDOMAIN", conn) - returnNXDomain(w, query) + returnNXDomain(w, query, "unknown process") // NOTE(ppacher): saving unknown process connection might end up in a lot of // processes. Consider disabling that via config. conn.Failed("Unknown process") @@ -218,7 +223,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) if lms < 10 { tracer.Warningf("nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", conn.Process(), q.FQDN, lms) - returnNXDomain(w, query) + returnNXDomain(w, query, "lms") conn.Block("Possible data tunnel") return nil } @@ -229,7 +234,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er switch conn.Verdict { case network.VerdictBlock: tracer.Infof("nameserver: %s blocked, returning nxdomain", conn) - returnNXDomain(w, query) + returnNXDomain(w, query, conn.Reason) return nil case network.VerdictDrop, network.VerdictFailed: tracer.Infof("nameserver: %s dropped, not replying", conn) @@ -241,7 +246,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er if err != nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains tracer.Warningf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err) - returnNXDomain(w, query) + returnNXDomain(w, query, conn.Reason) conn.Failed("failed to resolve: " + err.Error()) return nil } @@ -251,41 +256,92 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // TODO: FilterDNSResponse also sets a connection verdict if rrCache == nil { tracer.Infof("nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", conn) - returnNXDomain(w, query) + returnNXDomain(w, query, conn.Reason) conn.Block("DNS response filtered") return nil } // save IP addresses to IPInfo + cnames := make(map[string]string) + ips := make(map[string]struct{}) + for _, rr := range append(rrCache.Answer, rrCache.Extra...) { switch v := rr.(type) { + case *dns.CNAME: + cnames[v.Hdr.Name] = v.Target + case *dns.A: - ipInfo, err := resolver.GetIPInfo(v.A.String()) - if err != nil { - ipInfo = &resolver.IPInfo{ - IP: v.A.String(), - Domains: []string{q.FQDN}, - } - _ = ipInfo.Save() - } else { - added := ipInfo.AddDomain(q.FQDN) - if added { - _ = ipInfo.Save() - } - } + ips[v.A.String()] = struct{}{} + case *dns.AAAA: - ipInfo, err := resolver.GetIPInfo(v.AAAA.String()) - if err != nil { - ipInfo = &resolver.IPInfo{ - IP: v.AAAA.String(), - Domains: []string{q.FQDN}, - } - _ = ipInfo.Save() - } else { - added := ipInfo.AddDomain(q.FQDN) - if added { - _ = ipInfo.Save() - } + ips[v.AAAA.String()] = struct{}{} + } + } + + for ip := range ips { + record := resolver.ResolvedDomain{ + Domain: q.FQDN, + } + + // resolve all CNAMEs in the correct order. + var domain = q.FQDN + for { + nextDomain, isCNAME := cnames[domain] + if !isCNAME { + break + } + + record.CNAMEs = append(record.CNAMEs, nextDomain) + domain = nextDomain + } + + // update the entity to include the cnames + conn.Entity.CNAME = record.CNAMEs + + // get the existing IP info or create a new one + var save bool + info, err := resolver.GetIPInfo(ip) + if err != nil { + if err != database.ErrNotFound { + log.Errorf("nameserver: failed to search for IP info record: %s", err) + } + + info = &resolver.IPInfo{ + IP: ip, + } + save = true + } + + // and the new resolved domain record and save + if new := info.AddDomain(record); new { + save = true + } + if save { + if err := info.Save(); err != nil { + log.Errorf("nameserver: failed to save IP info record: %s", err) + } + } + } + + // if we have CNAMEs and the profile is configured to filter them + // we need to re-check the lists and endpoints here + if conn.Process().Profile().FilterCNAMEs() { + conn.Entity.ResetLists() + conn.Entity.EnableCNAMECheck(true) + + result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) + if result == endpoints.Denied { + conn.Block("endpoint in blocklist: " + reason) + returnNXDomain(w, query, conn.Reason) + return nil + } + + if result == endpoints.NoMatch { + result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) + if result == endpoints.Denied { + conn.Block("endpoint in filterlists: " + reason) + returnNXDomain(w, query, conn.Reason) + return nil } } } diff --git a/nameserver/only/nameserver.go b/nameserver/only/nameserver.go index feaa5ca5..6d5cb5cb 100644 --- a/nameserver/only/nameserver.go +++ b/nameserver/only/nameserver.go @@ -5,6 +5,7 @@ import ( "net" "strings" + "github.com/safing/portbase/database" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" "github.com/safing/portmaster/netenv" @@ -164,35 +165,60 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er } // save IP addresses to IPInfo + cnames := make(map[string]string) + ips := make(map[string]struct{}) + for _, rr := range append(rrCache.Answer, rrCache.Extra...) { switch v := rr.(type) { + case *dns.CNAME: + cnames[v.Hdr.Name] = v.Target + case *dns.A: - ipInfo, err := resolver.GetIPInfo(v.A.String()) - if err != nil { - ipInfo = &resolver.IPInfo{ - IP: v.A.String(), - Domains: []string{q.FQDN}, - } - _ = ipInfo.Save() - } else { - added := ipInfo.AddDomain(q.FQDN) - if added { - _ = ipInfo.Save() - } - } + ips[v.A.String()] = struct{}{} + case *dns.AAAA: - ipInfo, err := resolver.GetIPInfo(v.AAAA.String()) - if err != nil { - ipInfo = &resolver.IPInfo{ - IP: v.AAAA.String(), - Domains: []string{q.FQDN}, - } - _ = ipInfo.Save() - } else { - added := ipInfo.AddDomain(q.FQDN) - if added { - _ = ipInfo.Save() - } + ips[v.AAAA.String()] = struct{}{} + } + } + + for ip := range ips { + record := resolver.ResolvedDomain{ + Domain: q.FQDN, + } + + // resolve all CNAMEs in the correct order. + var domain = q.FQDN + for { + nextDomain, isCNAME := cnames[domain] + if !isCNAME { + break + } + + record.CNAMEs = append(record.CNAMEs, nextDomain) + domain = nextDomain + } + + // get the existing IP info or create a new one + var save bool + info, err := resolver.GetIPInfo(ip) + if err != nil { + if err != database.ErrNotFound { + log.Errorf("nameserver: failed to search for IP info record: %s", err) + } + + info = &resolver.IPInfo{ + IP: ip, + } + save = true + } + + // and the new resolved domain record and save + if new := info.AddDomain(record); new { + save = true + } + if save { + if err := info.Save(); err != nil { + log.Errorf("nameserver: failed to save IP info record: %s", err) } } } diff --git a/network/connection.go b/network/connection.go index b9bef333..e0af4c01 100644 --- a/network/connection.go +++ b/network/connection.go @@ -54,9 +54,9 @@ type Connection struct { //nolint:maligned // TODO: fix alignment } // NewConnectionFromDNSRequest returns a new connection based on the given dns request. -func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, ip net.IP, port uint16) *Connection { +func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection { // get Process - proc, err := process.GetProcessByEndpoints(ctx, ip, port, dnsAddress, dnsPort, packet.UDP) + proc, err := process.GetProcessByEndpoints(ctx, localIP, localPort, dnsAddress, dnsPort, packet.UDP) if err != nil { log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err) proc = process.GetUnidentifiedProcess(ctx) @@ -67,7 +67,8 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, ip net.IP, po Scope: fqdn, Entity: (&intel.Entity{ Domain: fqdn, - }).Init(), + CNAME: cnames, + }), process: proc, Started: timestamp, Ended: timestamp, @@ -104,7 +105,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { IP: pkt.Info().Src, Protocol: uint8(pkt.Info().Protocol), Port: pkt.Info().SrcPort, - }).Init() + }) } else { @@ -113,18 +114,21 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { IP: pkt.Info().Dst, Protocol: uint8(pkt.Info().Protocol), Port: pkt.Info().DstPort, - }).Init() + }) // check if we can find a domain for that IP ipinfo, err := resolver.GetIPInfo(pkt.Info().Dst.String()) if err == nil { + lastResolvedDomain := ipinfo.ResolvedDomains.MostRecentDomain() + if lastResolvedDomain != nil { + scope = lastResolvedDomain.Domain + entity.Domain = lastResolvedDomain.Domain + entity.CNAME = lastResolvedDomain.CNAMEs + removeOpenDNSRequest(proc.Pid, lastResolvedDomain.Domain) + } + } - // outbound to domain - scope = ipinfo.Domains[0] - entity.Domain = scope - removeOpenDNSRequest(proc.Pid, scope) - - } else { + if scope == "" { // outbound direct (possibly P2P) connection switch netutils.ClassifyIP(pkt.Info().Dst) { diff --git a/profile/config.go b/profile/config.go index 2607646a..1890afa6 100644 --- a/profile/config.go +++ b/profile/config.go @@ -30,6 +30,9 @@ var ( CfgOptionFilterSubDomainsKey = "filter/includeSubdomains" cfgOptionFilterSubDomains config.IntOption // security level option + CfgOptionFilterCNAMEKey = "filter/includeCNAMEs" + cfgOptionFilterCNAME config.IntOption // security level option + CfgOptionBlockScopeLocalKey = "filter/blockLocal" cfgOptionBlockScopeLocal config.IntOption // security level option @@ -180,6 +183,23 @@ Examples: cfgOptionFilterLists = config.Concurrent.GetAsStringArray(CfgOptionFilterListKey, []string{}) cfgStringArrayOptions[CfgOptionFilterListKey] = cfgOptionFilterLists + // Include CNAMEs + err = config.Register(&config.Option{ + Name: "Filter CNAMEs", + Key: CfgOptionFilterCNAMEKey, + Description: "Also filter requests where a CNAME would be blocked", + OptType: config.OptTypeInt, + ExternalOptType: "security level", + DefaultValue: status.SecurityLevelsAll, + ValidationRegex: "^(7|6|4)$", + }) + if err != nil { + return err + } + cfgOptionFilterCNAME = config.Concurrent.GetAsInt(CfgOptionFilterCNAMEKey, int64(status.SecurityLevelsAll)) + cfgIntOptions[CfgOptionFilterCNAMEKey] = cfgOptionFilterCNAME + + // Include subdomains err = config.Register(&config.Option{ Name: "Filter SubDomains", Key: CfgOptionFilterSubDomainsKey, diff --git a/profile/endpoints/endpoint-domain.go b/profile/endpoints/endpoint-domain.go index 225a4b5d..3350f053 100644 --- a/profile/endpoints/endpoint-domain.go +++ b/profile/endpoints/endpoint-domain.go @@ -31,35 +31,77 @@ type EndpointDomain struct { Reason string } +func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, string) { + switch ep.MatchType { + case domainMatchTypeExact: + if domain == ep.Domain { + return ep.matchesPPP(entity), ep.Reason + } + case domainMatchTypeZone: + if domain == ep.Domain { + return ep.matchesPPP(entity), ep.Reason + } + if strings.HasSuffix(domain, ep.DomainZone) { + return ep.matchesPPP(entity), ep.Reason + } + case domainMatchTypeSuffix: + if strings.HasSuffix(domain, ep.Domain) { + return ep.matchesPPP(entity), ep.Reason + } + case domainMatchTypePrefix: + if strings.HasPrefix(domain, ep.Domain) { + return ep.matchesPPP(entity), ep.Reason + } + case domainMatchTypeContains: + if strings.Contains(domain, ep.Domain) { + return ep.matchesPPP(entity), ep.Reason + } + } + return NoMatch, "" +} + // Matches checks whether the given entity matches this endpoint definition. func (ep *EndpointDomain) Matches(entity *intel.Entity) (result EPResult, reason string) { if entity.Domain == "" { return NoMatch, "" } - switch ep.MatchType { - case domainMatchTypeExact: - if entity.Domain == ep.Domain { - return ep.matchesPPP(entity), ep.Reason - } - case domainMatchTypeZone: - if entity.Domain == ep.Domain { - return ep.matchesPPP(entity), ep.Reason - } - if strings.HasSuffix(entity.Domain, ep.DomainZone) { - return ep.matchesPPP(entity), ep.Reason - } - case domainMatchTypeSuffix: - if strings.HasSuffix(entity.Domain, ep.Domain) { - return ep.matchesPPP(entity), ep.Reason - } - case domainMatchTypePrefix: - if strings.HasPrefix(entity.Domain, ep.Domain) { - return ep.matchesPPP(entity), ep.Reason - } - case domainMatchTypeContains: - if strings.Contains(entity.Domain, ep.Domain) { - return ep.matchesPPP(entity), ep.Reason + result, reason = ep.check(entity, entity.Domain) + if result != NoMatch { + return + } + + if entity.CNAMECheckEnabled() { + for _, domain := range entity.CNAME { + switch ep.MatchType { + case domainMatchTypeExact: + if domain == ep.Domain { + result, reason = ep.matchesPPP(entity), ep.Reason + } + case domainMatchTypeZone: + if domain == ep.Domain { + result, reason = ep.matchesPPP(entity), ep.Reason + } + if strings.HasSuffix(domain, ep.DomainZone) { + result, reason = ep.matchesPPP(entity), ep.Reason + } + case domainMatchTypeSuffix: + if strings.HasSuffix(domain, ep.Domain) { + result, reason = ep.matchesPPP(entity), ep.Reason + } + case domainMatchTypePrefix: + if strings.HasPrefix(domain, ep.Domain) { + result, reason = ep.matchesPPP(entity), ep.Reason + } + case domainMatchTypeContains: + if strings.Contains(domain, ep.Domain) { + result, reason = ep.matchesPPP(entity), ep.Reason + } + } + + if result == Denied { + return result, reason + } } } diff --git a/profile/profile-layered.go b/profile/profile-layered.go index dd0af165..4ba2210b 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -43,6 +43,7 @@ type LayeredProfile struct { RemoveOutOfScopeDNS config.BoolOption RemoveBlockedDNS config.BoolOption FilterSubDomains config.BoolOption + FilterCNAMEs config.BoolOption PreventBypassing config.BoolOption } @@ -99,6 +100,10 @@ func NewLayeredProfile(localProfile *Profile) *LayeredProfile { CfgOptionFilterSubDomainsKey, cfgOptionFilterSubDomains, ) + new.FilterCNAMEs = new.wrapSecurityLevelOption( + CfgOptionFilterCNAMEKey, + cfgOptionFilterCNAME, + ) new.PreventBypassing = new.wrapSecurityLevelOption( CfgOptionPreventBypassingKey, cfgOptionPreventBypassing, @@ -236,6 +241,7 @@ func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (result end // lists. func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, string) { entity.ResolveSubDomainLists(lp.FilterSubDomains()) + entity.EnableCNAMECheck(lp.FilterCNAMEs()) lookupMap, hasLists := entity.GetListsMap() if !hasLists { diff --git a/resolver/ipinfo.go b/resolver/ipinfo.go index 760b25f3..03c0b59b 100644 --- a/resolver/ipinfo.go +++ b/resolver/ipinfo.go @@ -16,13 +16,92 @@ var ( }) ) +// ResolvedDomain holds a Domain name and a list of +// CNAMES that have been resolved. +type ResolvedDomain struct { + // Domain is the domain as requested by the application. + Domain string + + // CNAMEs is a list of CNAMEs that have been resolved for + // Domain. + CNAMEs []string +} + +// String returns a string representation of ResolvedDomain including +// the CNAME chain. It implements fmt.Stringer +func (resolved *ResolvedDomain) String() string { + ret := resolved.Domain + cnames := "" + + if len(resolved.CNAMEs) > 0 { + cnames = " (-> " + strings.Join(resolved.CNAMEs, "->") + ")" + } + + return ret + cnames +} + +// ResolvedDomains is a helper type for operating on a slice +// of ResolvedDomain +type ResolvedDomains []ResolvedDomain + +// String returns a string representation of all domains joined +// to a single string. +func (rds ResolvedDomains) String() string { + var domains []string + for _, n := range rds { + domains = append(domains, n.String()) + } + return strings.Join(domains, " or ") +} + +// MostRecentDomain returns the most recent domain. +func (rds ResolvedDomains) MostRecentDomain() *ResolvedDomain { + if len(rds) == 0 { + return nil + } + // TODO(ppacher): we could also do that by using ResolvedAt() + mostRecent := rds[len(rds)-1] + return &mostRecent +} + // IPInfo represents various information about an IP. type IPInfo struct { record.Base sync.Mutex - IP string - Domains []string + // IP holds the acutal IP address. + IP string + + // Domains holds a list of domains that have been + // resolved to IP. This field is deprecated and should + // be removed. + // DEPRECATED: remove with alpha. + Domains []string `json:"Domains,omitempty"` + + // ResolvedDomain is a slice of domains that + // have been requested by various applications + // and have been resolved to IP. + ResolvedDomains ResolvedDomains +} + +// AddDomain adds a new resolved domain to ipi. +func (ipi *IPInfo) AddDomain(resolved ResolvedDomain) bool { + for idx, d := range ipi.ResolvedDomains { + if d.Domain == resolved.Domain { + if utils.StringSliceEqual(d.CNAMEs, resolved.CNAMEs) { + return false + } + + // we have a different CNAME chain now, remove the previous + // entry and add it at the end. + ipi.ResolvedDomains = append(ipi.ResolvedDomains[:idx], ipi.ResolvedDomains[idx+1:]...) + ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved) + return true + } + } + + ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved) + return true } func makeIPInfoKey(ip string) string { @@ -46,6 +125,19 @@ func GetIPInfo(ip string) (*IPInfo, error) { if err != nil { return nil, err } + + // Legacy support, + // DEPRECATED: remove with alpha + if len(new.Domains) > 0 && len(new.ResolvedDomains) == 0 { + for _, d := range new.Domains { + new.ResolvedDomains = append(new.ResolvedDomains, ResolvedDomain{ + Domain: d, + // rest is empty... + }) + } + new.Domains = nil // clean up so we remove it from the database + } + return new, nil } @@ -57,17 +149,6 @@ func GetIPInfo(ip string) (*IPInfo, error) { return new, nil } -// AddDomain adds a domain to the list and reports back if it was added, or was already present. -func (ipi *IPInfo) AddDomain(domain string) (added bool) { - ipi.Lock() - defer ipi.Unlock() - if !utils.StringInSlice(ipi.Domains, domain) { - ipi.Domains = append([]string{domain}, ipi.Domains...) - return true - } - return false -} - // Save saves the IPInfo record to the database. func (ipi *IPInfo) Save() error { ipi.Lock() @@ -75,17 +156,21 @@ func (ipi *IPInfo) Save() error { ipi.SetKey(makeIPInfoKey(ipi.IP)) } ipi.Unlock() - return ipInfoDatabase.Put(ipi) -} -// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " -func (ipi *IPInfo) FmtDomains() string { - return strings.Join(ipi.Domains, " or ") + // Legacy support + // Ensure we don't write new Domain fields into the + // database. + // DEPRECATED: remove with alpha + if len(ipi.Domains) > 0 { + ipi.Domains = nil + } + + return ipInfoDatabase.Put(ipi) } // FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " func (ipi *IPInfo) String() string { ipi.Lock() defer ipi.Unlock() - return fmt.Sprintf(" Date: Fri, 17 Apr 2020 17:02:04 +0200 Subject: [PATCH 2/7] Add support to detect upstream DNS resolver blocking --- nameserver/nameserver.go | 94 +++++++++++++++------------- profile/endpoints/endpoint-domain.go | 27 +------- resolver/block_detection.go | 61 ++++++++++++++++++ resolver/config.go | 22 ++++--- resolver/resolve.go | 15 +++++ resolver/resolver.go | 35 ++++++++++- resolver/resolvers.go | 25 ++++++-- 7 files changed, 194 insertions(+), 85 deletions(-) create mode 100644 resolver/block_detection.go diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index f02151fe..013fc749 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -246,8 +246,14 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er if err != nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains tracer.Warningf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err) + + if _, ok := err.(*resolver.BlockedUpstreamError); ok { + conn.Block(err.Error()) + } else { + conn.Failed("failed to resolve: " + err.Error()) + } + returnNXDomain(w, query, conn.Reason) - conn.Failed("failed to resolve: " + err.Error()) return nil } @@ -261,6 +267,51 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er return nil } + updateIPsAndCNAMEs(q, rrCache, conn) + + // if we have CNAMEs and the profile is configured to filter them + // we need to re-check the lists and endpoints here + if conn.Process().Profile().FilterCNAMEs() { + conn.Entity.ResetLists() + conn.Entity.EnableCNAMECheck(true) + + result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) + if result == endpoints.Denied { + conn.Block("endpoint in blocklist: " + reason) + returnNXDomain(w, query, conn.Reason) + return nil + } + + if result == endpoints.NoMatch { + result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) + if result == endpoints.Denied { + conn.Block("endpoint in filterlists: " + reason) + returnNXDomain(w, query, conn.Reason) + return nil + } + } + } + + // reply to query + m := new(dns.Msg) + m.SetReply(query) + m.Answer = rrCache.Answer + m.Ns = rrCache.Ns + m.Extra = rrCache.Extra + + if err := w.WriteMsg(m); err != nil { + log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) + } else { + tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process()) + } + + // save dns request as open + network.SaveOpenDNSRequest(conn) + + return nil +} + +func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { // save IP addresses to IPInfo cnames := make(map[string]string) ips := make(map[string]struct{}) @@ -322,45 +373,4 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er } } } - - // if we have CNAMEs and the profile is configured to filter them - // we need to re-check the lists and endpoints here - if conn.Process().Profile().FilterCNAMEs() { - conn.Entity.ResetLists() - conn.Entity.EnableCNAMECheck(true) - - result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) - if result == endpoints.Denied { - conn.Block("endpoint in blocklist: " + reason) - returnNXDomain(w, query, conn.Reason) - return nil - } - - if result == endpoints.NoMatch { - result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) - if result == endpoints.Denied { - conn.Block("endpoint in filterlists: " + reason) - returnNXDomain(w, query, conn.Reason) - return nil - } - } - } - - // reply to query - m := new(dns.Msg) - m.SetReply(query) - m.Answer = rrCache.Answer - m.Ns = rrCache.Ns - m.Extra = rrCache.Extra - - if err := w.WriteMsg(m); err != nil { - log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) - } else { - tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process()) - } - - // save dns request as open - network.SaveOpenDNSRequest(conn) - - return nil } diff --git a/profile/endpoints/endpoint-domain.go b/profile/endpoints/endpoint-domain.go index 3350f053..2691ba13 100644 --- a/profile/endpoints/endpoint-domain.go +++ b/profile/endpoints/endpoint-domain.go @@ -73,32 +73,7 @@ func (ep *EndpointDomain) Matches(entity *intel.Entity) (result EPResult, reason if entity.CNAMECheckEnabled() { for _, domain := range entity.CNAME { - switch ep.MatchType { - case domainMatchTypeExact: - if domain == ep.Domain { - result, reason = ep.matchesPPP(entity), ep.Reason - } - case domainMatchTypeZone: - if domain == ep.Domain { - result, reason = ep.matchesPPP(entity), ep.Reason - } - if strings.HasSuffix(domain, ep.DomainZone) { - result, reason = ep.matchesPPP(entity), ep.Reason - } - case domainMatchTypeSuffix: - if strings.HasSuffix(domain, ep.Domain) { - result, reason = ep.matchesPPP(entity), ep.Reason - } - case domainMatchTypePrefix: - if strings.HasPrefix(domain, ep.Domain) { - result, reason = ep.matchesPPP(entity), ep.Reason - } - case domainMatchTypeContains: - if strings.Contains(domain, ep.Domain) { - result, reason = ep.matchesPPP(entity), ep.Reason - } - } - + result, reason = ep.check(entity, domain) if result == Denied { return result, reason } diff --git a/resolver/block_detection.go b/resolver/block_detection.go new file mode 100644 index 00000000..8a4005cd --- /dev/null +++ b/resolver/block_detection.go @@ -0,0 +1,61 @@ +package resolver + +import ( + "net" + + "github.com/miekg/dns" +) + +// Supported upstream block detections +const ( + BlockDetectionRefused = "refused" + BlockDetectionZeroIP = "zeroip" + BlockDetectionEmptyAnswer = "empty" + BlockDetectionDisabled = "disabled" +) + +func isBlockedUpstream(resolver *Resolver, answer *dns.Msg) bool { + if resolver.UpstreamBlockDetection == BlockDetectionDisabled { + return false + } + + switch resolver.UpstreamBlockDetection { + case BlockDetectionRefused: + return answer.Rcode == dns.RcodeRefused + case BlockDetectionZeroIP: + if answer.Rcode != dns.RcodeSuccess { + return false + } + var ips []net.IP + for _, rr := range answer.Answer { + switch v := rr.(type) { + case *dns.A: + ips = append(ips, v.A) + case *dns.AAAA: + ips = append(ips, v.AAAA) + } + } + + if len(ips) == 0 { + return false // we expected an empty IP + } + + for _, ip := range ips { + if ip.To4() != nil { + if !ip.Equal(net.IPv4zero) { + return false + } + } else { + if !ip.To16().Equal(net.IPv6zero) { + return false + } + } + } + + return true + case BlockDetectionEmptyAnswer: + return answer.Rcode == dns.RcodeNameError && len(answer.Ns) == 0 && len(answer.Answer) == 0 && len(answer.Extra) == 0 + } + + return false +} diff --git a/resolver/config.go b/resolver/config.go index d4c4828d..ac9567b0 100644 --- a/resolver/config.go +++ b/resolver/config.go @@ -29,28 +29,30 @@ var ( // We encourage everyone who has the technical abilities to set their own preferred servers. // Default 1: Cloudflare - "dot://1.1.1.1:853?verify=cloudflare-dns.com", // Cloudflare - "dot://1.0.0.1:853?verify=cloudflare-dns.com", // Cloudflare + "dot://1.1.1.1:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", // Cloudflare + "dot://1.0.0.1:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", // Cloudflare // Default 2: Quad9 - "dot://9.9.9.9:853?verify=dns.quad9.net", // Quad9 - "dot://149.112.112.112:853?verify=dns.quad9.net", // Quad9 + "dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", // Quad9 + "dot://149.112.112.112:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", // Quad9 // Fallback 1: Cloudflare - "dns://1.1.1.1:53", // Cloudflare - "dns://1.0.0.1:53", // Cloudflare + "dns://1.1.1.1:53?name=Cloudflare&blockedif=zeroip", // Cloudflare + "dns://1.0.0.1:53?name=Cloudflare&blockedif=zeroip", // Cloudflare // Fallback 2: Quad9 - "dns://9.9.9.9:53", // Quad9 - "dns://149.112.112.112:53", // Quad9 + "dns://9.9.9.9:53?name=Quad9&blockedif=empty", // Quad9 + "dns://149.112.112.112:53?name=Quad9&blockedif=empty", // Quad9 // supported parameters // - `verify=domain`: verify domain (dot only) // future parameters: // // - `name=name`: human readable name for resolver - // - `blockedif=baredns`: how to detect if the dns service blocked something - // - `baredns`: NXDomain result, but without any other record in any section + // - `blockedif=empty`: how to detect if the dns service blocked something + // - `empty`: NXDomain result, but without any other record in any section + // - `refused`: Request was refused + // - `zeroip`: Answer only contains zeroip } CfgOptionNameServersKey = "dns/nameservers" diff --git a/resolver/resolve.go b/resolver/resolve.go index 2c7f532a..f13d07c2 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -37,6 +37,21 @@ var ( ErrNoCompliance = fmt.Errorf("%w: no compliant resolvers for this query", ErrBlocked) ) +// BlockedUpstreamError is returned when a DNS request +// has been blocked by the upstream server. +type BlockedUpstreamError struct { + ResolverName string +} + +func (blocked *BlockedUpstreamError) Error() string { + return fmt.Sprintf("Endpoint blocked by upstream DNS resolver %s", blocked.ResolverName) +} + +// Unwrap implements errors.Unwrapper +func (blocked *BlockedUpstreamError) Unwrap() error { + return ErrBlocked +} + // Query describes a dns query. type Query struct { FQDN string diff --git a/resolver/resolver.go b/resolver/resolver.go index 0201b0c5..65155fab 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -28,6 +28,19 @@ type Resolver struct { // Server config url (and ID) Server string + // Name is the name of the resolver as passed via + // ?name=. + Name string + + // UpstreamBlockDetection defines the detection type + // to identifier upstream DNS query blocking. + // Valid values are: + // - zeroip + // - empty + // - refused (default) + // - disabled + UpstreamBlockDetection string + // Parsed config ServerType string ServerAddress string @@ -46,9 +59,25 @@ type Resolver struct { Conn ResolverConn } +// IsBlockedUpstream returns true if the request has been blocked +// upstream. +func (resolver *Resolver) IsBlockedUpstream(answer *dns.Msg) bool { + return isBlockedUpstream(resolver, answer) +} + +// GetName returns the name of the server. If no name +// is configured the server address is returned. +func (resolver *Resolver) GetName() string { + if resolver.Name != "" { + return resolver.Name + } + + return resolver.Server +} + // String returns the URL representation of the resolver. func (resolver *Resolver) String() string { - return resolver.Server + return resolver.GetName() } // ResolverConn is an interface to implement different types of query backends. @@ -126,6 +155,10 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er break } + if resolver.IsBlockedUpstream(reply) { + return nil, &BlockedUpstreamError{resolver.GetName()} + } + // no error break } diff --git a/resolver/resolvers.go b/resolver/resolvers.go index b4451f11..00ad0d0e 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -107,13 +107,26 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { return nil, false, fmt.Errorf("DOT must have a verify query parameter set") } + blockType := query.Get("blockedif") + if blockType == "" { + blockType = BlockDetectionRefused + } + + switch blockType { + case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP: + default: + return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)") + } + new := &Resolver{ - Server: resolverURL, - ServerType: u.Scheme, - ServerAddress: u.Host, - ServerIPScope: scope, - Source: source, - VerifyDomain: verifyDomain, + Server: resolverURL, + ServerType: u.Scheme, + ServerAddress: u.Host, + ServerIPScope: scope, + Source: source, + VerifyDomain: verifyDomain, + Name: query.Get("name"), + UpstreamBlockDetection: blockType, } newConn := &BasicResolverConn{ From 42ccb3e39a5469bfdf7c8be81dca1c65c2eda082 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Mon, 20 Apr 2020 08:25:34 +0200 Subject: [PATCH 3/7] Small changes based on review comments --- intel/entity.go | 11 +++++++++++ nameserver/nameserver.go | 3 ++- profile/config.go | 1 + 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/intel/entity.go b/intel/entity.go index 39a3765c..f01b53cb 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -273,6 +273,8 @@ func (e *Entity) getDomainLists() { domains = domainsToInspect } + domains = makeDistinct(domains) + for _, d := range domains { log.Tracef("intel: loading domain list for %s", d) list, err := filterlists.LookupDomain(d) @@ -441,3 +443,12 @@ func buildLookupMap(l []string) filterlists.LookupMap { return m } + +func makeDistinct(slice []string) []string { + lm := buildLookupMap(slice) + result := make([]string, 0, len(lm)) + for key := range lm { + result = append(result, key) + } + return result +} diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 013fc749..a30d933c 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -2,6 +2,7 @@ package nameserver import ( "context" + "errors" "fmt" "net" "strings" @@ -247,7 +248,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // TODO: analyze nxdomain requests, malware could be trying DGA-domains tracer.Warningf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err) - if _, ok := err.(*resolver.BlockedUpstreamError); ok { + if errors.Is(err, &resolver.BlockedUpstreamError{}) { conn.Block(err.Error()) } else { conn.Failed("failed to resolve: " + err.Error()) diff --git a/profile/config.go b/profile/config.go index 1890afa6..8842530e 100644 --- a/profile/config.go +++ b/profile/config.go @@ -192,6 +192,7 @@ Examples: ExternalOptType: "security level", DefaultValue: status.SecurityLevelsAll, ValidationRegex: "^(7|6|4)$", + ExpertiseLevel: config.ExpertiseLevelExpert, }) if err != nil { return err From eeb358425df918cd665337e5fadaa8c84c9430b2 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Mon, 20 Apr 2020 11:36:34 +0200 Subject: [PATCH 4/7] Refactor entity list handling --- intel/entity.go | 183 ++++++++++++++++++++++------ intel/filterlists/lookup_map.go | 25 ---- intel/filterlists/lookup_test.go | 92 -------------- intel/lists.go | 40 ------ profile/endpoints/endpoint-lists.go | 16 ++- profile/endpoints/endpoint.go | 2 +- profile/profile-layered.go | 26 ++-- 7 files changed, 167 insertions(+), 217 deletions(-) delete mode 100644 intel/filterlists/lookup_map.go delete mode 100644 intel/filterlists/lookup_test.go delete mode 100644 intel/lists.go diff --git a/intel/entity.go b/intel/entity.go index f01b53cb..e46ac041 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -16,6 +16,43 @@ import ( "golang.org/x/net/publicsuffix" ) +// ListMatch represents an entity that has been +// matched against filterlists. +type ListMatch struct { + Entity string + ActiveLists []string + InactiveLists []string +} + +func (lm *ListMatch) String() string { + inactive := "" + if len(lm.InactiveLists) > 0 { + inactive = " and in deactivated lists " + strings.Join(lm.InactiveLists, ", ") + } + return fmt.Sprintf( + "%s in activated lists %s%s", + lm.Entity, + strings.Join(lm.ActiveLists, ","), + inactive, + ) +} + +// ListBlockReason is a list of list matches. +type ListBlockReason []ListMatch + +func (br ListBlockReason) String() string { + if len(br) == 0 { + return "" + } + + matches := make([]string, len(br)) + for idx, lm := range br { + matches[idx] = lm.String() + } + + return strings.Join(matches, " and ") +} + // Entity describes a remote endpoint in many different ways. // It embeddes a sync.Mutex but none of the endpoints own // functions performs locking. The caller MUST ENSURE @@ -60,8 +97,18 @@ type Entity struct { location *geoip.Location - Lists []string - ListsMap filterlists.LookupMap + // BlockedByLists holds list source IDs that + // are used to block the entity. + BlockedByLists []string + + // BlockedEntities holds a list of entities that + // have been blocked. Values can be used as a key + // for the ListOccurences map. + BlockedEntities []string + + // ListOccurences is a map that matches an entity (Domain, IPs, ASN, Country, Sub-domain) + // to a list of sources where the entity has been observed in. + ListOccurences map[string][]string // we only load each data above at most once fetchLocationOnce sync.Once @@ -90,8 +137,11 @@ func (e *Entity) ResetLists() { // TODO(ppacher): our actual goal is to reset the domain // list right now so we could be more efficient by keeping // the other lists around. - e.Lists = nil - e.ListsMap = nil + + // FIXME + //e.Lists = nil + //e.ListsMap = nil + e.ListOccurences = nil e.domainListLoaded = false e.ipListLoaded = false e.countryListLoaded = false @@ -238,9 +288,19 @@ func (e *Entity) getLists() { e.getCountryLists() } -func (e *Entity) mergeList(list []string) { - e.Lists = mergeStringList(e.Lists, list) - e.ListsMap = buildLookupMap(e.Lists) +func (e *Entity) mergeList(key string, list []string) { + if len(list) == 0 { + return + } + + if e.ListOccurences == nil { + e.ListOccurences = make(map[string][]string) + } + + e.ListOccurences[key] = mergeStringList(e.ListOccurences[key], list) + + //e.Lists = mergeStringList(e.Lists, list) + //e.ListsMap = buildLookupMap(e.Lists) } func (e *Entity) getDomainLists() { @@ -284,7 +344,7 @@ func (e *Entity) getDomainLists() { return } - e.mergeList(list) + e.mergeList(d, list) } e.domainListLoaded = true }) @@ -328,7 +388,8 @@ func (e *Entity) getASNLists() { log.Tracef("intel: loading ASN list for %d", asn) e.loadAsnListOnce.Do(func() { - list, err := filterlists.LookupASNString(fmt.Sprintf("%d", asn)) + asnStr := fmt.Sprintf("%d", asn) + list, err := filterlists.LookupASNString(asnStr) if err != nil { log.Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err) e.loadAsnListOnce = sync.Once{} @@ -336,7 +397,7 @@ func (e *Entity) getASNLists() { } e.asnListLoaded = true - e.mergeList(list) + e.mergeList(asnStr, list) }) } @@ -360,7 +421,7 @@ func (e *Entity) getCountryLists() { } e.countryListLoaded = true - e.mergeList(list) + e.mergeList(country, list) }) } @@ -393,28 +454,71 @@ func (e *Entity) getIPLists() { return } e.ipListLoaded = true - e.mergeList(list) + e.mergeList(ip.String(), list) }) } -// GetLists returns the filter list identifiers the entity matched and whether this data is set. -func (e *Entity) GetLists() ([]string, bool) { +// LoadLists searches all filterlists for all occurences of +// this entity. +func (e *Entity) LoadLists() bool { e.getLists() - if e.Lists == nil { - return nil, false + if e.ListOccurences == nil { + return false } - return e.Lists, true + return true } -// GetListsMap is like GetLists but returns a lookup map for list IDs. -func (e *Entity) GetListsMap() (filterlists.LookupMap, bool) { - e.getLists() +// MatchLists matches the entities lists against a slice +// of source IDs and updates various entity properties +// like BlockedByLists, ListOccurences and BlockedEntitites. +func (e *Entity) MatchLists(lists []string) bool { + e.BlockedByLists = nil + e.BlockedEntities = nil - if e.ListsMap == nil { - return nil, false + lm := makeMap(lists) + for key, keyLists := range e.ListOccurences { + for _, keyListID := range keyLists { + if _, ok := lm[keyListID]; ok { + e.BlockedByLists = append(e.BlockedByLists, keyListID) + e.BlockedEntities = append(e.BlockedEntities, key) + } + } } - return e.ListsMap, true + + makeDistinct(e.BlockedByLists) + + return len(e.BlockedByLists) > 0 +} + +// ListBlockReason returns the block reason for this entity. +func (e *Entity) ListBlockReason() ListBlockReason { + blockedBy := make([]ListMatch, len(e.BlockedEntities)) + + lm := makeMap(e.BlockedByLists) + + for idx, blockedEntity := range e.BlockedEntities { + if entityLists, ok := e.ListOccurences[blockedEntity]; ok { + var activeLists []string + var inactiveLists []string + + for _, l := range entityLists { + if _, ok := lm[l]; ok { + activeLists = append(activeLists, l) + } else { + inactiveLists = append(inactiveLists, l) + } + } + + blockedBy[idx] = ListMatch{ + Entity: blockedEntity, + ActiveLists: activeLists, + InactiveLists: inactiveLists, + } + } + } + + return blockedBy } func mergeStringList(a, b []string) []string { @@ -434,21 +538,26 @@ func mergeStringList(a, b []string) []string { return res } -func buildLookupMap(l []string) filterlists.LookupMap { - m := make(filterlists.LookupMap, len(l)) - - for _, s := range l { - m[s] = struct{}{} - } - - return m -} - func makeDistinct(slice []string) []string { - lm := buildLookupMap(slice) - result := make([]string, 0, len(lm)) - for key := range lm { - result = append(result, key) + m := make(map[string]struct{}, len(slice)) + var result []string + + for _, v := range slice { + if _, ok := m[v]; ok { + continue + } + + m[v] = struct{}{} + result = append(result, v) } + return result } + +func makeMap(slice []string) map[string]struct{} { + lm := make(map[string]struct{}) + for _, v := range slice { + lm[v] = struct{}{} + } + return lm +} diff --git a/intel/filterlists/lookup_map.go b/intel/filterlists/lookup_map.go deleted file mode 100644 index 5a161c91..00000000 --- a/intel/filterlists/lookup_map.go +++ /dev/null @@ -1,25 +0,0 @@ -package filterlists - -import "strings" - -// LookupMap is a helper type for matching a list of endpoint sources -// against a map. -type LookupMap map[string]struct{} - -// Match checks if a source in `list` is part of lm. -// Matches are joined to string and returned. -// If nothing is found, an empty string is returned. -func (lm LookupMap) Match(list []string) string { - matches := make([]string, 0, len(list)) - for _, l := range list { - if _, ok := lm[l]; ok { - matches = append(matches, l) - } - } - - if len(matches) == 0 { - return "" - } - - return strings.Join(matches, ", ") -} diff --git a/intel/filterlists/lookup_test.go b/intel/filterlists/lookup_test.go deleted file mode 100644 index 0a2e9220..00000000 --- a/intel/filterlists/lookup_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package filterlists - -/* - -func TestLookupASN(t *testing.T) { - lists, err := LookupASNString("123") - assert.NoError(t, err) - assert.Equal(t, []string{"TEST"}, lists) - - lists, err = LookupASNString("does-not-exist") - assert.NoError(t, err) - assert.Empty(t, lists) - - defer testMarkNotLoaded()() - lists, err = LookupASNString("123") - assert.NoError(t, err) - assert.Empty(t, lists) -} - -func TestLookupCountry(t *testing.T) { - lists, err := LookupCountry("AT") - assert.NoError(t, err) - assert.Equal(t, []string{"TEST"}, lists) - - lists, err = LookupCountry("does-not-exist") - assert.NoError(t, err) - assert.Empty(t, lists) - - defer testMarkNotLoaded()() - lists, err = LookupCountry("AT") - assert.NoError(t, err) - assert.Empty(t, lists) -} - -func TestLookupIP(t *testing.T) { - lists, err := LookupIP(net.IP{1, 1, 1, 1}) - assert.NoError(t, err) - assert.Equal(t, []string{"TEST"}, lists) - - lists, err = LookupIP(net.IP{127, 0, 0, 1}) - assert.NoError(t, err) - assert.Empty(t, lists) - - defer testMarkNotLoaded()() - lists, err = LookupIP(net.IP{1, 1, 1, 1}) - assert.NoError(t, err) - assert.Empty(t, lists) -} - -func TestLookupDomain(t *testing.T) { - lists, err := LookupDomain("example.com") - assert.NoError(t, err) - assert.Equal(t, []string{"TEST"}, lists) - - lists, err = LookupDomain("does-not-exist") - assert.NoError(t, err) - assert.Empty(t, lists) - - defer testMarkNotLoaded()() - lists, err = LookupDomain("example.com") - assert.NoError(t, err) - assert.Empty(t, lists) -} - -// testMarkNotLoaded ensures that functions believe -// filterlists are not yet loaded. It returns a -// func that restores the previous state. -func testMarkNotLoaded() func() { - if isLoaded() { - filterListsLoaded = make(chan struct{}) - return func() { - close(filterListsLoaded) - } - } - - return func() {} -} - -// testMarkLoaded is like testMarkNotLoaded but ensures -// isLoaded() return true. It returns a function to restore -// the previous state. -func testMarkLoaded() func() { - if !isLoaded() { - close(filterListsLoaded) - return func() { - filterListsLoaded = make(chan struct{}) - } - } - - return func() {} -} -*/ diff --git a/intel/lists.go b/intel/lists.go deleted file mode 100644 index 08ca0925..00000000 --- a/intel/lists.go +++ /dev/null @@ -1,40 +0,0 @@ -package intel - -// ListSet holds a set of list IDs. -type ListSet struct { - match []string -} - -// NewListSet returns a new ListSet with the given list IDs. -func NewListSet(lists []string) *ListSet { - // TODO: validate lists - return &ListSet{ - match: lists, - } -} - -// Matches returns whether there is a match in the given list IDs. -func (ls *ListSet) Matches(lists []string) (matches bool) { - for _, list := range lists { - for _, entry := range ls.match { - if entry == list { - return true - } - } - } - - return false -} - -// MatchSet returns the matching list IDs. -func (ls *ListSet) MatchSet(lists []string) (matched []string) { - for _, list := range lists { - for _, entry := range ls.match { - if entry == list { - matched = append(matched, list) - } - } - } - - return -} diff --git a/profile/endpoints/endpoint-lists.go b/profile/endpoints/endpoint-lists.go index fe151347..b11da060 100644 --- a/profile/endpoints/endpoint-lists.go +++ b/profile/endpoints/endpoint-lists.go @@ -10,21 +10,19 @@ import ( type EndpointLists struct { EndpointBase - ListSet *intel.ListSet + ListSet []string Lists string Reason string } // Matches checks whether the given entity matches this endpoint definition. func (ep *EndpointLists) Matches(entity *intel.Entity) (result EPResult, reason string) { - lists, ok := entity.GetLists() - if !ok { - return Undeterminable, "" - } - matched := ep.ListSet.MatchSet(lists) - if len(matched) > 0 { - return ep.matchesPPP(entity), ep.Reason + entity.LoadLists() + + if entity.MatchLists(ep.ListSet) { + return ep.matchesPPP(entity), entity.ListBlockReason().String() } + return NoMatch, "" } @@ -36,7 +34,7 @@ func parseTypeList(fields []string) (Endpoint, error) { if strings.HasPrefix(fields[1], "L:") { lists := strings.Split(strings.TrimPrefix(fields[1], "L:"), ",") ep := &EndpointLists{ - ListSet: intel.NewListSet(lists), + ListSet: lists, Lists: "L:" + strings.Join(lists, ","), Reason: "matched lists " + strings.Join(lists, ","), } diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index 687834cd..b634b911 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -11,7 +11,7 @@ import ( // Endpoint describes an Endpoint Matcher type Endpoint interface { - Matches(entity *intel.Entity) (result EPResult, reason string) + Matches(entity *intel.Entity) (EPResult, string) String() string } diff --git a/profile/profile-layered.go b/profile/profile-layered.go index 4ba2210b..edb8a8a2 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -243,27 +243,27 @@ func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPRe entity.ResolveSubDomainLists(lp.FilterSubDomains()) entity.EnableCNAMECheck(lp.FilterCNAMEs()) - lookupMap, hasLists := entity.GetListsMap() - if !hasLists { - return endpoints.NoMatch, "" - } - for _, layer := range lp.layers { - if reason := lookupMap.Match(layer.filterListIDs); reason != "" { - return endpoints.Denied, reason - } - - // only check the first layer that has filter list - // IDs defined. + // search for the first layer that has filterListIDs set if len(layer.filterListIDs) > 0 { + entity.LoadLists() + + if entity.MatchLists(layer.filterListIDs) { + return endpoints.Denied, entity.ListBlockReason().String() + } + return endpoints.NoMatch, "" } } cfgLock.RLock() defer cfgLock.RUnlock() - if reason := lookupMap.Match(cfgFilterLists); reason != "" { - return endpoints.Denied, reason + if len(cfgFilterLists) > 0 { + entity.LoadLists() + + if entity.MatchLists(cfgFilterLists) { + return endpoints.Denied, entity.ListBlockReason().String() + } } return endpoints.NoMatch, "" From 8c5526a69b74f19daeb78deb05c07c51aadf3696 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Mon, 20 Apr 2020 17:19:48 +0200 Subject: [PATCH 5/7] Add support for verdict and decision reason context --- firewall/inspection/inspection.go | 4 +- firewall/master.go | 12 ++--- intel/block_reason.go | 51 +++++++++++++++++++ intel/entity.go | 37 -------------- nameserver/nameserver.go | 4 +- network/connection.go | 72 ++++++++++++++++++--------- profile/endpoints/endpoint-any.go | 4 +- profile/endpoints/endpoint-asn.go | 21 ++++---- profile/endpoints/endpoint-country.go | 13 ++--- profile/endpoints/endpoint-domain.go | 30 +++++------ profile/endpoints/endpoint-ip.go | 15 +++--- profile/endpoints/endpoint-iprange.go | 14 +++--- profile/endpoints/endpoint-lists.go | 16 +++--- profile/endpoints/endpoint.go | 31 +++++++++++- profile/endpoints/endpoints.go | 10 +++- profile/endpoints/reason.go | 34 +++++++++++++ profile/profile-layered.go | 26 +++++----- 17 files changed, 246 insertions(+), 148 deletions(-) create mode 100644 intel/block_reason.go create mode 100644 profile/endpoints/reason.go diff --git a/firewall/inspection/inspection.go b/firewall/inspection/inspection.go index 55629b19..7dc59494 100644 --- a/firewall/inspection/inspection.go +++ b/firewall/inspection/inspection.go @@ -85,11 +85,11 @@ func RunInspectors(conn *network.Connection, pkt packet.Packet) (network.Verdict verdict = network.VerdictDrop continueInspection = true case BLOCK_CONN: - conn.SetVerdict(network.VerdictBlock) + conn.SetVerdict(network.VerdictBlock, "", nil) verdict = conn.Verdict activeInspectors[key] = true case DROP_CONN: - conn.SetVerdict(network.VerdictDrop) + conn.SetVerdict(network.VerdictDrop, "", nil) verdict = conn.Verdict activeInspectors[key] = true case STOP_INSPECTING: diff --git a/firewall/master.go b/firewall/master.go index 86196e36..cafd11f2 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -143,9 +143,6 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: } } - var result endpoints.EPResult - var reason string - if p.PreventBypassing() { // check for bypass protection result, reason := PreventBypassing(conn) @@ -160,6 +157,9 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: } } + var result endpoints.EPResult + var reason endpoints.Reason + // check endpoints list if conn.Inbound { result, reason = p.MatchServiceEndpoint(conn.Entity) @@ -168,10 +168,10 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: } switch result { case endpoints.Denied: - conn.Deny("endpoint is blacklisted: " + reason) // Block Outbound / Drop Inbound + conn.DenyWithContext(reason.String(), reason.Context()) return case endpoints.Permitted: - conn.Accept("endpoint is whitelisted: " + reason) + conn.AcceptWithContext(reason.String(), reason.Context()) return } // continuing with result == NoMatch @@ -180,7 +180,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: result, reason = p.MatchFilterLists(conn.Entity) switch result { case endpoints.Denied: - conn.Deny("endpoint in filterlists: " + reason) + conn.DenyWithContext(reason.String(), reason.Context()) return case endpoints.NoMatch: // nothing to do diff --git a/intel/block_reason.go b/intel/block_reason.go new file mode 100644 index 00000000..5a14c6cd --- /dev/null +++ b/intel/block_reason.go @@ -0,0 +1,51 @@ +package intel + +import ( + "fmt" + "strings" +) + +// ListMatch represents an entity that has been +// matched against filterlists. +type ListMatch struct { + Entity string + ActiveLists []string + InactiveLists []string +} + +func (lm *ListMatch) String() string { + inactive := "" + if len(lm.InactiveLists) > 0 { + inactive = " and in deactivated lists " + strings.Join(lm.InactiveLists, ", ") + } + return fmt.Sprintf( + "%s in activated lists %s%s", + lm.Entity, + strings.Join(lm.ActiveLists, ","), + inactive, + ) +} + +// ListBlockReason is a list of list matches. +type ListBlockReason []ListMatch + +func (br ListBlockReason) String() string { + if len(br) == 0 { + return "" + } + + matches := make([]string, len(br)) + for idx, lm := range br { + matches[idx] = lm.String() + } + + return strings.Join(matches, " and ") +} + +// Context returns br wrapped into a map. It implements +// the endpoints.Reason interface. +func (br ListBlockReason) Context() interface{} { + return map[string]interface{}{ + "filterlists": br, + } +} diff --git a/intel/entity.go b/intel/entity.go index e46ac041..8d2e21fb 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -16,43 +16,6 @@ import ( "golang.org/x/net/publicsuffix" ) -// ListMatch represents an entity that has been -// matched against filterlists. -type ListMatch struct { - Entity string - ActiveLists []string - InactiveLists []string -} - -func (lm *ListMatch) String() string { - inactive := "" - if len(lm.InactiveLists) > 0 { - inactive = " and in deactivated lists " + strings.Join(lm.InactiveLists, ", ") - } - return fmt.Sprintf( - "%s in activated lists %s%s", - lm.Entity, - strings.Join(lm.ActiveLists, ","), - inactive, - ) -} - -// ListBlockReason is a list of list matches. -type ListBlockReason []ListMatch - -func (br ListBlockReason) String() string { - if len(br) == 0 { - return "" - } - - matches := make([]string, len(br)) - for idx, lm := range br { - matches[idx] = lm.String() - } - - return strings.Join(matches, " and ") -} - // Entity describes a remote endpoint in many different ways. // It embeddes a sync.Mutex but none of the endpoints own // functions performs locking. The caller MUST ENSURE diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index a30d933c..9b3d8016 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -278,7 +278,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) if result == endpoints.Denied { - conn.Block("endpoint in blocklist: " + reason) + conn.BlockWithContext(reason.String(), reason.Context()) returnNXDomain(w, query, conn.Reason) return nil } @@ -286,7 +286,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er if result == endpoints.NoMatch { result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) if result == endpoints.Denied { - conn.Block("endpoint in filterlists: " + reason) + conn.BlockWithContext(reason.String(), reason.Context()) returnNXDomain(w, query, conn.Reason) return nil } diff --git a/network/connection.go b/network/connection.go index e0af4c01..bbc088dc 100644 --- a/network/connection.go +++ b/network/connection.go @@ -31,9 +31,10 @@ type Connection struct { //nolint:maligned // TODO: fix alignment Entity *intel.Entity // needs locking, instance is never shared process *process.Process - Verdict Verdict - Reason string - ReasonID string // format source[:id[:id]] // TODO + Verdict Verdict + Reason string + ReasonContext interface{} + ReasonID string // format source[:id[:id]] // TODO Started int64 Ended int64 @@ -164,59 +165,82 @@ func GetConnection(id string) (*Connection, bool) { return conn, ok } -// Accept accepts the connection. -func (conn *Connection) Accept(reason string) { - if conn.SetVerdict(VerdictAccept) { - conn.Reason = reason +// AcceptWithContext accepts the connection. +func (conn *Connection) AcceptWithContext(reason string, ctx interface{}) { + if conn.SetVerdict(VerdictAccept, reason, ctx) { log.Infof("filter: granting connection %s, %s", conn, conn.Reason) } else { log.Warningf("filter: tried to accept %s, but current verdict is %s", conn, conn.Verdict) } } -// Block blocks the connection. -func (conn *Connection) Block(reason string) { - if conn.SetVerdict(VerdictBlock) { - conn.Reason = reason +// Accept is like AcceptWithContext but only accepts a reason. +func (conn *Connection) Accept(reason string) { + conn.AcceptWithContext(reason, nil) +} + +// BlockWithContext blocks the connection. +func (conn *Connection) BlockWithContext(reason string, ctx interface{}) { + if conn.SetVerdict(VerdictBlock, reason, ctx) { log.Infof("filter: blocking connection %s, %s", conn, conn.Reason) } else { log.Warningf("filter: tried to block %s, but current verdict is %s", conn, conn.Verdict) } } -// Drop drops the connection. -func (conn *Connection) Drop(reason string) { - if conn.SetVerdict(VerdictDrop) { - conn.Reason = reason +// Block is like BlockWithContext but does only accepts a reason. +func (conn *Connection) Block(reason string) { + conn.BlockWithContext(reason, nil) +} + +// DropWithContext drops the connection. +func (conn *Connection) DropWithContext(reason string, ctx interface{}) { + if conn.SetVerdict(VerdictDrop, reason, ctx) { log.Infof("filter: dropping connection %s, %s", conn, conn.Reason) } else { log.Warningf("filter: tried to drop %s, but current verdict is %s", conn, conn.Verdict) } } -// Deny blocks or drops the link depending on the connection direction. -func (conn *Connection) Deny(reason string) { +// Drop is like DropWithContext but does only accepts a reason. +func (conn *Connection) Drop(reason string) { + conn.DropWithContext(reason, nil) +} + +// DenyWithContext blocks or drops the link depending on the connection direction. +func (conn *Connection) DenyWithContext(reason string, ctx interface{}) { if conn.Inbound { - conn.Drop(reason) + conn.DropWithContext(reason, ctx) } else { - conn.Block(reason) + conn.BlockWithContext(reason, ctx) } } -// Failed marks the connection with VerdictFailed and stores the reason. -func (conn *Connection) Failed(reason string) { - if conn.SetVerdict(VerdictFailed) { - conn.Reason = reason +// Deny is like DenyWithContext but only accepts a reason. +func (conn *Connection) Deny(reason string) { + conn.DenyWithContext(reason, nil) +} + +// FailedWithContext marks the connection with VerdictFailed and stores the reason. +func (conn *Connection) FailedWithContext(reason string, ctx interface{}) { + if conn.SetVerdict(VerdictFailed, reason, ctx) { log.Infof("filter: dropping connection %s because of an internal error: %s", conn, reason) } else { log.Warningf("filter: tried to drop %s due to error but current verdict is %s", conn, conn.Verdict) } } +// Failed is like FailedWithContext but only accepts a string. +func (conn *Connection) Failed(reason string) { + conn.FailedWithContext(reason, nil) +} + // SetVerdict sets a new verdict for the connection, making sure it does not interfere with previous verdicts. -func (conn *Connection) SetVerdict(newVerdict Verdict) (ok bool) { +func (conn *Connection) SetVerdict(newVerdict Verdict, reason string, ctx interface{}) (ok bool) { if newVerdict >= conn.Verdict { conn.Verdict = newVerdict + conn.Reason = reason + conn.ReasonContext = ctx return true } return false diff --git a/profile/endpoints/endpoint-any.go b/profile/endpoints/endpoint-any.go index 34fd65b4..8e8deb98 100644 --- a/profile/endpoints/endpoint-any.go +++ b/profile/endpoints/endpoint-any.go @@ -8,8 +8,8 @@ type EndpointAny struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointAny) Matches(entity *intel.Entity) (result EPResult, reason string) { - return ep.matchesPPP(entity), "matches *" +func (ep *EndpointAny) Matches(entity *intel.Entity) (EPResult, Reason) { + return ep.match(ep, entity, "*", "matches") } func (ep *EndpointAny) String() string { diff --git a/profile/endpoints/endpoint-asn.go b/profile/endpoints/endpoint-asn.go index dee94f2d..6713d199 100644 --- a/profile/endpoints/endpoint-asn.go +++ b/profile/endpoints/endpoint-asn.go @@ -16,24 +16,22 @@ var ( type EndpointASN struct { EndpointBase - ASN uint - Reason string + ASN uint } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointASN) Matches(entity *intel.Entity) (result EPResult, reason string) { - if entity.IP == nil { - return Undeterminable, "" - } - +func (ep *EndpointASN) Matches(entity *intel.Entity) (EPResult, Reason) { asn, ok := entity.GetASN() if !ok { - return Undeterminable, "" + return Undeterminable, nil } + if asn == ep.ASN { - return ep.matchesPPP(entity), ep.Reason + asnStr := strconv.Itoa(int(ep.ASN)) + return ep.match(ep, entity, asnStr, "IP is part of AS") } - return NoMatch, "" + + return NoMatch, nil } func (ep *EndpointASN) String() string { @@ -48,8 +46,7 @@ func parseTypeASN(fields []string) (Endpoint, error) { } ep := &EndpointASN{ - ASN: uint(asn), - Reason: "IP is part of AS" + strconv.FormatInt(int64(asn), 10), + ASN: uint(asn), } return ep.parsePPP(ep, fields) } diff --git a/profile/endpoints/endpoint-country.go b/profile/endpoints/endpoint-country.go index 63d21bfc..85449cf5 100644 --- a/profile/endpoints/endpoint-country.go +++ b/profile/endpoints/endpoint-country.go @@ -19,19 +19,16 @@ type EndpointCountry struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointCountry) Matches(entity *intel.Entity) (result EPResult, reason string) { - if entity.IP == nil { - return Undeterminable, "" - } - +func (ep *EndpointCountry) Matches(entity *intel.Entity) (EPResult, Reason) { country, ok := entity.GetCountry() if !ok { - return Undeterminable, "" + return Undeterminable, nil } + if country == ep.Country { - return ep.matchesPPP(entity), "IP is located in " + country + return ep.match(ep, entity, country, "IP is located in") } - return NoMatch, "" + return NoMatch, nil } func (ep *EndpointCountry) String() string { diff --git a/profile/endpoints/endpoint-domain.go b/profile/endpoints/endpoint-domain.go index 2691ba13..fbd0dcf9 100644 --- a/profile/endpoints/endpoint-domain.go +++ b/profile/endpoints/endpoint-domain.go @@ -28,47 +28,48 @@ type EndpointDomain struct { Domain string DomainZone string MatchType uint8 - Reason string } -func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, string) { +func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, Reason) { + result, reason := ep.match(ep, entity, ep.Domain, "domain matches") + switch ep.MatchType { case domainMatchTypeExact: if domain == ep.Domain { - return ep.matchesPPP(entity), ep.Reason + return result, reason } case domainMatchTypeZone: if domain == ep.Domain { - return ep.matchesPPP(entity), ep.Reason + return result, reason } if strings.HasSuffix(domain, ep.DomainZone) { - return ep.matchesPPP(entity), ep.Reason + return result, reason } case domainMatchTypeSuffix: if strings.HasSuffix(domain, ep.Domain) { - return ep.matchesPPP(entity), ep.Reason + return result, reason } case domainMatchTypePrefix: if strings.HasPrefix(domain, ep.Domain) { - return ep.matchesPPP(entity), ep.Reason + return result, reason } case domainMatchTypeContains: if strings.Contains(domain, ep.Domain) { - return ep.matchesPPP(entity), ep.Reason + return result, reason } } - return NoMatch, "" + return NoMatch, nil } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointDomain) Matches(entity *intel.Entity) (result EPResult, reason string) { +func (ep *EndpointDomain) Matches(entity *intel.Entity) (EPResult, Reason) { if entity.Domain == "" { - return NoMatch, "" + return NoMatch, nil } - result, reason = ep.check(entity, entity.Domain) + result, reason := ep.check(entity, entity.Domain) if result != NoMatch { - return + return result, reason } if entity.CNAMECheckEnabled() { @@ -80,7 +81,7 @@ func (ep *EndpointDomain) Matches(entity *intel.Entity) (result EPResult, reason } } - return NoMatch, "" + return NoMatch, nil } func (ep *EndpointDomain) String() string { @@ -93,7 +94,6 @@ func parseTypeDomain(fields []string) (Endpoint, error) { if domainRegex.MatchString(domain) || altDomainRegex.MatchString(domain) { ep := &EndpointDomain{ OriginalValue: domain, - Reason: "domain matches " + domain, } // fix domain ending diff --git a/profile/endpoints/endpoint-ip.go b/profile/endpoints/endpoint-ip.go index 6fa742c0..43ea47f7 100644 --- a/profile/endpoints/endpoint-ip.go +++ b/profile/endpoints/endpoint-ip.go @@ -10,19 +10,19 @@ import ( type EndpointIP struct { EndpointBase - IP net.IP - Reason string + IP net.IP } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointIP) Matches(entity *intel.Entity) (result EPResult, reason string) { +func (ep *EndpointIP) Matches(entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { - return Undeterminable, "" + return Undeterminable, nil } + if ep.IP.Equal(entity.IP) { - return ep.matchesPPP(entity), ep.Reason + return ep.match(ep, entity, ep.IP.String(), "IP matches") } - return NoMatch, "" + return NoMatch, nil } func (ep *EndpointIP) String() string { @@ -33,8 +33,7 @@ func parseTypeIP(fields []string) (Endpoint, error) { ip := net.ParseIP(fields[1]) if ip != nil { ep := &EndpointIP{ - IP: ip, - Reason: "IP is " + ip.String(), + IP: ip, } return ep.parsePPP(ep, fields) } diff --git a/profile/endpoints/endpoint-iprange.go b/profile/endpoints/endpoint-iprange.go index 16afa4bb..bc0d22fe 100644 --- a/profile/endpoints/endpoint-iprange.go +++ b/profile/endpoints/endpoint-iprange.go @@ -10,19 +10,18 @@ import ( type EndpointIPRange struct { EndpointBase - Net *net.IPNet - Reason string + Net *net.IPNet } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointIPRange) Matches(entity *intel.Entity) (result EPResult, reason string) { +func (ep *EndpointIPRange) Matches(entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { - return Undeterminable, "" + return Undeterminable, nil } if ep.Net.Contains(entity.IP) { - return ep.matchesPPP(entity), ep.Reason + return ep.match(ep, entity, ep.Net.String(), "IP is in") } - return NoMatch, "" + return NoMatch, nil } func (ep *EndpointIPRange) String() string { @@ -33,8 +32,7 @@ func parseTypeIPRange(fields []string) (Endpoint, error) { _, net, err := net.ParseCIDR(fields[1]) if err == nil { ep := &EndpointIPRange{ - Net: net, - Reason: "IP is part of " + net.String(), + Net: net, } return ep.parsePPP(ep, fields) } diff --git a/profile/endpoints/endpoint-lists.go b/profile/endpoints/endpoint-lists.go index b11da060..27ec8b00 100644 --- a/profile/endpoints/endpoint-lists.go +++ b/profile/endpoints/endpoint-lists.go @@ -12,18 +12,19 @@ type EndpointLists struct { ListSet []string Lists string - Reason string } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointLists) Matches(entity *intel.Entity) (result EPResult, reason string) { - entity.LoadLists() - - if entity.MatchLists(ep.ListSet) { - return ep.matchesPPP(entity), entity.ListBlockReason().String() +func (ep *EndpointLists) Matches(entity *intel.Entity) (EPResult, Reason) { + if !entity.LoadLists() { + return Undeterminable, nil } - return NoMatch, "" + if entity.MatchLists(ep.ListSet) { + return ep.match(ep, entity, ep.Lists, "filterlist contains", "filterlist", entity.ListBlockReason()) + } + + return NoMatch, nil } func (ep *EndpointLists) String() string { @@ -36,7 +37,6 @@ func parseTypeList(fields []string) (Endpoint, error) { ep := &EndpointLists{ ListSet: lists, Lists: "L:" + strings.Join(lists, ","), - Reason: "matched lists " + strings.Join(lists, ","), } return ep.parsePPP(ep, fields) } diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index b634b911..43741c40 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -11,7 +11,7 @@ import ( // Endpoint describes an Endpoint Matcher type Endpoint interface { - Matches(entity *intel.Entity) (EPResult, string) + Matches(entity *intel.Entity) (EPResult, Reason) String() string } @@ -24,6 +24,35 @@ type EndpointBase struct { //nolint:maligned // TODO Permitted bool } +func (ep *EndpointBase) match(s fmt.Stringer, entity *intel.Entity, value, desc string, keval ...interface{}) (EPResult, Reason) { + result := ep.matchesPPP(entity) + if result == Undeterminable || result == NoMatch { + return result, nil + } + + return result, ep.makeReason(s, value, desc) +} + +func (ep *EndpointBase) makeReason(s fmt.Stringer, value, desc string, keyval ...interface{}) Reason { + r := &reason{ + description: desc, + Filter: ep.renderPPP(s.String()), + Permitted: ep.Permitted, + Value: value, + } + + r.Extra = make(map[string]interface{}) + + for idx := 0; idx < int(len(keyval)/2); idx += 2 { + key := keyval[idx] + val := keyval[idx+1] + + r.Extra[key.(string)] = val + } + + return r +} + func (ep *EndpointBase) matchesPPP(entity *intel.Entity) (result EPResult) { // only check if protocol is defined if ep.Protocol > 0 { diff --git a/profile/endpoints/endpoints.go b/profile/endpoints/endpoints.go index 7a49a0c8..d8af9423 100644 --- a/profile/endpoints/endpoints.go +++ b/profile/endpoints/endpoints.go @@ -21,6 +21,12 @@ const ( Permitted ) +// IsDecision returns true if result represents a decision +// and false if result is NoMatch or Undeterminable. +func IsDecision(result EPResult) bool { + return result == Denied || result == Permitted +} + // ParseEndpoints parses a list of endpoints and returns a list of Endpoints for matching. func ParseEndpoints(entries []string) (Endpoints, error) { var firstErr error @@ -57,7 +63,7 @@ func (e Endpoints) IsSet() bool { } // Match checks whether the given entity matches any of the endpoint definitions in the list. -func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason string) { +func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason Reason) { for _, entry := range e { if entry != nil { if result, reason = entry.Matches(entity); result != NoMatch { @@ -66,7 +72,7 @@ func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason string) } } - return NoMatch, "" + return NoMatch, nil } func (e Endpoints) String() string { diff --git a/profile/endpoints/reason.go b/profile/endpoints/reason.go new file mode 100644 index 00000000..d137c3b3 --- /dev/null +++ b/profile/endpoints/reason.go @@ -0,0 +1,34 @@ +package endpoints + +// Reason describes the reason why an endpoint has been +// permitted or blocked. +type Reason interface { + // String should return a human readable string + // describing the decision reason. + String() string + + // Context returns the context that was used + // for the decision. + Context() interface{} +} + +type reason struct { + description string + Filter string + Value string + Permitted bool + Extra map[string]interface{} +} + +func (r *reason) String() string { + prefix := "endpoint in blocklist: " + if r.Permitted { + prefix = "endpoint in whitelist: " + } + + return prefix + r.description + " " + r.Value +} + +func (r *reason) Context() interface{} { + return r +} diff --git a/profile/profile-layered.go b/profile/profile-layered.go index edb8a8a2..45311662 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -204,12 +204,12 @@ func (lp *LayeredProfile) DefaultAction() uint8 { } // MatchEndpoint checks if the given endpoint matches an entry in any of the profiles. -func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (result endpoints.EPResult, reason string) { +func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { for _, layer := range lp.layers { if layer.endpoints.IsSet() { - result, reason = layer.endpoints.Match(entity) - if result != endpoints.NoMatch { - return + result, reason := layer.endpoints.Match(entity) + if endpoints.IsDecision(result) { + return result, reason } } } @@ -220,14 +220,14 @@ func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (result endpoints. } // MatchServiceEndpoint checks if the given endpoint of an inbound connection matches an entry in any of the profiles. -func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (result endpoints.EPResult, reason string) { +func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { entity.EnableReverseResolving() for _, layer := range lp.layers { if layer.serviceEndpoints.IsSet() { - result, reason = layer.serviceEndpoints.Match(entity) - if result != endpoints.NoMatch { - return + result, reason := layer.serviceEndpoints.Match(entity) + if endpoints.IsDecision(result) { + return result, reason } } } @@ -239,7 +239,7 @@ func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (result end // MatchFilterLists matches the entity against the set of filter // lists. -func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, string) { +func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { entity.ResolveSubDomainLists(lp.FilterSubDomains()) entity.EnableCNAMECheck(lp.FilterCNAMEs()) @@ -249,10 +249,10 @@ func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPRe entity.LoadLists() if entity.MatchLists(layer.filterListIDs) { - return endpoints.Denied, entity.ListBlockReason().String() + return endpoints.Denied, entity.ListBlockReason() } - return endpoints.NoMatch, "" + return endpoints.NoMatch, nil } } @@ -262,11 +262,11 @@ func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPRe entity.LoadLists() if entity.MatchLists(cfgFilterLists) { - return endpoints.Denied, entity.ListBlockReason().String() + return endpoints.Denied, entity.ListBlockReason() } } - return endpoints.NoMatch, "" + return endpoints.NoMatch, nil } // AddEndpoint adds an endpoint to the local endpoint list, saves the local profile and reloads the configuration. From f2e41a0d32d2f9319cf82945c6ba732ba40d10e6 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Tue, 21 Apr 2020 09:55:49 +0200 Subject: [PATCH 6/7] Refactoring of FilterDNSResponse and DecideOnConnection --- firewall/dns.go | 230 +++++++++++++++++++++ firewall/master.go | 367 +++++++++++++++------------------ intel/block_reason.go | 52 ++++- intel/entity.go | 16 +- nameserver/nameserver.go | 132 +++--------- profile/endpoints/endpoint.go | 6 +- profile/endpoints/endpoints.go | 2 +- resolver/ipinfo.go | 6 +- 8 files changed, 480 insertions(+), 331 deletions(-) create mode 100644 firewall/dns.go diff --git a/firewall/dns.go b/firewall/dns.go new file mode 100644 index 00000000..d6421e2e --- /dev/null +++ b/firewall/dns.go @@ -0,0 +1,230 @@ +package firewall + +import ( + "net" + "os" + "strings" + + "github.com/miekg/dns" + "github.com/safing/portbase/database" + "github.com/safing/portbase/log" + "github.com/safing/portmaster/network" + "github.com/safing/portmaster/network/netutils" + "github.com/safing/portmaster/profile" + "github.com/safing/portmaster/profile/endpoints" + "github.com/safing/portmaster/resolver" +) + +func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, scope int8) ([]dns.RR, []string, int) { + goodEntries := make([]dns.RR, 0, len(entries)) + filteredRecords := make([]string, 0, len(entries)) + + // keeps track of the number of valid and allowed + // A and AAAA records. + var allowedAddressRecords int + + for _, rr := range entries { + // get IP and classification + var ip net.IP + switch v := rr.(type) { + case *dns.A: + ip = v.A + case *dns.AAAA: + ip = v.AAAA + default: + // add non A/AAAA entries + goodEntries = append(goodEntries, rr) + continue + } + classification := netutils.ClassifyIP(ip) + + if p.RemoveOutOfScopeDNS() { + switch { + case classification == netutils.HostLocal: + // No DNS should return localhost addresses + filteredRecords = append(filteredRecords, rr.String()) + continue + case scope == netutils.Global && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): + // No global DNS should return LAN addresses + filteredRecords = append(filteredRecords, rr.String()) + continue + } + } + + if p.RemoveBlockedDNS() { + // filter by flags + switch { + case p.BlockScopeInternet() && classification == netutils.Global: + filteredRecords = append(filteredRecords, rr.String()) + continue + case p.BlockScopeLAN() && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): + filteredRecords = append(filteredRecords, rr.String()) + continue + case p.BlockScopeLocal() && classification == netutils.HostLocal: + filteredRecords = append(filteredRecords, rr.String()) + continue + } + + // TODO: filter by endpoint list (IP only) + } + + // if survived, add to good entries + allowedAddressRecords++ + goodEntries = append(goodEntries, rr) + } + + return goodEntries, filteredRecords, allowedAddressRecords +} + +func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *resolver.RRCache { + // do not modify own queries + if conn.Process().Pid == os.Getpid() { + return rrCache + } + + // get profile + p := conn.Process().Profile() + if p == nil { + conn.Block("no profile") + return nil + } + + // check if DNS response filtering is completely turned off + if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() { + return rrCache + } + + // duplicate entry + rrCache = rrCache.ShallowCopy() + rrCache.FilteredEntries = make([]string, 0) + + var filteredRecords []string + var validIPs int + + rrCache.Answer, filteredRecords, validIPs = filterDNSSection(rrCache.Answer, p, rrCache.ServerScope) + rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...) + + // FIXME(ppacher): should we consider valid IPs from the extra section? + rrCache.Extra, filteredRecords, _ = filterDNSSection(rrCache.Extra, p, rrCache.ServerScope) + rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...) + + if len(rrCache.FilteredEntries) > 0 { + rrCache.Filtered = true + if validIPs == 0 { + conn.Block("no addresses returned for this domain are permitted") + return nil + } + + log.Infof("filter: filtered DNS replies for %s: %s", conn, strings.Join(rrCache.FilteredEntries, ", ")) + } + + return rrCache +} + +// DecideOnResolvedDNS filters a dns response according to the application profile and settings. +func DecideOnResolvedDNS(conn *network.Connection, q *resolver.Query, rrCache *resolver.RRCache) *resolver.RRCache { + updatedRR := filterDNSResponse(conn, rrCache) + if updatedRR == nil { + return nil + } + + updateIPsAndCNAMEs(q, rrCache, conn) + + if mayBlockCNAMEs(conn) { + return nil + } + + // TODO: Gate17 integration + // tunnelInfo, err := AssignTunnelIP(fqdn) + + return updatedRR +} + +func mayBlockCNAMEs(conn *network.Connection) bool { + // if we have CNAMEs and the profile is configured to filter them + // we need to re-check the lists and endpoints here + if conn.Process().Profile().FilterCNAMEs() { + conn.Entity.ResetLists() + conn.Entity.EnableCNAMECheck(true) + + result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) + if result == endpoints.Denied { + conn.BlockWithContext(reason.String(), reason.Context()) + return true + } + + if result == endpoints.NoMatch { + result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) + if result == endpoints.Denied { + conn.BlockWithContext(reason.String(), reason.Context()) + return true + } + } + } + + return false +} + +func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { + // save IP addresses to IPInfo + cnames := make(map[string]string) + ips := make(map[string]struct{}) + + for _, rr := range append(rrCache.Answer, rrCache.Extra...) { + switch v := rr.(type) { + case *dns.CNAME: + cnames[v.Hdr.Name] = v.Target + + case *dns.A: + ips[v.A.String()] = struct{}{} + + case *dns.AAAA: + ips[v.AAAA.String()] = struct{}{} + } + } + + for ip := range ips { + record := resolver.ResolvedDomain{ + Domain: q.FQDN, + } + + // resolve all CNAMEs in the correct order. + var domain = q.FQDN + for { + nextDomain, isCNAME := cnames[domain] + if !isCNAME { + break + } + + record.CNAMEs = append(record.CNAMEs, nextDomain) + domain = nextDomain + } + + // update the entity to include the cnames + conn.Entity.CNAME = record.CNAMEs + + // get the existing IP info or create a new one + var save bool + info, err := resolver.GetIPInfo(ip) + if err != nil { + if err != database.ErrNotFound { + log.Errorf("nameserver: failed to search for IP info record: %s", err) + } + + info = &resolver.IPInfo{ + IP: ip, + } + save = true + } + + // and the new resolved domain record and save + if new := info.AddDomain(record); new { + save = true + } + if save { + if err := info.Save(); err != nil { + log.Errorf("nameserver: failed to save IP info record: %s", err) + } + } + } +} diff --git a/firewall/master.go b/firewall/master.go index cafd11f2..06a0f5f1 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -2,7 +2,6 @@ package firewall import ( "fmt" - "net" "os" "path/filepath" "strings" @@ -14,10 +13,8 @@ import ( "github.com/safing/portmaster/process" "github.com/safing/portmaster/profile" "github.com/safing/portmaster/profile/endpoints" - "github.com/safing/portmaster/resolver" "github.com/agext/levenshtein" - "github.com/miekg/dns" ) // Call order: @@ -26,7 +23,7 @@ import ( // 1. DecideOnConnection // is called when a DNS query is made, may set verdict to Undeterminable to permit a DNS reply. // is called with a nil packet. -// 2. FilterDNSResponse +// 2. DecideOnResolvedDNS // is called to (possibly) filter out A/AAAA records that the filter would deny later. // // Network Connection: @@ -35,7 +32,7 @@ import ( // DecideOnConnection makes a decision about a connection. // When called, the connection and profile is already locked. -func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint:gocognit,gocyclo // TODO +func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { // update profiles and check if communication needs reevaluation if conn.UpdateAndCheck() { log.Infof("filter: re-evaluating verdict on %s", conn) @@ -46,14 +43,47 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: } } + var deciders = []func(*network.Connection, packet.Packet) bool{ + checkPortmasterConnection, + checkSelfCommunication, + checkProfileExists, + checkConnectionType, + checkConnectionScope, + checkEndpointLists, + checkBypassPrevention, + checkFilterLists, + checkInbound, + checkDefaultPermit, + checkAutoPermitRelated, + checkDefaultAction, + } + + for _, decider := range deciders { + if decider(conn, pkt) { + return + } + } + + // DefaultAction == DefaultActionBlock + conn.Deny("endpoint is not whitelisted (default=block)") +} + +// checkPortmasterConnection allows all connection that originate from +// portmaster itself. +func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool { // grant self if conn.Process().Pid == os.Getpid() { log.Infof("filter: granting own connection %s", conn) conn.Verdict = network.VerdictAccept conn.Internal = true - return + return true } + return false +} + +// checkSelfCommunication checks if the process is communicating with itself. +func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { // check if process is communicating with itself if pkt != nil { // TODO: evaluate the case where different IPs in the 127/8 net are used. @@ -77,89 +107,30 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: } else if otherProcess.Pid == conn.Process().Pid { conn.Accept("connection to self") conn.Internal = true - return + return true } } } } - // get profile - p := conn.Process().Profile() - if p == nil { - conn.Block("no profile") - return - } - - // check conn type - switch conn.Scope { - case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: - if p.BlockInbound() { - if conn.Scope == network.IncomingHost { - conn.Block("inbound connections blocked") - } else { - conn.Drop("inbound connections blocked") - } - return - } - case network.PeerLAN, network.PeerInternet, network.PeerInvalid: - // Important: PeerHost is and should be missing! - if p.BlockP2P() { - conn.Block("direct connections (P2P) blocked") - return - } - } - - // check scopes - if conn.Entity.IP != nil { - classification := netutils.ClassifyIP(conn.Entity.IP) - - switch classification { - case netutils.Global, netutils.GlobalMulticast: - if p.BlockScopeInternet() { - conn.Deny("Internet access blocked") // Block Outbound / Drop Inbound - return - } - case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast: - if p.BlockScopeLAN() { - conn.Block("LAN access blocked") // Block Outbound / Drop Inbound - return - } - case netutils.HostLocal: - if p.BlockScopeLocal() { - conn.Block("Localhost access blocked") // Block Outbound / Drop Inbound - return - } - default: // netutils.Invalid - conn.Deny("invalid IP") // Block Outbound / Drop Inbound - return - } - } else if conn.Entity.Domain != "" { - // DNS Query - // DNS is expected to resolve to LAN or Internet addresses - // TODO: handle domains mapped to localhost - if p.BlockScopeInternet() && p.BlockScopeLAN() { - conn.Block("Internet and LAN access blocked") - return - } - } - - if p.PreventBypassing() { - // check for bypass protection - result, reason := PreventBypassing(conn) - switch result { - case endpoints.Denied: - conn.Block("bypass prevention: " + reason) - return - case endpoints.Permitted: - conn.Accept("bypass prevention: " + reason) - return - case endpoints.NoMatch: - } + return false +} + +func checkProfileExists(conn *network.Connection, _ packet.Packet) bool { + if conn.Process().Profile() == nil { + conn.Block("unknown process or profile") + return true } + return false +} +func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool { var result endpoints.EPResult var reason endpoints.Reason + // there must always be a profile. + p := conn.Process().Profile() + // check endpoints list if conn.Inbound { result, reason = p.MatchServiceEndpoint(conn.Entity) @@ -169,167 +140,151 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: switch result { case endpoints.Denied: conn.DenyWithContext(reason.String(), reason.Context()) - return + return true case endpoints.Permitted: conn.AcceptWithContext(reason.String(), reason.Context()) - return + return true } - // continuing with result == NoMatch + return false +} + +func checkConnectionType(conn *network.Connection, _ packet.Packet) bool { + p := conn.Process().Profile() + + // check conn type + switch conn.Scope { + case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: + if p.BlockInbound() { + if conn.Scope == network.IncomingHost { + conn.Block("inbound connections blocked") + } else { + conn.Drop("inbound connections blocked") + } + return true + } + case network.PeerLAN, network.PeerInternet, network.PeerInvalid: + // Important: PeerHost is and should be missing! + if p.BlockP2P() { + conn.Block("direct connections (P2P) blocked") + return true + } + } + + return false +} + +func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool { + p := conn.Process().Profile() + + // check scopes + if conn.Entity.IP != nil { + classification := netutils.ClassifyIP(conn.Entity.IP) + + switch classification { + case netutils.Global, netutils.GlobalMulticast: + if p.BlockScopeInternet() { + conn.Deny("Internet access blocked") // Block Outbound / Drop Inbound + return true + } + case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast: + if p.BlockScopeLAN() { + conn.Block("LAN access blocked") // Block Outbound / Drop Inbound + return true + } + case netutils.HostLocal: + if p.BlockScopeLocal() { + conn.Block("Localhost access blocked") // Block Outbound / Drop Inbound + return true + } + default: // netutils.Invalid + conn.Deny("invalid IP") // Block Outbound / Drop Inbound + return true + } + } else if conn.Entity.Domain != "" { + // DNS Query + // DNS is expected to resolve to LAN or Internet addresses + // TODO: handle domains mapped to localhost + if p.BlockScopeInternet() && p.BlockScopeLAN() { + conn.Block("Internet and LAN access blocked") + return true + } + } + return false +} + +func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool { + if conn.Process().Profile().PreventBypassing() { + // check for bypass protection + result, reason := PreventBypassing(conn) + switch result { + case endpoints.Denied: + conn.Block("bypass prevention: " + reason) + return true + case endpoints.Permitted: + conn.Accept("bypass prevention: " + reason) + return true + case endpoints.NoMatch: + } + } + return false +} + +func checkFilterLists(conn *network.Connection, _ packet.Packet) bool { // apply privacy filter lists - result, reason = p.MatchFilterLists(conn.Entity) + p := conn.Process().Profile() + + result, reason := p.MatchFilterLists(conn.Entity) switch result { case endpoints.Denied: conn.DenyWithContext(reason.String(), reason.Context()) - return + return true case endpoints.NoMatch: // nothing to do default: log.Debugf("filter: filter lists returned unsupported verdict: %s", result) } + return false +} +func checkInbound(conn *network.Connection, _ packet.Packet) bool { // implicit default=block for inbound if conn.Inbound { conn.Drop("endpoint is not whitelisted (incoming is always default=block)") - return + return true } + return false +} +func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool { // check default action + p := conn.Process().Profile() if p.DefaultAction() == profile.DefaultActionPermit { conn.Accept("endpoint is not blacklisted (default=permit)") - return + return true } + return false +} - // check relation +func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool { + p := conn.Process().Profile() if !p.DisableAutoPermit() { related, reason := checkRelation(conn) if related { conn.Accept(reason) - return + return true } } - - // prompt - if p.DefaultAction() == profile.DefaultActionAsk { - prompt(conn, pkt) - return - } - - // DefaultAction == DefaultActionBlock - conn.Deny("endpoint is not whitelisted (default=block)") + return false } -// FilterDNSResponse filters a dns response according to the application profile and settings. -func FilterDNSResponse(conn *network.Connection, q *resolver.Query, rrCache *resolver.RRCache) *resolver.RRCache { //nolint:gocognit // TODO - // do not modify own queries - if conn.Process().Pid == os.Getpid() { - return rrCache - } - - // get profile +func checkDefaultAction(conn *network.Connection, pkt packet.Packet) bool { p := conn.Process().Profile() - if p == nil { - conn.Block("no profile") - return nil + if p.DefaultAction() == profile.DefaultActionAsk { + prompt(conn, pkt) + return true } - - // check if DNS response filtering is completely turned off - if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() { - return rrCache - } - - // duplicate entry - rrCache = rrCache.ShallowCopy() - rrCache.FilteredEntries = make([]string, 0) - - // change information - var addressesRemoved int - var addressesOk int - - // loop vars - var classification int8 - var ip net.IP - - // filter function - filterEntries := func(entries []dns.RR) (goodEntries []dns.RR) { - goodEntries = make([]dns.RR, 0, len(entries)) - - for _, rr := range entries { - - // get IP and classification - switch v := rr.(type) { - case *dns.A: - ip = v.A - case *dns.AAAA: - ip = v.AAAA - default: - // add non A/AAAA entries - goodEntries = append(goodEntries, rr) - continue - } - classification = netutils.ClassifyIP(ip) - - if p.RemoveOutOfScopeDNS() { - switch { - case classification == netutils.HostLocal: - // No DNS should return localhost addresses - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - case rrCache.ServerScope == netutils.Global && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): - // No global DNS should return LAN addresses - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - } - } - - if p.RemoveBlockedDNS() { - // filter by flags - switch { - case p.BlockScopeInternet() && classification == netutils.Global: - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - case p.BlockScopeLAN() && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - case p.BlockScopeLocal() && classification == netutils.HostLocal: - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - } - - // TODO: filter by endpoint list (IP only) - } - - // if survived, add to good entries - addressesOk++ - goodEntries = append(goodEntries, rr) - } - return - } - - rrCache.Answer = filterEntries(rrCache.Answer) - rrCache.Extra = filterEntries(rrCache.Extra) - - if addressesRemoved > 0 { - rrCache.Filtered = true - if addressesOk == 0 { - conn.Block("no addresses returned for this domain are permitted") - return nil - } - } - - if rrCache.Filtered { - log.Infof("filter: filtered DNS replies for %s: %s", conn, strings.Join(rrCache.FilteredEntries, ", ")) - } - - // TODO: Gate17 integration - // tunnelInfo, err := AssignTunnelIP(fqdn) - - return rrCache + return false } // checkRelation tries to find a relation between a process and a communication. This is for better out of the box experience and is _not_ meant to thwart intentional malware. diff --git a/intel/block_reason.go b/intel/block_reason.go index 5a14c6cd..040d6c01 100644 --- a/intel/block_reason.go +++ b/intel/block_reason.go @@ -1,8 +1,12 @@ package intel import ( + "encoding/json" "fmt" "strings" + + "github.com/miekg/dns" + "github.com/safing/portbase/log" ) // ListMatch represents an entity that has been @@ -45,7 +49,49 @@ func (br ListBlockReason) String() string { // Context returns br wrapped into a map. It implements // the endpoints.Reason interface. func (br ListBlockReason) Context() interface{} { - return map[string]interface{}{ - "filterlists": br, - } + return br +} + +// MarshalJSON marshals the list block reason into a map +// prefixed with filterlists. +func (br ListBlockReason) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + // we convert to []ListMatch to avoid recursing + // here. + "filterlists": []ListMatch(br), + }) +} + +// ToRRs returns a set of dns TXT records that describe the +// block reason. +func (br ListBlockReason) ToRRs() []dns.RR { + rrs := make([]dns.RR, 0, len(br)) + + for _, lm := range br { + blockedBy, err := dns.NewRR(fmt.Sprintf( + "%s-blockedBy. 0 IN TXT %q", + strings.TrimRight(lm.Entity, "."), + strings.Join(lm.ActiveLists, ","), + )) + if err == nil { + rrs = append(rrs, blockedBy) + } else { + log.Errorf("intel: failed to create TXT RR for block reason: %s", err) + } + + if len(lm.InactiveLists) > 0 { + wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf( + "%s-wouldBeBlockedBy. 0 IN TXT %q", + strings.TrimRight(lm.Entity, "."), + strings.Join(lm.ActiveLists, ","), + )) + if err == nil { + rrs = append(rrs, wouldBeBlockedBy) + } else { + log.Errorf("intel: failed to create TXT RR for block reason: %s", err) + } + } + } + + return rrs } diff --git a/intel/entity.go b/intel/entity.go index 8d2e21fb..af96343d 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -101,10 +101,10 @@ func (e *Entity) ResetLists() { // list right now so we could be more efficient by keeping // the other lists around. - // FIXME - //e.Lists = nil - //e.ListsMap = nil + e.BlockedByLists = nil + e.BlockedEntities = nil e.ListOccurences = nil + e.domainListLoaded = false e.ipListLoaded = false e.countryListLoaded = false @@ -421,15 +421,12 @@ func (e *Entity) getIPLists() { }) } -// LoadLists searches all filterlists for all occurences of +// LoadLists searches all filterlists for all occurrences of // this entity. func (e *Entity) LoadLists() bool { e.getLists() - if e.ListOccurences == nil { - return false - } - return true + return e.ListOccurences != nil } // MatchLists matches the entities lists against a slice @@ -450,6 +447,7 @@ func (e *Entity) MatchLists(lists []string) bool { } makeDistinct(e.BlockedByLists) + makeDistinct(e.BlockedEntities) return len(e.BlockedByLists) > 0 } @@ -503,7 +501,7 @@ func mergeStringList(a, b []string) []string { func makeDistinct(slice []string) []string { m := make(map[string]struct{}, len(slice)) - var result []string + result := make([]string, 0, len(slice)) for _, v := range slice { if _, ok := m[v]; ok { diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 9b3d8016..03d71701 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -7,7 +7,6 @@ import ( "net" "strings" - "github.com/safing/portbase/database" "github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/log" @@ -17,7 +16,6 @@ import ( "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" - "github.com/safing/portmaster/profile/endpoints" "github.com/safing/portmaster/resolver" "github.com/miekg/dns" @@ -91,12 +89,27 @@ func stop() error { return nil } -func returnNXDomain(w dns.ResponseWriter, query *dns.Msg, reason string) { +func returnNXDomain(w dns.ResponseWriter, query *dns.Msg, reason string, reasonContext interface{}) { m := new(dns.Msg) m.SetRcode(query, dns.RcodeNameError) - rr, _ := dns.NewRR("portmaster.block.reason. 0 IN TXT " + fmt.Sprintf("%q", reason)) + rr, _ := dns.NewRR("portmaster.block-reason. 0 IN TXT " + fmt.Sprintf("%q", reason)) m.Extra = []dns.RR{rr} - _ = w.WriteMsg(m) + + if reasonContext != nil { + if v, ok := reasonContext.(interface { + ToRRs() []dns.RR + }); ok { + m.Extra = append(m.Extra, v.ToRRs()...) + } else if v, ok := reasonContext.(interface { + ToRR() dns.RR + }); ok { + m.Extra = append(m.Extra, v.ToRR()) + } + } + + if err := w.WriteMsg(m); err != nil { + log.Errorf("nameserver: failed to send response: %s", err) + } } func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) { @@ -132,7 +145,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er if question.Qclass != dns.ClassINET { // we only serve IN records, return nxdomain log.Warningf("nameserver: only IN record requests are supported but received Qclass %d, returning NXDOMAIN", question.Qclass) - returnNXDomain(w, query, "wrong type") + returnNXDomain(w, query, "wrong type", nil) return nil } @@ -172,7 +185,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // check if valid domain name if !netutils.IsValidFqdn(q.FQDN) { log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN) - returnNXDomain(w, query, "invalid domain") + returnNXDomain(w, query, "invalid domain", nil) return nil } @@ -208,10 +221,10 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // TODO: this has been obsoleted due to special profiles if conn.Process().Profile() == nil { tracer.Infof("nameserver: failed to find process for request %s, returning NXDOMAIN", conn) - returnNXDomain(w, query, "unknown process") // NOTE(ppacher): saving unknown process connection might end up in a lot of // processes. Consider disabling that via config. conn.Failed("Unknown process") + returnNXDomain(w, query, "unknown process", conn.ReasonContext) return nil } @@ -224,8 +237,8 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) if lms < 10 { tracer.Warningf("nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", conn.Process(), q.FQDN, lms) - returnNXDomain(w, query, "lms") conn.Block("Possible data tunnel") + returnNXDomain(w, query, "lms", conn.ReasonContext) return nil } @@ -235,7 +248,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er switch conn.Verdict { case network.VerdictBlock: tracer.Infof("nameserver: %s blocked, returning nxdomain", conn) - returnNXDomain(w, query, conn.Reason) + returnNXDomain(w, query, conn.Reason, conn.ReasonContext) return nil case network.VerdictDrop, network.VerdictFailed: tracer.Infof("nameserver: %s dropped, not replying", conn) @@ -254,45 +267,16 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er conn.Failed("failed to resolve: " + err.Error()) } - returnNXDomain(w, query, conn.Reason) + returnNXDomain(w, query, conn.Reason, conn.ReasonContext) return nil } - // filter DNS response - rrCache = firewall.FilterDNSResponse(conn, q, rrCache) - // TODO: FilterDNSResponse also sets a connection verdict + rrCache = firewall.DecideOnResolvedDNS(conn, q, rrCache) if rrCache == nil { - tracer.Infof("nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", conn) - returnNXDomain(w, query, conn.Reason) - conn.Block("DNS response filtered") + returnNXDomain(w, query, conn.Reason, conn.ReasonContext) return nil } - updateIPsAndCNAMEs(q, rrCache, conn) - - // if we have CNAMEs and the profile is configured to filter them - // we need to re-check the lists and endpoints here - if conn.Process().Profile().FilterCNAMEs() { - conn.Entity.ResetLists() - conn.Entity.EnableCNAMECheck(true) - - result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) - if result == endpoints.Denied { - conn.BlockWithContext(reason.String(), reason.Context()) - returnNXDomain(w, query, conn.Reason) - return nil - } - - if result == endpoints.NoMatch { - result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) - if result == endpoints.Denied { - conn.BlockWithContext(reason.String(), reason.Context()) - returnNXDomain(w, query, conn.Reason) - return nil - } - } - } - // reply to query m := new(dns.Msg) m.SetReply(query) @@ -311,67 +295,3 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er return nil } - -func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { - // save IP addresses to IPInfo - cnames := make(map[string]string) - ips := make(map[string]struct{}) - - for _, rr := range append(rrCache.Answer, rrCache.Extra...) { - switch v := rr.(type) { - case *dns.CNAME: - cnames[v.Hdr.Name] = v.Target - - case *dns.A: - ips[v.A.String()] = struct{}{} - - case *dns.AAAA: - ips[v.AAAA.String()] = struct{}{} - } - } - - for ip := range ips { - record := resolver.ResolvedDomain{ - Domain: q.FQDN, - } - - // resolve all CNAMEs in the correct order. - var domain = q.FQDN - for { - nextDomain, isCNAME := cnames[domain] - if !isCNAME { - break - } - - record.CNAMEs = append(record.CNAMEs, nextDomain) - domain = nextDomain - } - - // update the entity to include the cnames - conn.Entity.CNAME = record.CNAMEs - - // get the existing IP info or create a new one - var save bool - info, err := resolver.GetIPInfo(ip) - if err != nil { - if err != database.ErrNotFound { - log.Errorf("nameserver: failed to search for IP info record: %s", err) - } - - info = &resolver.IPInfo{ - IP: ip, - } - save = true - } - - // and the new resolved domain record and save - if new := info.AddDomain(record); new { - save = true - } - if save { - if err := info.Save(); err != nil { - log.Errorf("nameserver: failed to save IP info record: %s", err) - } - } - } -} diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index 43741c40..76847ac7 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -24,13 +24,13 @@ type EndpointBase struct { //nolint:maligned // TODO Permitted bool } -func (ep *EndpointBase) match(s fmt.Stringer, entity *intel.Entity, value, desc string, keval ...interface{}) (EPResult, Reason) { +func (ep *EndpointBase) match(s fmt.Stringer, entity *intel.Entity, value, desc string, keyval ...interface{}) (EPResult, Reason) { result := ep.matchesPPP(entity) if result == Undeterminable || result == NoMatch { return result, nil } - return result, ep.makeReason(s, value, desc) + return result, ep.makeReason(s, value, desc, keyval...) } func (ep *EndpointBase) makeReason(s fmt.Stringer, value, desc string, keyval ...interface{}) Reason { @@ -43,7 +43,7 @@ func (ep *EndpointBase) makeReason(s fmt.Stringer, value, desc string, keyval .. r.Extra = make(map[string]interface{}) - for idx := 0; idx < int(len(keyval)/2); idx += 2 { + for idx := 0; idx < len(keyval)/2; idx += 2 { key := keyval[idx] val := keyval[idx+1] diff --git a/profile/endpoints/endpoints.go b/profile/endpoints/endpoints.go index d8af9423..f74edd24 100644 --- a/profile/endpoints/endpoints.go +++ b/profile/endpoints/endpoints.go @@ -24,7 +24,7 @@ const ( // IsDecision returns true if result represents a decision // and false if result is NoMatch or Undeterminable. func IsDecision(result EPResult) bool { - return result == Denied || result == Permitted + return result == Denied || result == Permitted || result == Undeterminable } // ParseEndpoints parses a list of endpoints and returns a list of Endpoints for matching. diff --git a/resolver/ipinfo.go b/resolver/ipinfo.go index 03c0b59b..0ecf9766 100644 --- a/resolver/ipinfo.go +++ b/resolver/ipinfo.go @@ -47,9 +47,9 @@ type ResolvedDomains []ResolvedDomain // String returns a string representation of all domains joined // to a single string. func (rds ResolvedDomains) String() string { - var domains []string - for _, n := range rds { - domains = append(domains, n.String()) + domains := make([]string, len(rds)) + for idx, n := range rds { + domains[idx] = n.String() } return strings.Join(domains, " or ") } From a07c36159be8d058c0f85eec2375d546ae01d3d7 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Tue, 21 Apr 2020 09:59:36 +0200 Subject: [PATCH 7/7] Minor bug fix --- firewall/dns.go | 2 +- intel/block_reason.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/firewall/dns.go b/firewall/dns.go index d6421e2e..d9a868cb 100644 --- a/firewall/dns.go +++ b/firewall/dns.go @@ -104,7 +104,7 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *res rrCache.Answer, filteredRecords, validIPs = filterDNSSection(rrCache.Answer, p, rrCache.ServerScope) rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...) - // FIXME(ppacher): should we consider valid IPs from the extra section? + // we don't count the valid IPs in the extra section rrCache.Extra, filteredRecords, _ = filterDNSSection(rrCache.Extra, p, rrCache.ServerScope) rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...) diff --git a/intel/block_reason.go b/intel/block_reason.go index 040d6c01..09b89db2 100644 --- a/intel/block_reason.go +++ b/intel/block_reason.go @@ -83,7 +83,7 @@ func (br ListBlockReason) ToRRs() []dns.RR { wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf( "%s-wouldBeBlockedBy. 0 IN TXT %q", strings.TrimRight(lm.Entity, "."), - strings.Join(lm.ActiveLists, ","), + strings.Join(lm.InactiveLists, ","), )) if err == nil { rrs = append(rrs, wouldBeBlockedBy)