Merge pull request #174 from safing/feature/ipinfo-scoping

Add scoping to IPInfo
This commit is contained in:
Daniel
2020-10-15 11:39:14 +02:00
committed by GitHub
8 changed files with 174 additions and 129 deletions

View File

@@ -121,9 +121,9 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *res
err, err,
) )
} }
} else if rrCache.TTL > time.Now().Add(10*time.Second).Unix() { } else if rrCache.Expires > time.Now().Add(10*time.Second).Unix() {
// Set a low TTL of 10 seconds if TTL is higher than that. // Set a low TTL of 10 seconds if TTL is higher than that.
rrCache.TTL = time.Now().Add(10 * time.Second).Unix() rrCache.Expires = time.Now().Add(10 * time.Second).Unix()
err := rrCache.Save() err := rrCache.Save()
if err != nil { if err != nil {
log.Debugf( log.Debugf(
@@ -205,10 +205,19 @@ func mayBlockCNAMEs(conn *network.Connection) bool {
return false return false
} }
// updateIPsAndCNAMEs saves all the IP->Name mappings to the cache database and
// updates the CNAMEs in the Connection's Entity.
func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) {
// save IP addresses to IPInfo // Get profileID for scoping IPInfo.
var profileID string
proc := conn.Process()
if proc != nil {
profileID = proc.LocalProfileKey
}
// Collect IPs and CNAMEs.
cnames := make(map[string]string) cnames := make(map[string]string)
ips := make(map[string]struct{}) ips := make([]net.IP, 0, len(rrCache.Answer))
for _, rr := range append(rrCache.Answer, rrCache.Extra...) { for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
switch v := rr.(type) { switch v := rr.(type) {
@@ -216,19 +225,27 @@ func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw
cnames[v.Hdr.Name] = v.Target cnames[v.Hdr.Name] = v.Target
case *dns.A: case *dns.A:
ips[v.A.String()] = struct{}{} ips = append(ips, v.A)
case *dns.AAAA: case *dns.AAAA:
ips[v.AAAA.String()] = struct{}{} ips = append(ips, v.AAAA)
} }
} }
for ip := range ips { // Package IPs and CNAMEs into IPInfo structs.
record := resolver.ResolvedDomain{ for _, ip := range ips {
Domain: q.FQDN, // Never save domain attributions for localhost IPs.
if netutils.ClassifyIP(ip) == netutils.HostLocal {
continue
} }
// resolve all CNAMEs in the correct order. // Create new record for this IP.
record := resolver.ResolvedDomain{
Domain: q.FQDN,
Expires: rrCache.Expires,
}
// Resolve all CNAMEs in the correct order and add the to the record.
var domain = q.FQDN var domain = q.FQDN
for { for {
nextDomain, isCNAME := cnames[domain] nextDomain, isCNAME := cnames[domain]
@@ -240,31 +257,30 @@ func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw
domain = nextDomain domain = nextDomain
} }
// update the entity to include the cnames // Update the entity to include the CNAMEs of the query response.
conn.Entity.CNAME = record.CNAMEs conn.Entity.CNAME = record.CNAMEs
// get the existing IP info or create a new one // Check if there is an existing record for this DNS response.
var save bool // Else create a new one.
info, err := resolver.GetIPInfo(ip) ipString := ip.String()
info, err := resolver.GetIPInfo(profileID, ipString)
if err != nil { if err != nil {
if err != database.ErrNotFound { if err != database.ErrNotFound {
log.Errorf("nameserver: failed to search for IP info record: %s", err) log.Errorf("nameserver: failed to search for IP info record: %s", err)
} }
info = &resolver.IPInfo{ info = &resolver.IPInfo{
IP: ip, IP: ipString,
ProfileID: profileID,
} }
save = true
} }
// and the new resolved domain record and save // Add the new record to the resolved domains for this IP and scope.
if new := info.AddDomain(record); new { info.AddDomain(record)
save = true
} // Save if the record is new or has been updated.
if save { if err := info.Save(); err != nil {
if err := info.Save(); err != nil { log.Errorf("nameserver: failed to save IP info record: %s", err)
log.Errorf("nameserver: failed to save IP info record: %s", err)
}
} }
} }
} }

View File

@@ -139,9 +139,9 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
} }
// check if we can find a domain for that IP // check if we can find a domain for that IP
ipinfo, err := resolver.GetIPInfo(pkt.Info().Dst.String()) ipinfo, err := resolver.GetIPInfo(proc.LocalProfileKey, pkt.Info().Dst.String())
if err == nil { if err == nil {
lastResolvedDomain := ipinfo.ResolvedDomains.MostRecentDomain() lastResolvedDomain := ipinfo.MostRecentDomain()
if lastResolvedDomain != nil { if lastResolvedDomain != nil {
scope = lastResolvedDomain.Domain scope = lastResolvedDomain.Domain
entity.Domain = lastResolvedDomain.Domain entity.Domain = lastResolvedDomain.Domain

View File

@@ -7,12 +7,26 @@ import (
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
"github.com/safing/portbase/utils" )
const (
// IPInfoProfileScopeGlobal is the profile scope used for unscoped IPInfo entries.
IPInfoProfileScopeGlobal = "global"
) )
var ( var (
ipInfoDatabase = database.NewInterface(&database.Options{ ipInfoDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 86400, // 24 hours Local: true,
Internal: true,
// Cache entries because new/updated entries will often be queries soon
// after inserted.
CacheSize: 256,
// We only use the cache database here, so we can delay and batch all our
// writes. Also, no one else accesses these records, so we are fine using
// this.
DelayCachedWrites: "cache",
}) })
) )
@@ -25,6 +39,11 @@ type ResolvedDomain struct {
// CNAMEs is a list of CNAMEs that have been resolved for // CNAMEs is a list of CNAMEs that have been resolved for
// Domain. // Domain.
CNAMEs []string CNAMEs []string
// Expires holds the timestamp when this entry expires.
// This does not mean that the entry may not be used anymore afterwards,
// but that this is used to calcuate the TTL of the database record.
Expires int64
} }
// String returns a string representation of ResolvedDomain including // String returns a string representation of ResolvedDomain including
@@ -54,29 +73,16 @@ func (rds ResolvedDomains) String() string {
return strings.Join(domains, " or ") return strings.Join(domains, " or ")
} }
// MostRecentDomain returns the most recent domain.
func (rds ResolvedDomains) MostRecentDomain() *ResolvedDomain {
if len(rds) == 0 {
return nil
}
// TODO(ppacher): we could also do that by using ResolvedAt()
mostRecent := rds[len(rds)-1]
return &mostRecent
}
// IPInfo represents various information about an IP. // IPInfo represents various information about an IP.
type IPInfo struct { type IPInfo struct {
record.Base record.Base
sync.Mutex sync.Mutex
// IP holds the acutal IP address. // IP holds the actual IP address.
IP string IP string
// Domains holds a list of domains that have been // ProfileID is used to scope this entry to a process group.
// resolved to IP. This field is deprecated and should ProfileID string
// be removed.
// DEPRECATED: remove with alpha.
Domains []string `json:"Domains,omitempty"`
// ResolvedDomain is a slice of domains that // ResolvedDomain is a slice of domains that
// have been requested by various applications // have been requested by various applications
@@ -84,35 +90,43 @@ type IPInfo struct {
ResolvedDomains ResolvedDomains ResolvedDomains ResolvedDomains
} }
// AddDomain adds a new resolved domain to ipi. // AddDomain adds a new resolved domain to IPInfo.
func (ipi *IPInfo) AddDomain(resolved ResolvedDomain) bool { func (info *IPInfo) AddDomain(resolved ResolvedDomain) {
for idx, d := range ipi.ResolvedDomains { info.Lock()
if d.Domain == resolved.Domain { defer info.Unlock()
if utils.StringSliceEqual(d.CNAMEs, resolved.CNAMEs) {
return false
}
// we have a different CNAME chain now, remove the previous // Delete old for the same domain.
// entry and add it at the end. for idx, d := range info.ResolvedDomains {
ipi.ResolvedDomains = append(ipi.ResolvedDomains[:idx], ipi.ResolvedDomains[idx+1:]...) if d.Domain == resolved.Domain {
ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved) info.ResolvedDomains = append(info.ResolvedDomains[:idx], info.ResolvedDomains[idx+1:]...)
return true break
} }
} }
ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved) // Add new entry to the end.
return true info.ResolvedDomains = append(info.ResolvedDomains, resolved)
} }
func makeIPInfoKey(ip string) string { // MostRecentDomain returns the most recent domain.
return fmt.Sprintf("cache:intel/ipInfo/%s", ip) func (info *IPInfo) MostRecentDomain() *ResolvedDomain {
info.Lock()
defer info.Unlock()
if len(info.ResolvedDomains) == 0 {
return nil
}
mostRecent := info.ResolvedDomains[len(info.ResolvedDomains)-1]
return &mostRecent
}
func makeIPInfoKey(profileID, ip string) string {
return fmt.Sprintf("cache:intel/ipInfo/%s/%s", profileID, ip)
} }
// GetIPInfo gets an IPInfo record from the database. // GetIPInfo gets an IPInfo record from the database.
func GetIPInfo(ip string) (*IPInfo, error) { func GetIPInfo(profileID, ip string) (*IPInfo, error) {
key := makeIPInfoKey(ip) r, err := ipInfoDatabase.Get(makeIPInfoKey(profileID, ip))
r, err := ipInfoDatabase.Get(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -126,18 +140,6 @@ func GetIPInfo(ip string) (*IPInfo, error) {
return nil, err return nil, err
} }
// Legacy support,
// DEPRECATED: remove with alpha
if len(new.Domains) > 0 && len(new.ResolvedDomains) == 0 {
for _, d := range new.Domains {
new.ResolvedDomains = append(new.ResolvedDomains, ResolvedDomain{
Domain: d,
// rest is empty...
})
}
new.Domains = nil // clean up so we remove it from the database
}
return new, nil return new, nil
} }
@@ -150,27 +152,38 @@ func GetIPInfo(ip string) (*IPInfo, error) {
} }
// Save saves the IPInfo record to the database. // Save saves the IPInfo record to the database.
func (ipi *IPInfo) Save() error { func (info *IPInfo) Save() error {
ipi.Lock() info.Lock()
if !ipi.KeyIsSet() {
ipi.SetKey(makeIPInfoKey(ipi.IP))
}
ipi.Unlock()
// Legacy support // Set database key if not yet set already.
// Ensure we don't write new Domain fields into the if !info.KeyIsSet() {
// database. // Default to global scope if scope is unset.
// DEPRECATED: remove with alpha if info.ProfileID == "" {
if len(ipi.Domains) > 0 { info.ProfileID = IPInfoProfileScopeGlobal
ipi.Domains = nil }
info.SetKey(makeIPInfoKey(info.ProfileID, info.IP))
} }
return ipInfoDatabase.Put(ipi) // Calculate and set cache expiry.
var expires int64 = 86400 // Minimum TTL of one day.
for _, rd := range info.ResolvedDomains {
if rd.Expires > expires {
expires = rd.Expires
}
}
info.UpdateMeta()
expires += 3600 // Add one hour to expiry as a buffer.
info.Meta().SetAbsoluteExpiry(expires)
info.Unlock()
return ipInfoDatabase.Put(info)
} }
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " // FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) String() string { func (info *IPInfo) String() string {
ipi.Lock() info.Lock()
defer ipi.Unlock() defer info.Unlock()
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.ResolvedDomains.String())
return fmt.Sprintf("<IPInfo[%s] %s: %s>", info.Key(), info.IP, info.ResolvedDomains.String())
} }

View File

@@ -15,7 +15,7 @@ func TestIPInfo(t *testing.T) {
CNAMEs: []string{"example.com"}, CNAMEs: []string{"example.com"},
} }
ipi := &IPInfo{ info := &IPInfo{
IP: "1.2.3.4", IP: "1.2.3.4",
ResolvedDomains: ResolvedDomains{ ResolvedDomains: ResolvedDomains{
example, example,
@@ -27,22 +27,18 @@ func TestIPInfo(t *testing.T) {
Domain: "sub2.example.com", Domain: "sub2.example.com",
CNAMEs: []string{"sub1.example.com", "example.com"}, CNAMEs: []string{"sub1.example.com", "example.com"},
} }
added := ipi.AddDomain(sub2Example) info.AddDomain(sub2Example)
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, info.ResolvedDomains)
assert.True(t, added)
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains)
// try again, should do nothing now // try again, should do nothing now
added = ipi.AddDomain(sub2Example) info.AddDomain(sub2Example)
assert.False(t, added) assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, info.ResolvedDomains)
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains)
subOverWrite := ResolvedDomain{ subOverWrite := ResolvedDomain{
Domain: "sub1.example.com", Domain: "sub1.example.com",
CNAMEs: []string{}, // now without CNAMEs CNAMEs: []string{}, // now without CNAMEs
} }
added = ipi.AddDomain(subOverWrite) info.AddDomain(subOverWrite)
assert.True(t, added) assert.Equal(t, ResolvedDomains{example, sub2Example, subOverWrite}, info.ResolvedDomains)
assert.Equal(t, ResolvedDomains{example, sub2Example, subOverWrite}, ipi.ResolvedDomains)
} }

View File

@@ -93,6 +93,9 @@ func start() error {
listenToMDNS, listenToMDNS,
) )
module.StartServiceWorker("name record delayed cache writer", 0, recordDatabase.DelayedCacheWriter)
module.StartServiceWorker("ip info delayed cache writer", 0, ipInfoDatabase.DelayedCacheWriter)
return nil return nil
} }

View File

@@ -12,10 +12,24 @@ import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
) )
const (
// databaseOvertime defines how much longer than the TTL name records are
// cached in the database.
databaseOvertime = 86400 * 14 // two weeks
)
var ( var (
recordDatabase = database.NewInterface(&database.Options{ recordDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 2592000, // 30 days Local: true,
CacheSize: 256, Internal: true,
// Cache entries because application often resolve domains multiple times.
CacheSize: 256,
// We only use the cache database here, so we can delay and batch all our
// writes. Also, no one else accesses these records, so we are fine using
// this.
DelayCachedWrites: "cache",
}) })
nameRecordsKeyPrefix = "cache:intel/nameRecord/" nameRecordsKeyPrefix = "cache:intel/nameRecord/"
@@ -32,7 +46,7 @@ type NameRecord struct {
Answer []string Answer []string
Ns []string Ns []string
Extra []string Extra []string
TTL int64 Expires int64
Server string Server string
ServerScope int8 ServerScope int8
@@ -84,6 +98,9 @@ func (rec *NameRecord) Save() error {
} }
rec.SetKey(makeNameRecordKey(rec.Domain, rec.Question)) rec.SetKey(makeNameRecordKey(rec.Domain, rec.Question))
rec.UpdateMeta()
rec.Meta().SetAbsoluteExpiry(rec.Expires + databaseOvertime)
return recordDatabase.PutNew(rec) return recordDatabase.PutNew(rec)
} }

View File

@@ -220,19 +220,19 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
log.Tracer(ctx).Tracef( log.Tracer(ctx).Tracef(
"resolver: cache for %s will expire in %s, refreshing async now", "resolver: cache for %s will expire in %s, refreshing async now",
q.ID(), q.ID(),
time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second), time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second),
) )
// resolve async // resolve async
module.StartWorker("resolve async", func(ctx context.Context) error { module.StartWorker("resolve async", func(asyncCtx context.Context) error {
ctx, tracer := log.AddTracer(ctx) tracingCtx, tracer := log.AddTracer(asyncCtx)
defer tracer.Submit() defer tracer.Submit()
tracer.Debugf("resolver: resolving %s async", q.ID()) tracer.Tracef("resolver: resolving %s async", q.ID())
_, err := resolveAndCache(ctx, q, nil) _, err := resolveAndCache(tracingCtx, q, nil)
if err != nil { if err != nil {
tracer.Warningf("resolver: async query for %s failed: %s", q.ID(), err) tracer.Warningf("resolver: async query for %s failed: %s", q.ID(), err)
} else { } else {
tracer.Debugf("resolver: async query for %s succeeded", q.ID()) tracer.Infof("resolver: async query for %s succeeded", q.ID())
} }
return nil return nil
}) })
@@ -242,7 +242,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
log.Tracer(ctx).Tracef( log.Tracer(ctx).Tracef(
"resolver: using cached RR (expires in %s)", "resolver: using cached RR (expires in %s)",
time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second), time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second),
) )
return rrCache return rrCache
} }

View File

@@ -25,10 +25,10 @@ type RRCache struct {
RCode int RCode int
// Response Content // Response Content
Answer []dns.RR Answer []dns.RR
Ns []dns.RR Ns []dns.RR
Extra []dns.RR Extra []dns.RR
TTL int64 Expires int64
// Source Information // Source Information
Server string Server string
@@ -54,12 +54,12 @@ func (rrCache *RRCache) ID() string {
// Expired returns whether the record has expired. // Expired returns whether the record has expired.
func (rrCache *RRCache) Expired() bool { func (rrCache *RRCache) Expired() bool {
return rrCache.TTL <= time.Now().Unix() return rrCache.Expires <= time.Now().Unix()
} }
// ExpiresSoon returns whether the record will expire soon and should already be refreshed. // ExpiresSoon returns whether the record will expire soon and should already be refreshed.
func (rrCache *RRCache) ExpiresSoon() bool { func (rrCache *RRCache) ExpiresSoon() bool {
return rrCache.TTL <= time.Now().Unix()+refreshTTL return rrCache.Expires <= time.Now().Unix()+refreshTTL
} }
// Clean sets all TTLs to 17 and sets cache expiry with specified minimum. // Clean sets all TTLs to 17 and sets cache expiry with specified minimum.
@@ -99,7 +99,7 @@ func (rrCache *RRCache) Clean(minExpires uint32) {
} }
// log.Tracef("lowest TTL is %d", lowestTTL) // log.Tracef("lowest TTL is %d", lowestTTL)
rrCache.TTL = time.Now().Unix() + int64(lowestTTL) rrCache.Expires = time.Now().Unix() + int64(lowestTTL)
} }
// ExportAllARecords return of a list of all A and AAAA IP addresses. // ExportAllARecords return of a list of all A and AAAA IP addresses.
@@ -131,7 +131,7 @@ func (rrCache *RRCache) ToNameRecord() *NameRecord {
Domain: rrCache.Domain, Domain: rrCache.Domain,
Question: rrCache.Question.String(), Question: rrCache.Question.String(),
RCode: rrCache.RCode, RCode: rrCache.RCode,
TTL: rrCache.TTL, Expires: rrCache.Expires,
Server: rrCache.Server, Server: rrCache.Server,
ServerScope: rrCache.ServerScope, ServerScope: rrCache.ServerScope,
ServerInfo: rrCache.ServerInfo, ServerInfo: rrCache.ServerInfo,
@@ -188,7 +188,7 @@ func GetRRCache(domain string, question dns.Type) (*RRCache, error) {
} }
rrCache.RCode = nameRecord.RCode rrCache.RCode = nameRecord.RCode
rrCache.TTL = nameRecord.TTL rrCache.Expires = nameRecord.Expires
for _, entry := range nameRecord.Answer { for _, entry := range nameRecord.Answer {
rrCache.Answer = parseRR(rrCache.Answer, entry) rrCache.Answer = parseRR(rrCache.Answer, entry)
} }
@@ -249,10 +249,10 @@ func (rrCache *RRCache) ShallowCopy() *RRCache {
Question: rrCache.Question, Question: rrCache.Question,
RCode: rrCache.RCode, RCode: rrCache.RCode,
Answer: rrCache.Answer, Answer: rrCache.Answer,
Ns: rrCache.Ns, Ns: rrCache.Ns,
Extra: rrCache.Extra, Extra: rrCache.Extra,
TTL: rrCache.TTL, Expires: rrCache.Expires,
Server: rrCache.Server, Server: rrCache.Server,
ServerScope: rrCache.ServerScope, ServerScope: rrCache.ServerScope,
@@ -310,9 +310,9 @@ func (rrCache *RRCache) GetExtraRRs(ctx context.Context, query *dns.Msg) (extra
// Add expiry and cache information. // Add expiry and cache information.
if rrCache.Expired() { if rrCache.Expired() {
extra = addExtra(ctx, extra, fmt.Sprintf("record expired since %s", time.Since(time.Unix(rrCache.TTL, 0)).Round(time.Second))) extra = addExtra(ctx, extra, fmt.Sprintf("record expired since %s", time.Since(time.Unix(rrCache.Expires, 0)).Round(time.Second)))
} else { } else {
extra = addExtra(ctx, extra, fmt.Sprintf("record valid for %s", time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second))) extra = addExtra(ctx, extra, fmt.Sprintf("record valid for %s", time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second)))
} }
if rrCache.RequestingNew { if rrCache.RequestingNew {
extra = addExtra(ctx, extra, "async request to refresh the cache has been started") extra = addExtra(ctx, extra, "async request to refresh the cache has been started")