From f2e41a0d32d2f9319cf82945c6ba732ba40d10e6 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Tue, 21 Apr 2020 09:55:49 +0200 Subject: [PATCH] Refactoring of FilterDNSResponse and DecideOnConnection --- firewall/dns.go | 230 +++++++++++++++++++++ firewall/master.go | 367 +++++++++++++++------------------ intel/block_reason.go | 52 ++++- intel/entity.go | 16 +- nameserver/nameserver.go | 132 +++--------- profile/endpoints/endpoint.go | 6 +- profile/endpoints/endpoints.go | 2 +- resolver/ipinfo.go | 6 +- 8 files changed, 480 insertions(+), 331 deletions(-) create mode 100644 firewall/dns.go diff --git a/firewall/dns.go b/firewall/dns.go new file mode 100644 index 00000000..d6421e2e --- /dev/null +++ b/firewall/dns.go @@ -0,0 +1,230 @@ +package firewall + +import ( + "net" + "os" + "strings" + + "github.com/miekg/dns" + "github.com/safing/portbase/database" + "github.com/safing/portbase/log" + "github.com/safing/portmaster/network" + "github.com/safing/portmaster/network/netutils" + "github.com/safing/portmaster/profile" + "github.com/safing/portmaster/profile/endpoints" + "github.com/safing/portmaster/resolver" +) + +func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, scope int8) ([]dns.RR, []string, int) { + goodEntries := make([]dns.RR, 0, len(entries)) + filteredRecords := make([]string, 0, len(entries)) + + // keeps track of the number of valid and allowed + // A and AAAA records. + var allowedAddressRecords int + + for _, rr := range entries { + // get IP and classification + var ip net.IP + switch v := rr.(type) { + case *dns.A: + ip = v.A + case *dns.AAAA: + ip = v.AAAA + default: + // add non A/AAAA entries + goodEntries = append(goodEntries, rr) + continue + } + classification := netutils.ClassifyIP(ip) + + if p.RemoveOutOfScopeDNS() { + switch { + case classification == netutils.HostLocal: + // No DNS should return localhost addresses + filteredRecords = append(filteredRecords, rr.String()) + continue + case scope == netutils.Global && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): + // No global DNS should return LAN addresses + filteredRecords = append(filteredRecords, rr.String()) + continue + } + } + + if p.RemoveBlockedDNS() { + // filter by flags + switch { + case p.BlockScopeInternet() && classification == netutils.Global: + filteredRecords = append(filteredRecords, rr.String()) + continue + case p.BlockScopeLAN() && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): + filteredRecords = append(filteredRecords, rr.String()) + continue + case p.BlockScopeLocal() && classification == netutils.HostLocal: + filteredRecords = append(filteredRecords, rr.String()) + continue + } + + // TODO: filter by endpoint list (IP only) + } + + // if survived, add to good entries + allowedAddressRecords++ + goodEntries = append(goodEntries, rr) + } + + return goodEntries, filteredRecords, allowedAddressRecords +} + +func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *resolver.RRCache { + // do not modify own queries + if conn.Process().Pid == os.Getpid() { + return rrCache + } + + // get profile + p := conn.Process().Profile() + if p == nil { + conn.Block("no profile") + return nil + } + + // check if DNS response filtering is completely turned off + if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() { + return rrCache + } + + // duplicate entry + rrCache = rrCache.ShallowCopy() + rrCache.FilteredEntries = make([]string, 0) + + var filteredRecords []string + var validIPs int + + rrCache.Answer, filteredRecords, validIPs = filterDNSSection(rrCache.Answer, p, rrCache.ServerScope) + rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...) + + // FIXME(ppacher): should we consider valid IPs from the extra section? + rrCache.Extra, filteredRecords, _ = filterDNSSection(rrCache.Extra, p, rrCache.ServerScope) + rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...) + + if len(rrCache.FilteredEntries) > 0 { + rrCache.Filtered = true + if validIPs == 0 { + conn.Block("no addresses returned for this domain are permitted") + return nil + } + + log.Infof("filter: filtered DNS replies for %s: %s", conn, strings.Join(rrCache.FilteredEntries, ", ")) + } + + return rrCache +} + +// DecideOnResolvedDNS filters a dns response according to the application profile and settings. +func DecideOnResolvedDNS(conn *network.Connection, q *resolver.Query, rrCache *resolver.RRCache) *resolver.RRCache { + updatedRR := filterDNSResponse(conn, rrCache) + if updatedRR == nil { + return nil + } + + updateIPsAndCNAMEs(q, rrCache, conn) + + if mayBlockCNAMEs(conn) { + return nil + } + + // TODO: Gate17 integration + // tunnelInfo, err := AssignTunnelIP(fqdn) + + return updatedRR +} + +func mayBlockCNAMEs(conn *network.Connection) bool { + // if we have CNAMEs and the profile is configured to filter them + // we need to re-check the lists and endpoints here + if conn.Process().Profile().FilterCNAMEs() { + conn.Entity.ResetLists() + conn.Entity.EnableCNAMECheck(true) + + result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) + if result == endpoints.Denied { + conn.BlockWithContext(reason.String(), reason.Context()) + return true + } + + if result == endpoints.NoMatch { + result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) + if result == endpoints.Denied { + conn.BlockWithContext(reason.String(), reason.Context()) + return true + } + } + } + + return false +} + +func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { + // save IP addresses to IPInfo + cnames := make(map[string]string) + ips := make(map[string]struct{}) + + for _, rr := range append(rrCache.Answer, rrCache.Extra...) { + switch v := rr.(type) { + case *dns.CNAME: + cnames[v.Hdr.Name] = v.Target + + case *dns.A: + ips[v.A.String()] = struct{}{} + + case *dns.AAAA: + ips[v.AAAA.String()] = struct{}{} + } + } + + for ip := range ips { + record := resolver.ResolvedDomain{ + Domain: q.FQDN, + } + + // resolve all CNAMEs in the correct order. + var domain = q.FQDN + for { + nextDomain, isCNAME := cnames[domain] + if !isCNAME { + break + } + + record.CNAMEs = append(record.CNAMEs, nextDomain) + domain = nextDomain + } + + // update the entity to include the cnames + conn.Entity.CNAME = record.CNAMEs + + // get the existing IP info or create a new one + var save bool + info, err := resolver.GetIPInfo(ip) + if err != nil { + if err != database.ErrNotFound { + log.Errorf("nameserver: failed to search for IP info record: %s", err) + } + + info = &resolver.IPInfo{ + IP: ip, + } + save = true + } + + // and the new resolved domain record and save + if new := info.AddDomain(record); new { + save = true + } + if save { + if err := info.Save(); err != nil { + log.Errorf("nameserver: failed to save IP info record: %s", err) + } + } + } +} diff --git a/firewall/master.go b/firewall/master.go index cafd11f2..06a0f5f1 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -2,7 +2,6 @@ package firewall import ( "fmt" - "net" "os" "path/filepath" "strings" @@ -14,10 +13,8 @@ import ( "github.com/safing/portmaster/process" "github.com/safing/portmaster/profile" "github.com/safing/portmaster/profile/endpoints" - "github.com/safing/portmaster/resolver" "github.com/agext/levenshtein" - "github.com/miekg/dns" ) // Call order: @@ -26,7 +23,7 @@ import ( // 1. DecideOnConnection // is called when a DNS query is made, may set verdict to Undeterminable to permit a DNS reply. // is called with a nil packet. -// 2. FilterDNSResponse +// 2. DecideOnResolvedDNS // is called to (possibly) filter out A/AAAA records that the filter would deny later. // // Network Connection: @@ -35,7 +32,7 @@ import ( // DecideOnConnection makes a decision about a connection. // When called, the connection and profile is already locked. -func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint:gocognit,gocyclo // TODO +func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { // update profiles and check if communication needs reevaluation if conn.UpdateAndCheck() { log.Infof("filter: re-evaluating verdict on %s", conn) @@ -46,14 +43,47 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: } } + var deciders = []func(*network.Connection, packet.Packet) bool{ + checkPortmasterConnection, + checkSelfCommunication, + checkProfileExists, + checkConnectionType, + checkConnectionScope, + checkEndpointLists, + checkBypassPrevention, + checkFilterLists, + checkInbound, + checkDefaultPermit, + checkAutoPermitRelated, + checkDefaultAction, + } + + for _, decider := range deciders { + if decider(conn, pkt) { + return + } + } + + // DefaultAction == DefaultActionBlock + conn.Deny("endpoint is not whitelisted (default=block)") +} + +// checkPortmasterConnection allows all connection that originate from +// portmaster itself. +func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool { // grant self if conn.Process().Pid == os.Getpid() { log.Infof("filter: granting own connection %s", conn) conn.Verdict = network.VerdictAccept conn.Internal = true - return + return true } + return false +} + +// checkSelfCommunication checks if the process is communicating with itself. +func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { // check if process is communicating with itself if pkt != nil { // TODO: evaluate the case where different IPs in the 127/8 net are used. @@ -77,89 +107,30 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: } else if otherProcess.Pid == conn.Process().Pid { conn.Accept("connection to self") conn.Internal = true - return + return true } } } } - // get profile - p := conn.Process().Profile() - if p == nil { - conn.Block("no profile") - return - } - - // check conn type - switch conn.Scope { - case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: - if p.BlockInbound() { - if conn.Scope == network.IncomingHost { - conn.Block("inbound connections blocked") - } else { - conn.Drop("inbound connections blocked") - } - return - } - case network.PeerLAN, network.PeerInternet, network.PeerInvalid: - // Important: PeerHost is and should be missing! - if p.BlockP2P() { - conn.Block("direct connections (P2P) blocked") - return - } - } - - // check scopes - if conn.Entity.IP != nil { - classification := netutils.ClassifyIP(conn.Entity.IP) - - switch classification { - case netutils.Global, netutils.GlobalMulticast: - if p.BlockScopeInternet() { - conn.Deny("Internet access blocked") // Block Outbound / Drop Inbound - return - } - case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast: - if p.BlockScopeLAN() { - conn.Block("LAN access blocked") // Block Outbound / Drop Inbound - return - } - case netutils.HostLocal: - if p.BlockScopeLocal() { - conn.Block("Localhost access blocked") // Block Outbound / Drop Inbound - return - } - default: // netutils.Invalid - conn.Deny("invalid IP") // Block Outbound / Drop Inbound - return - } - } else if conn.Entity.Domain != "" { - // DNS Query - // DNS is expected to resolve to LAN or Internet addresses - // TODO: handle domains mapped to localhost - if p.BlockScopeInternet() && p.BlockScopeLAN() { - conn.Block("Internet and LAN access blocked") - return - } - } - - if p.PreventBypassing() { - // check for bypass protection - result, reason := PreventBypassing(conn) - switch result { - case endpoints.Denied: - conn.Block("bypass prevention: " + reason) - return - case endpoints.Permitted: - conn.Accept("bypass prevention: " + reason) - return - case endpoints.NoMatch: - } + return false +} + +func checkProfileExists(conn *network.Connection, _ packet.Packet) bool { + if conn.Process().Profile() == nil { + conn.Block("unknown process or profile") + return true } + return false +} +func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool { var result endpoints.EPResult var reason endpoints.Reason + // there must always be a profile. + p := conn.Process().Profile() + // check endpoints list if conn.Inbound { result, reason = p.MatchServiceEndpoint(conn.Entity) @@ -169,167 +140,151 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: switch result { case endpoints.Denied: conn.DenyWithContext(reason.String(), reason.Context()) - return + return true case endpoints.Permitted: conn.AcceptWithContext(reason.String(), reason.Context()) - return + return true } - // continuing with result == NoMatch + return false +} + +func checkConnectionType(conn *network.Connection, _ packet.Packet) bool { + p := conn.Process().Profile() + + // check conn type + switch conn.Scope { + case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: + if p.BlockInbound() { + if conn.Scope == network.IncomingHost { + conn.Block("inbound connections blocked") + } else { + conn.Drop("inbound connections blocked") + } + return true + } + case network.PeerLAN, network.PeerInternet, network.PeerInvalid: + // Important: PeerHost is and should be missing! + if p.BlockP2P() { + conn.Block("direct connections (P2P) blocked") + return true + } + } + + return false +} + +func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool { + p := conn.Process().Profile() + + // check scopes + if conn.Entity.IP != nil { + classification := netutils.ClassifyIP(conn.Entity.IP) + + switch classification { + case netutils.Global, netutils.GlobalMulticast: + if p.BlockScopeInternet() { + conn.Deny("Internet access blocked") // Block Outbound / Drop Inbound + return true + } + case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast: + if p.BlockScopeLAN() { + conn.Block("LAN access blocked") // Block Outbound / Drop Inbound + return true + } + case netutils.HostLocal: + if p.BlockScopeLocal() { + conn.Block("Localhost access blocked") // Block Outbound / Drop Inbound + return true + } + default: // netutils.Invalid + conn.Deny("invalid IP") // Block Outbound / Drop Inbound + return true + } + } else if conn.Entity.Domain != "" { + // DNS Query + // DNS is expected to resolve to LAN or Internet addresses + // TODO: handle domains mapped to localhost + if p.BlockScopeInternet() && p.BlockScopeLAN() { + conn.Block("Internet and LAN access blocked") + return true + } + } + return false +} + +func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool { + if conn.Process().Profile().PreventBypassing() { + // check for bypass protection + result, reason := PreventBypassing(conn) + switch result { + case endpoints.Denied: + conn.Block("bypass prevention: " + reason) + return true + case endpoints.Permitted: + conn.Accept("bypass prevention: " + reason) + return true + case endpoints.NoMatch: + } + } + return false +} + +func checkFilterLists(conn *network.Connection, _ packet.Packet) bool { // apply privacy filter lists - result, reason = p.MatchFilterLists(conn.Entity) + p := conn.Process().Profile() + + result, reason := p.MatchFilterLists(conn.Entity) switch result { case endpoints.Denied: conn.DenyWithContext(reason.String(), reason.Context()) - return + return true case endpoints.NoMatch: // nothing to do default: log.Debugf("filter: filter lists returned unsupported verdict: %s", result) } + return false +} +func checkInbound(conn *network.Connection, _ packet.Packet) bool { // implicit default=block for inbound if conn.Inbound { conn.Drop("endpoint is not whitelisted (incoming is always default=block)") - return + return true } + return false +} +func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool { // check default action + p := conn.Process().Profile() if p.DefaultAction() == profile.DefaultActionPermit { conn.Accept("endpoint is not blacklisted (default=permit)") - return + return true } + return false +} - // check relation +func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool { + p := conn.Process().Profile() if !p.DisableAutoPermit() { related, reason := checkRelation(conn) if related { conn.Accept(reason) - return + return true } } - - // prompt - if p.DefaultAction() == profile.DefaultActionAsk { - prompt(conn, pkt) - return - } - - // DefaultAction == DefaultActionBlock - conn.Deny("endpoint is not whitelisted (default=block)") + return false } -// FilterDNSResponse filters a dns response according to the application profile and settings. -func FilterDNSResponse(conn *network.Connection, q *resolver.Query, rrCache *resolver.RRCache) *resolver.RRCache { //nolint:gocognit // TODO - // do not modify own queries - if conn.Process().Pid == os.Getpid() { - return rrCache - } - - // get profile +func checkDefaultAction(conn *network.Connection, pkt packet.Packet) bool { p := conn.Process().Profile() - if p == nil { - conn.Block("no profile") - return nil + if p.DefaultAction() == profile.DefaultActionAsk { + prompt(conn, pkt) + return true } - - // check if DNS response filtering is completely turned off - if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() { - return rrCache - } - - // duplicate entry - rrCache = rrCache.ShallowCopy() - rrCache.FilteredEntries = make([]string, 0) - - // change information - var addressesRemoved int - var addressesOk int - - // loop vars - var classification int8 - var ip net.IP - - // filter function - filterEntries := func(entries []dns.RR) (goodEntries []dns.RR) { - goodEntries = make([]dns.RR, 0, len(entries)) - - for _, rr := range entries { - - // get IP and classification - switch v := rr.(type) { - case *dns.A: - ip = v.A - case *dns.AAAA: - ip = v.AAAA - default: - // add non A/AAAA entries - goodEntries = append(goodEntries, rr) - continue - } - classification = netutils.ClassifyIP(ip) - - if p.RemoveOutOfScopeDNS() { - switch { - case classification == netutils.HostLocal: - // No DNS should return localhost addresses - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - case rrCache.ServerScope == netutils.Global && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): - // No global DNS should return LAN addresses - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - } - } - - if p.RemoveBlockedDNS() { - // filter by flags - switch { - case p.BlockScopeInternet() && classification == netutils.Global: - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - case p.BlockScopeLAN() && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - case p.BlockScopeLocal() && classification == netutils.HostLocal: - addressesRemoved++ - rrCache.FilteredEntries = append(rrCache.FilteredEntries, rr.String()) - continue - } - - // TODO: filter by endpoint list (IP only) - } - - // if survived, add to good entries - addressesOk++ - goodEntries = append(goodEntries, rr) - } - return - } - - rrCache.Answer = filterEntries(rrCache.Answer) - rrCache.Extra = filterEntries(rrCache.Extra) - - if addressesRemoved > 0 { - rrCache.Filtered = true - if addressesOk == 0 { - conn.Block("no addresses returned for this domain are permitted") - return nil - } - } - - if rrCache.Filtered { - log.Infof("filter: filtered DNS replies for %s: %s", conn, strings.Join(rrCache.FilteredEntries, ", ")) - } - - // TODO: Gate17 integration - // tunnelInfo, err := AssignTunnelIP(fqdn) - - return rrCache + return false } // checkRelation tries to find a relation between a process and a communication. This is for better out of the box experience and is _not_ meant to thwart intentional malware. diff --git a/intel/block_reason.go b/intel/block_reason.go index 5a14c6cd..040d6c01 100644 --- a/intel/block_reason.go +++ b/intel/block_reason.go @@ -1,8 +1,12 @@ package intel import ( + "encoding/json" "fmt" "strings" + + "github.com/miekg/dns" + "github.com/safing/portbase/log" ) // ListMatch represents an entity that has been @@ -45,7 +49,49 @@ func (br ListBlockReason) String() string { // Context returns br wrapped into a map. It implements // the endpoints.Reason interface. func (br ListBlockReason) Context() interface{} { - return map[string]interface{}{ - "filterlists": br, - } + return br +} + +// MarshalJSON marshals the list block reason into a map +// prefixed with filterlists. +func (br ListBlockReason) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + // we convert to []ListMatch to avoid recursing + // here. + "filterlists": []ListMatch(br), + }) +} + +// ToRRs returns a set of dns TXT records that describe the +// block reason. +func (br ListBlockReason) ToRRs() []dns.RR { + rrs := make([]dns.RR, 0, len(br)) + + for _, lm := range br { + blockedBy, err := dns.NewRR(fmt.Sprintf( + "%s-blockedBy. 0 IN TXT %q", + strings.TrimRight(lm.Entity, "."), + strings.Join(lm.ActiveLists, ","), + )) + if err == nil { + rrs = append(rrs, blockedBy) + } else { + log.Errorf("intel: failed to create TXT RR for block reason: %s", err) + } + + if len(lm.InactiveLists) > 0 { + wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf( + "%s-wouldBeBlockedBy. 0 IN TXT %q", + strings.TrimRight(lm.Entity, "."), + strings.Join(lm.ActiveLists, ","), + )) + if err == nil { + rrs = append(rrs, wouldBeBlockedBy) + } else { + log.Errorf("intel: failed to create TXT RR for block reason: %s", err) + } + } + } + + return rrs } diff --git a/intel/entity.go b/intel/entity.go index 8d2e21fb..af96343d 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -101,10 +101,10 @@ func (e *Entity) ResetLists() { // list right now so we could be more efficient by keeping // the other lists around. - // FIXME - //e.Lists = nil - //e.ListsMap = nil + e.BlockedByLists = nil + e.BlockedEntities = nil e.ListOccurences = nil + e.domainListLoaded = false e.ipListLoaded = false e.countryListLoaded = false @@ -421,15 +421,12 @@ func (e *Entity) getIPLists() { }) } -// LoadLists searches all filterlists for all occurences of +// LoadLists searches all filterlists for all occurrences of // this entity. func (e *Entity) LoadLists() bool { e.getLists() - if e.ListOccurences == nil { - return false - } - return true + return e.ListOccurences != nil } // MatchLists matches the entities lists against a slice @@ -450,6 +447,7 @@ func (e *Entity) MatchLists(lists []string) bool { } makeDistinct(e.BlockedByLists) + makeDistinct(e.BlockedEntities) return len(e.BlockedByLists) > 0 } @@ -503,7 +501,7 @@ func mergeStringList(a, b []string) []string { func makeDistinct(slice []string) []string { m := make(map[string]struct{}, len(slice)) - var result []string + result := make([]string, 0, len(slice)) for _, v := range slice { if _, ok := m[v]; ok { diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 9b3d8016..03d71701 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -7,7 +7,6 @@ import ( "net" "strings" - "github.com/safing/portbase/database" "github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/log" @@ -17,7 +16,6 @@ import ( "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" - "github.com/safing/portmaster/profile/endpoints" "github.com/safing/portmaster/resolver" "github.com/miekg/dns" @@ -91,12 +89,27 @@ func stop() error { return nil } -func returnNXDomain(w dns.ResponseWriter, query *dns.Msg, reason string) { +func returnNXDomain(w dns.ResponseWriter, query *dns.Msg, reason string, reasonContext interface{}) { m := new(dns.Msg) m.SetRcode(query, dns.RcodeNameError) - rr, _ := dns.NewRR("portmaster.block.reason. 0 IN TXT " + fmt.Sprintf("%q", reason)) + rr, _ := dns.NewRR("portmaster.block-reason. 0 IN TXT " + fmt.Sprintf("%q", reason)) m.Extra = []dns.RR{rr} - _ = w.WriteMsg(m) + + if reasonContext != nil { + if v, ok := reasonContext.(interface { + ToRRs() []dns.RR + }); ok { + m.Extra = append(m.Extra, v.ToRRs()...) + } else if v, ok := reasonContext.(interface { + ToRR() dns.RR + }); ok { + m.Extra = append(m.Extra, v.ToRR()) + } + } + + if err := w.WriteMsg(m); err != nil { + log.Errorf("nameserver: failed to send response: %s", err) + } } func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) { @@ -132,7 +145,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er if question.Qclass != dns.ClassINET { // we only serve IN records, return nxdomain log.Warningf("nameserver: only IN record requests are supported but received Qclass %d, returning NXDOMAIN", question.Qclass) - returnNXDomain(w, query, "wrong type") + returnNXDomain(w, query, "wrong type", nil) return nil } @@ -172,7 +185,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // check if valid domain name if !netutils.IsValidFqdn(q.FQDN) { log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN) - returnNXDomain(w, query, "invalid domain") + returnNXDomain(w, query, "invalid domain", nil) return nil } @@ -208,10 +221,10 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // TODO: this has been obsoleted due to special profiles if conn.Process().Profile() == nil { tracer.Infof("nameserver: failed to find process for request %s, returning NXDOMAIN", conn) - returnNXDomain(w, query, "unknown process") // NOTE(ppacher): saving unknown process connection might end up in a lot of // processes. Consider disabling that via config. conn.Failed("Unknown process") + returnNXDomain(w, query, "unknown process", conn.ReasonContext) return nil } @@ -224,8 +237,8 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) if lms < 10 { tracer.Warningf("nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", conn.Process(), q.FQDN, lms) - returnNXDomain(w, query, "lms") conn.Block("Possible data tunnel") + returnNXDomain(w, query, "lms", conn.ReasonContext) return nil } @@ -235,7 +248,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er switch conn.Verdict { case network.VerdictBlock: tracer.Infof("nameserver: %s blocked, returning nxdomain", conn) - returnNXDomain(w, query, conn.Reason) + returnNXDomain(w, query, conn.Reason, conn.ReasonContext) return nil case network.VerdictDrop, network.VerdictFailed: tracer.Infof("nameserver: %s dropped, not replying", conn) @@ -254,45 +267,16 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er conn.Failed("failed to resolve: " + err.Error()) } - returnNXDomain(w, query, conn.Reason) + returnNXDomain(w, query, conn.Reason, conn.ReasonContext) return nil } - // filter DNS response - rrCache = firewall.FilterDNSResponse(conn, q, rrCache) - // TODO: FilterDNSResponse also sets a connection verdict + rrCache = firewall.DecideOnResolvedDNS(conn, q, rrCache) if rrCache == nil { - tracer.Infof("nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", conn) - returnNXDomain(w, query, conn.Reason) - conn.Block("DNS response filtered") + returnNXDomain(w, query, conn.Reason, conn.ReasonContext) return nil } - updateIPsAndCNAMEs(q, rrCache, conn) - - // if we have CNAMEs and the profile is configured to filter them - // we need to re-check the lists and endpoints here - if conn.Process().Profile().FilterCNAMEs() { - conn.Entity.ResetLists() - conn.Entity.EnableCNAMECheck(true) - - result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) - if result == endpoints.Denied { - conn.BlockWithContext(reason.String(), reason.Context()) - returnNXDomain(w, query, conn.Reason) - return nil - } - - if result == endpoints.NoMatch { - result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) - if result == endpoints.Denied { - conn.BlockWithContext(reason.String(), reason.Context()) - returnNXDomain(w, query, conn.Reason) - return nil - } - } - } - // reply to query m := new(dns.Msg) m.SetReply(query) @@ -311,67 +295,3 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er return nil } - -func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { - // save IP addresses to IPInfo - cnames := make(map[string]string) - ips := make(map[string]struct{}) - - for _, rr := range append(rrCache.Answer, rrCache.Extra...) { - switch v := rr.(type) { - case *dns.CNAME: - cnames[v.Hdr.Name] = v.Target - - case *dns.A: - ips[v.A.String()] = struct{}{} - - case *dns.AAAA: - ips[v.AAAA.String()] = struct{}{} - } - } - - for ip := range ips { - record := resolver.ResolvedDomain{ - Domain: q.FQDN, - } - - // resolve all CNAMEs in the correct order. - var domain = q.FQDN - for { - nextDomain, isCNAME := cnames[domain] - if !isCNAME { - break - } - - record.CNAMEs = append(record.CNAMEs, nextDomain) - domain = nextDomain - } - - // update the entity to include the cnames - conn.Entity.CNAME = record.CNAMEs - - // get the existing IP info or create a new one - var save bool - info, err := resolver.GetIPInfo(ip) - if err != nil { - if err != database.ErrNotFound { - log.Errorf("nameserver: failed to search for IP info record: %s", err) - } - - info = &resolver.IPInfo{ - IP: ip, - } - save = true - } - - // and the new resolved domain record and save - if new := info.AddDomain(record); new { - save = true - } - if save { - if err := info.Save(); err != nil { - log.Errorf("nameserver: failed to save IP info record: %s", err) - } - } - } -} diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index 43741c40..76847ac7 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -24,13 +24,13 @@ type EndpointBase struct { //nolint:maligned // TODO Permitted bool } -func (ep *EndpointBase) match(s fmt.Stringer, entity *intel.Entity, value, desc string, keval ...interface{}) (EPResult, Reason) { +func (ep *EndpointBase) match(s fmt.Stringer, entity *intel.Entity, value, desc string, keyval ...interface{}) (EPResult, Reason) { result := ep.matchesPPP(entity) if result == Undeterminable || result == NoMatch { return result, nil } - return result, ep.makeReason(s, value, desc) + return result, ep.makeReason(s, value, desc, keyval...) } func (ep *EndpointBase) makeReason(s fmt.Stringer, value, desc string, keyval ...interface{}) Reason { @@ -43,7 +43,7 @@ func (ep *EndpointBase) makeReason(s fmt.Stringer, value, desc string, keyval .. r.Extra = make(map[string]interface{}) - for idx := 0; idx < int(len(keyval)/2); idx += 2 { + for idx := 0; idx < len(keyval)/2; idx += 2 { key := keyval[idx] val := keyval[idx+1] diff --git a/profile/endpoints/endpoints.go b/profile/endpoints/endpoints.go index d8af9423..f74edd24 100644 --- a/profile/endpoints/endpoints.go +++ b/profile/endpoints/endpoints.go @@ -24,7 +24,7 @@ const ( // IsDecision returns true if result represents a decision // and false if result is NoMatch or Undeterminable. func IsDecision(result EPResult) bool { - return result == Denied || result == Permitted + return result == Denied || result == Permitted || result == Undeterminable } // ParseEndpoints parses a list of endpoints and returns a list of Endpoints for matching. diff --git a/resolver/ipinfo.go b/resolver/ipinfo.go index 03c0b59b..0ecf9766 100644 --- a/resolver/ipinfo.go +++ b/resolver/ipinfo.go @@ -47,9 +47,9 @@ type ResolvedDomains []ResolvedDomain // String returns a string representation of all domains joined // to a single string. func (rds ResolvedDomains) String() string { - var domains []string - for _, n := range rds { - domains = append(domains, n.String()) + domains := make([]string, len(rds)) + for idx, n := range rds { + domains[idx] = n.String() } return strings.Join(domains, " or ") }