From f35872ec5126fdc317fddf244d65929646d52467 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 12 Dec 2018 19:18:23 +0100 Subject: [PATCH] Complete first alpha version --- README.md | 6 + firewall/firewall.go | 74 ++++++----- firewall/master.go | 79 ++++++------ global/databases.go | 3 +- intel/ipinfo.go | 21 +++- intel/ipinfo_test.go | 25 ++++ intel/resolver.go | 4 +- intel/rrcache.go | 4 +- main.go | 32 ++++- nameserver/nameserver.go | 31 +++-- network/clean.go | 26 ++-- network/connection.go | 81 ++++++++----- network/database.go | 21 ++-- network/environment/dbus_linux.go | 6 + network/link.go | 128 +++++++++++++++++--- network/netutils/ip.go | 14 +-- network/netutils/ip_test.go | 2 +- network/status.go | 2 + process/database.go | 19 +-- process/find.go | 21 +++- process/matching.go | 75 ++++++------ process/proc/gather.go | 36 +++++- process/proc/get.go | 16 +-- process/proc/sockets.go | 49 ++++---- process/proc/sockets_test.go | 10 +- process/process.go | 13 +- profile/database.go | 6 +- profile/index/index.go | 2 +- profile/module.go | 2 +- profile/ports_test.go | 4 +- profile/profile.go | 32 ++--- profile/{profileset.go => set.go} | 34 ++++-- profile/{profileset_test.go => set_test.go} | 0 profile/specialprofiles.go | 6 +- profile/updates.go | 8 +- status/module.go | 25 ++++ 36 files changed, 624 insertions(+), 293 deletions(-) create mode 100644 README.md create mode 100644 intel/ipinfo_test.go rename profile/{profileset.go => set.go} (81%) rename profile/{profileset_test.go => set_test.go} (100%) create mode 100644 status/module.go diff --git a/README.md b/README.md new file mode 100644 index 00000000..97f231d6 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +# Portmaster + +## Dependencies + +- Arch/based: libnetfilter_queue +x diff --git a/firewall/firewall.go b/firewall/firewall.go index 634002b0..49890737 100644 --- a/firewall/firewall.go +++ b/firewall/firewall.go @@ -38,7 +38,7 @@ var ( ) func init() { - modules.Register("firewall", prep, start, stop, "global", "network", "nameserver") + modules.Register("firewall", prep, start, stop, "global", "network", "nameserver", "profile") } func prep() (err error) { @@ -112,12 +112,6 @@ func handlePacket(pkt packet.Packet) { return } - // allow anything that goes to a tunnel entrypoint - if pkt.IsOutbound() && (pkt.GetIPHeader().Dst.Equal(tunnelEntry4) || pkt.GetIPHeader().Dst.Equal(tunnelEntry6)) { - pkt.PermanentAccept() - return - } - // log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetConnectionID()) // use this to time how long it takes process packet @@ -147,7 +141,7 @@ func handlePacket(pkt packet.Packet) { link.HandlePacket(pkt) return } - verdict(pkt, link.Verdict) + verdict(pkt, link.GetVerdict()) } @@ -156,42 +150,45 @@ func initialHandler(pkt packet.Packet, link *network.Link) { // get Connection connection, err := network.GetConnectionByFirstPacket(pkt) if err != nil { - link.Lock() if err != process.ErrConnectionNotFound { log.Warningf("firewall: could not find process of packet (dropping link %s): %s", pkt.String(), err) - link.AddReason(fmt.Sprintf("could not find process or it does not exist (unsolicited packet): %s", err)) + link.Deny(fmt.Sprintf("could not find process or it does not exist (unsolicited packet): %s", err)) } else { log.Warningf("firewall: internal error finding process of packet (dropping link %s): %s", pkt.String(), err) - link.AddReason(fmt.Sprintf("internal error finding process: %s", err)) + link.Deny(fmt.Sprintf("internal error finding process: %s", err)) } - link.Unlock() if pkt.IsInbound() { network.UnknownIncomingConnection.AddLink(link) } else { network.UnknownDirectConnection.AddLink(link) } - verdict(pkt, link.Verdict) - return - } - // reroute dns requests to nameserver - if connection.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 { - pkt.RerouteToNameserver() + verdict(pkt, link.GetVerdict()) + link.StopFirewallHandler() + return } // add new Link to Connection (and save both) connection.AddLink(link) + // reroute dns requests to nameserver + if connection.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 { + link.RerouteToNameserver() + verdict(pkt, link.GetVerdict()) + link.StopFirewallHandler() + return + } + // make a decision if not made already - if connection.Verdict == network.UNDECIDED { + if connection.GetVerdict() == network.UNDECIDED { DecideOnConnection(connection, pkt) } - if connection.Verdict != network.CANTSAY { - link.UpdateVerdict(connection.Verdict) - } else { + if connection.GetVerdict() == network.ACCEPT { DecideOnLink(connection, link, pkt) + } else { + link.UpdateVerdict(connection.GetVerdict()) } // log decision @@ -205,7 +202,7 @@ func initialHandler(pkt packet.Packet, link *network.Link) { // // tunnel link, but also inspect (after reroute) // link.Tunneled = true // link.SetFirewallHandler(inspectThenVerdict) - // verdict(pkt, link.Verdict) + // verdict(pkt, link.GetVerdict()) // case port17Active: // // tunnel link, don't inspect // link.Tunneled = true @@ -216,7 +213,7 @@ func initialHandler(pkt packet.Packet, link *network.Link) { inspectThenVerdict(pkt, link) default: link.StopFirewallHandler() - verdict(pkt, link.Verdict) + verdict(pkt, link.GetVerdict()) } } @@ -225,10 +222,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) { pktVerdict, continueInspection := inspection.RunInspectors(pkt, link) if continueInspection { // do not allow to circumvent link decision: e.g. to ACCEPT packets from a DROP-ed link - if pktVerdict > link.Verdict { + linkVerdict := link.GetVerdict() + if pktVerdict > linkVerdict { verdict(pkt, pktVerdict) } else { - verdict(pkt, link.Verdict) + verdict(pkt, linkVerdict) } return } @@ -236,9 +234,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) { // we are done with inspecting link.StopFirewallHandler() + link.Lock() + defer link.Unlock() link.VerdictPermanent = permanentVerdicts() if link.VerdictPermanent { - link.Save() + go link.Save() permanentVerdict(pkt, link.Verdict) } else { verdict(pkt, link.Verdict) @@ -259,6 +259,12 @@ func permanentVerdict(pkt packet.Packet, action network.Verdict) { atomic.AddUint64(packetsDropped, 1) pkt.PermanentDrop() return + case network.RerouteToNameserver: + pkt.RerouteToNameserver() + return + case network.RerouteToTunnel: + pkt.RerouteToTunnel() + return } pkt.Drop() } @@ -277,6 +283,12 @@ func verdict(pkt packet.Packet, action network.Verdict) { atomic.AddUint64(packetsDropped, 1) pkt.Drop() return + case network.RerouteToNameserver: + pkt.RerouteToNameserver() + return + case network.RerouteToTunnel: + pkt.RerouteToTunnel() + return } pkt.Drop() } @@ -295,18 +307,22 @@ func verdict(pkt packet.Packet, action network.Verdict) { // } func logInitialVerdict(link *network.Link) { - // switch link.Verdict { + // switch link.GetVerdict() { // case network.ACCEPT: // log.Infof("firewall: accepting new link: %s", link.String()) // case network.BLOCK: // log.Infof("firewall: blocking new link: %s", link.String()) // case network.DROP: // log.Infof("firewall: dropping new link: %s", link.String()) + // case network.RerouteToNameserver: + // log.Infof("firewall: rerouting new link to nameserver: %s", link.String()) + // case network.RerouteToTunnel: + // log.Infof("firewall: rerouting new link to tunnel: %s", link.String()) // } } func logChangedVerdict(link *network.Link) { - // switch link.Verdict { + // switch link.GetVerdict() { // case network.ACCEPT: // log.Infof("firewall: change! - now accepting link: %s", link.String()) // case network.BLOCK: diff --git a/firewall/master.go b/firewall/master.go index 9e5be52c..61394bef 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -8,6 +8,7 @@ import ( "github.com/Safing/portmaster/intel" "github.com/Safing/portmaster/network" "github.com/Safing/portmaster/network/packet" + "github.com/Safing/portmaster/profile" "github.com/Safing/portmaster/status" "github.com/agext/levenshtein" @@ -59,10 +60,10 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) permitted, ok := profileSet.CheckDomain(fqdn) if ok { if permitted { - log.Infof("firewall: accepting connection %s, domain is whitelisted", connection, domainElement, processElement) + log.Infof("firewall: accepting connection %s, domain is whitelisted", connection) connection.Accept("domain is whitelisted") } else { - log.Infof("firewall: denying connection %s, domain is blacklisted", connection, domainElement, processElement) + log.Infof("firewall: denying connection %s, domain is blacklisted", connection) connection.Deny("domain is blacklisted") } return @@ -70,7 +71,7 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) switch profileSet.GetProfileMode() { case profile.Whitelist: - log.Infof("firewall: denying connection %s, domain is not whitelisted", connection, domainElement, processElement) + log.Infof("firewall: denying connection %s, domain is not whitelisted", connection) connection.Deny("domain is not whitelisted") case profile.Prompt: @@ -97,9 +98,9 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) break matchLoop } } - if levenshtein.Match(domainElement, profile.Name, nil) > 0.5 { + if levenshtein.Match(domainElement, profileSet.UserProfile().Name, nil) > 0.5 { matched = true - processElement = profile.Name + processElement = profileSet.UserProfile().Name break matchLoop } if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 { @@ -107,6 +108,11 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) processElement = connection.Process().Name break matchLoop } + if levenshtein.Match(domainElement, connection.Process().ExecName, nil) > 0.5 { + matched = true + processElement = connection.Process().ExecName + break matchLoop + } } if matched { @@ -115,15 +121,15 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) } } - if connection.Verdict != network.ACCEPT { + if connection.GetVerdict() != network.ACCEPT { // TODO - log.Infof("firewall: accepting connection %s, domain permitted (prompting is not yet implemented)", connection, domainElement, processElement) + log.Infof("firewall: accepting connection %s, domain permitted (prompting is not yet implemented)", connection) connection.Accept("domain permitted (prompting is not yet implemented)") } case profile.Blacklist: - log.Infof("firewall: denying connection %s, domain is not blacklisted", connection, domainElement, processElement) - connection.Deny("domain is not blacklisted") + log.Infof("firewall: accepting connection %s, domain is not blacklisted", connection) + connection.Accept("domain is not blacklisted") } } @@ -175,8 +181,8 @@ func DecideOnConnection(connection *network.Connection, pkt packet.Packet) { } // check if there is a profile - profileSet := connection.Process().ProfileSet - if profile == nil { + profileSet := connection.Process().ProfileSet() + if profileSet == nil { log.Errorf("firewall: denying connection %s, no profile set", connection) connection.Deny("no profile") return @@ -185,17 +191,17 @@ func DecideOnConnection(connection *network.Connection, pkt packet.Packet) { // check connection type switch connection.Domain { - case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid: + case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: if !profileSet.CheckFlag(profile.Service) { log.Infof("firewall: denying connection %s, not a service", connection) - if connection.Domain == IncomingHost { + if connection.Domain == network.IncomingHost { connection.Block("not a service") } else { connection.Drop("not a service") } return } - case PeerLAN, PeerInternet, PeerInvalid: // Important: PeerHost is and should be missing! + case network.PeerLAN, network.PeerInternet, network.PeerInvalid: // Important: PeerHost is and should be missing! if !profileSet.CheckFlag(profile.PeerToPeer) { log.Infof("firewall: denying connection %s, peer to peer connections (to an IP) not allowed", connection) connection.Deny("peer to peer connections (to an IP) not allowed") @@ -205,54 +211,54 @@ func DecideOnConnection(connection *network.Connection, pkt packet.Packet) { // check network scope switch connection.Domain { - case IncomingHost: + case network.IncomingHost: if !profileSet.CheckFlag(profile.Localhost) { log.Infof("firewall: denying connection %s, serving localhost not allowed", connection) connection.Block("serving localhost not allowed") return } - case IncomingLAN: + case network.IncomingLAN: if !profileSet.CheckFlag(profile.LAN) { log.Infof("firewall: denying connection %s, serving LAN not allowed", connection) connection.Deny("serving LAN not allowed") return } - case IncomingInternet: + case network.IncomingInternet: if !profileSet.CheckFlag(profile.Internet) { log.Infof("firewall: denying connection %s, serving Internet not allowed", connection) connection.Deny("serving Internet not allowed") return } - case IncomingInvalid: + case network.IncomingInvalid: log.Infof("firewall: denying connection %s, invalid IP address", connection) connection.Drop("invalid IP address") return - case PeerHost: + case network.PeerHost: if !profileSet.CheckFlag(profile.Localhost) { log.Infof("firewall: denying connection %s, accessing localhost not allowed", connection) connection.Block("accessing localhost not allowed") return } - case PeerLAN: + case network.PeerLAN: if !profileSet.CheckFlag(profile.LAN) { log.Infof("firewall: denying connection %s, accessing the LAN not allowed", connection) connection.Deny("accessing the LAN not allowed") return } - case PeerInternet: + case network.PeerInternet: if !profileSet.CheckFlag(profile.Internet) { log.Infof("firewall: denying connection %s, accessing the Internet not allowed", connection) connection.Deny("accessing the Internet not allowed") return } - case PeerInvalid: + case network.PeerInvalid: log.Infof("firewall: denying connection %s, invalid IP address", connection) connection.Deny("invalid IP address") return } log.Infof("firewall: accepting connection %s", connection) - connection.Accept() + connection.Accept("") } // DecideOnLink makes a decision about a link with the first packet. @@ -264,8 +270,8 @@ func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet // Profile.ListenPorts // check if there is a profile - profileSet := connection.Process().ProfileSet - if profile == nil { + profileSet := connection.Process().ProfileSet() + if profileSet == nil { log.Infof("firewall: no profile, denying %s", link) link.Block("no profile") return @@ -274,20 +280,20 @@ func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet // get remote Port protocol := pkt.GetIPHeader().Protocol - var remotePort uint16 - tcpUdpHeader := pkt.GetTCPUDPHeader() - if tcpUdpHeader != nil { - remotePort = tcpUdpHeader.DstPort + var dstPort uint16 + tcpUDPHeader := pkt.GetTCPUDPHeader() + if tcpUDPHeader != nil { + dstPort = tcpUDPHeader.DstPort } // check port list - permitted, ok := profileSet.CheckPort(connection.Direction, protocol, remotePort) + permitted, ok := profileSet.CheckPort(connection.Direction, uint8(protocol), dstPort) if ok { if permitted { log.Infof("firewall: accepting link %s", link) link.Accept("port whitelisted") } else { - log.Infof("firewall: denying link %s: port %d is blacklisted", link, remotePort) + log.Infof("firewall: denying link %s: port %d is blacklisted", link, dstPort) link.Deny("port blacklisted") } return @@ -295,14 +301,17 @@ func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet switch profileSet.GetProfileMode() { case profile.Whitelist: - log.Infof("firewall: denying link %s: port %d is not whitelisted", link, remotePort) + log.Infof("firewall: denying link %s: port %d is not whitelisted", link, dstPort) link.Deny("port is not whitelisted") + return case profile.Prompt: - log.Infof("firewall: denying link %s: port %d is blacklisted", link, remotePort) + log.Infof("firewall: accepting link %s: port %d is blacklisted", link, dstPort) link.Accept("port permitted (prompting is not yet implemented)") + return case profile.Blacklist: - log.Infof("firewall: denying link %s: port %d is blacklisted", link, remotePort) - link.Deny("port is not blacklisted") + log.Infof("firewall: accepting link %s: port %d is not blacklisted", link, dstPort) + link.Accept("port is not blacklisted") + return } log.Infof("firewall: accepting link %s", link) diff --git a/global/databases.go b/global/databases.go index ea3b9708..624370b6 100644 --- a/global/databases.go +++ b/global/databases.go @@ -7,10 +7,11 @@ import ( // module dependencies _ "github.com/Safing/portbase/database/dbmodule" _ "github.com/Safing/portbase/database/storage/badger" + _ "github.com/Safing/portmaster/status" ) func init() { - modules.Register("global", nil, start, nil, "database") + modules.Register("global", nil, start, nil, "database", "status") } func start() error { diff --git a/intel/ipinfo.go b/intel/ipinfo.go index eef7de29..e00921fa 100644 --- a/intel/ipinfo.go +++ b/intel/ipinfo.go @@ -59,10 +59,10 @@ func GetIPInfo(ip string) (*IPInfo, error) { // AddDomain adds a domain to the list and reports back if it was added, or was already present. func (ipi *IPInfo) AddDomain(domain string) (added bool) { + ipi.Lock() + defer ipi.Unlock() if !utils.StringInSlice(ipi.Domains, domain) { - newDomains := make([]string, 1, len(ipi.Domains)+1) - newDomains[0] = domain - ipi.Domains = append(newDomains, ipi.Domains...) + ipi.Domains = append([]string{domain}, ipi.Domains...) return true } return false @@ -70,11 +70,22 @@ func (ipi *IPInfo) AddDomain(domain string) (added bool) { // Save saves the IPInfo record to the database. func (ipi *IPInfo) Save() error { - ipi.SetKey(makeIPInfoKey(ipi.IP)) - return ipInfoDatabase.PutNew(ipi) + ipi.Lock() + if !ipi.KeyIsSet() { + ipi.SetKey(makeIPInfoKey(ipi.IP)) + } + ipi.Unlock() + return ipInfoDatabase.Put(ipi) } // FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " func (ipi *IPInfo) FmtDomains() string { return strings.Join(ipi.Domains, " or ") } + +// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " +func (ipi *IPInfo) String() string { + ipi.Lock() + defer ipi.Unlock() + return fmt.Sprintf("= 0 { diff --git a/main.go b/main.go index 5a5e1474..b34e0f95 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,13 @@ package main import ( + "flag" "fmt" "os" "os/signal" + "runtime/pprof" "syscall" + "time" "github.com/Safing/portbase/info" "github.com/Safing/portbase/log" @@ -12,11 +15,21 @@ import ( // include packages here + _ "github.com/Safing/portbase/api" _ "github.com/Safing/portbase/database/dbmodule" _ "github.com/Safing/portbase/database/storage/badger" _ "github.com/Safing/portmaster/firewall" + _ "github.com/Safing/portmaster/nameserver" ) +var ( + printStackOnExit bool +) + +func init() { + flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down") +} + func main() { // Set Info @@ -51,7 +64,24 @@ func main() { case <-signalCh: fmt.Println(" ") log.Warning("main: program was interrupted, shutting down.") - modules.Shutdown() + + if printStackOnExit { + fmt.Println("=== PRINTING STACK ===") + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + fmt.Println("=== END STACK ===") + } + + go func() { + modules.Shutdown() + os.Exit(0) + }() + go func() { + time.Sleep(3 * time.Second) + fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") + pprof.Lookup("goroutine").WriteTo(os.Stdout, 2) + os.Exit(1) + }() + case <-modules.ShuttingDown(): } diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 9e24a271..3cbd4940 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -12,10 +12,10 @@ import ( "github.com/Safing/portbase/modules" "github.com/Safing/portmaster/analytics/algs" + "github.com/Safing/portmaster/firewall" "github.com/Safing/portmaster/intel" "github.com/Safing/portmaster/network" "github.com/Safing/portmaster/network/netutils" - "github.com/Safing/portmaster/firewall" ) var ( @@ -138,19 +138,19 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { // log.Tracef("nameserver: took %s to get connection/process of %s request", time.Now().Sub(timed).String(), fqdn) // check profile before we even get intel and rr - if connection.Verdict == network.UNDECIDED { + if connection.GetVerdict() == network.UNDECIDED { // start = time.Now() firewall.DecideOnConnectionBeforeIntel(connection, fqdn) // log.Tracef("nameserver: took %s to make decision", time.Since(start)) } - if connection.Verdict == network.BLOCK || connection.Verdict == network.DROP { + if connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP { nxDomain(w, query) return } // get intel and RRs // start = time.Now() - domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().Profile.SecurityLevel) + domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().ProfileSet().SecurityLevel()) // log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start)) if rrCache == nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains @@ -160,14 +160,16 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { } // set intel + connection.Lock() connection.Intel = domainIntel + connection.Unlock() connection.Save() // do a full check with intel - if connection.Verdict == network.UNDECIDED { + if connection.GetVerdict() == network.UNDECIDED { rrCache = firewall.DecideOnConnectionAfterIntel(connection, fqdn, rrCache) } - if rrCache == nil || connection.Verdict == network.BLOCK || connection.Verdict == network.DROP { + if rrCache == nil || connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP { nxDomain(w, query) return } @@ -179,24 +181,27 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { ipInfo, err := intel.GetIPInfo(v.A.String()) if err != nil { ipInfo = &intel.IPInfo{ + IP: v.A.String(), Domains: []string{fqdn}, } - ipInfo.Create(v.A.String()) - } else { - ipInfo.Domains = append(ipInfo.Domains, fqdn) ipInfo.Save() + } else { + if ipInfo.AddDomain(fqdn) { + ipInfo.Save() + } } case *dns.AAAA: ipInfo, err := intel.GetIPInfo(v.AAAA.String()) if err != nil { ipInfo = &intel.IPInfo{ - IP: + IP: v.AAAA.String(), Domains: []string{fqdn}, } - ipInfo.Create(v.AAAA.String()) - } else { - ipInfo.Domains = append(ipInfo.Domains, fqdn) ipInfo.Save() + } else { + if ipInfo.AddDomain(fqdn) { + ipInfo.Save() + } } } } diff --git a/network/clean.go b/network/clean.go index 9b1b88e3..8aebf5d1 100644 --- a/network/clean.go +++ b/network/clean.go @@ -17,8 +17,11 @@ var ( func cleaner() { for { time.Sleep(cleanerTickDuration) + cleanLinks() + time.Sleep(10 * time.Second) cleanConnections() + time.Sleep(10 * time.Second) cleanProcesses() } } @@ -26,18 +29,21 @@ func cleaner() { func cleanLinks() { activeIDs := process.GetActiveConnectionIDs() - dataLock.Lock() - defer dataLock.Lock() - now := time.Now().Unix() deleteOlderThan := time.Now().Add(-deadLinksTimeout).Unix() + linksLock.RLock() + defer linksLock.RUnlock() + var found bool for key, link := range links { // delete dead links - if link.Ended > 0 && link.Ended < deleteOlderThan { - link.Delete() + link.Lock() + deleteThis := link.Ended > 0 && link.Ended < deleteOlderThan + link.Unlock() + if deleteThis { + go link.Delete() continue } @@ -53,21 +59,23 @@ func cleanLinks() { // mark end time if !found { link.Ended = now - link.Save() + go link.Save() } } } func cleanConnections() { - dataLock.Lock() - defer dataLock.Lock() + connectionsLock.RLock() + defer connectionsLock.RUnlock() threshold := time.Now().Add(-thresholdDuration).Unix() for _, conn := range connections { + conn.Lock() if conn.FirstLinkEstablished < threshold && conn.LinkCount == 0 { - conn.Delete() + go conn.Delete() } + conn.Unlock() } } diff --git a/network/connection.go b/network/connection.go index 112503be..6369f938 100644 --- a/network/connection.go +++ b/network/connection.go @@ -11,6 +11,7 @@ import ( "github.com/Safing/portbase/database/record" "github.com/Safing/portmaster/intel" + "github.com/Safing/portmaster/network/netutils" "github.com/Safing/portmaster/network/packet" "github.com/Safing/portmaster/process" ) @@ -35,17 +36,28 @@ type Connection struct { // Process returns the process that owns the connection. func (conn *Connection) Process() *process.Process { + conn.Lock() + defer conn.Unlock() + return conn.process } +// GetVerdict returns the current verdict. +func (conn *Connection) GetVerdict() Verdict { + conn.Lock() + defer conn.Unlock() + + return conn.Verdict +} + // Accept accepts the connection and adds the given reason. -func (conn *Link) Accept(reason string) { +func (conn *Connection) Accept(reason string) { conn.AddReason(reason) conn.UpdateVerdict(ACCEPT) } // Deny blocks or drops the connection depending on the connection direction and adds the given reason. -func (conn *Link) Deny(reason string) { +func (conn *Connection) Deny(reason string) { if conn.Direction { conn.Drop(reason) } else { @@ -54,13 +66,13 @@ func (conn *Link) Deny(reason string) { } // Block blocks the connection and adds the given reason. -func (conn *Link) Block(reason string) { +func (conn *Connection) Block(reason string) { conn.AddReason(reason) conn.UpdateVerdict(BLOCK) } // Drop drops the connection and adds the given reason. -func (conn *Link) Drop(reason string) { +func (conn *Connection) Drop(reason string) { conn.AddReason(reason) conn.UpdateVerdict(DROP) } @@ -72,7 +84,7 @@ func (conn *Connection) UpdateVerdict(newVerdict Verdict) { if newVerdict > conn.Verdict { conn.Verdict = newVerdict - conn.Save() + go conn.Save() } } @@ -103,13 +115,13 @@ func GetConnectionByFirstPacket(pkt packet.Packet) (*Connection, error) { // Incoming if direction { switch netutils.ClassifyIP(pkt.GetIPHeader().Src) { - case HostLocal: + case netutils.HostLocal: domain = IncomingHost - case LinkLocal, SiteLocal, LocalMulticast: + case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: domain = IncomingLAN - case Global, GlobalMulticast: + case netutils.Global, netutils.GlobalMulticast: domain = IncomingInternet - case Invalid: + case netutils.Invalid: domain = IncomingInvalid } @@ -135,13 +147,13 @@ func GetConnectionByFirstPacket(pkt packet.Packet) (*Connection, error) { // if no domain could be found, it must be a direct connection switch netutils.ClassifyIP(pkt.GetIPHeader().Dst) { - case HostLocal: + case netutils.HostLocal: domain = PeerHost - case LinkLocal, SiteLocal, LocalMulticast: + case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: domain = PeerLAN - case Global, GlobalMulticast: + case netutils.Global, netutils.GlobalMulticast: domain = PeerInternet - case Invalid: + case netutils.Invalid: domain = PeerInvalid } @@ -205,8 +217,8 @@ func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection // GetConnection fetches a connection object from the internal storage. func GetConnection(pid int, domain string) (conn *Connection, ok bool) { - dataLock.RLock() - defer dataLock.RUnlock() + connectionsLock.RLock() + defer connectionsLock.RUnlock() conn, ok = connections[fmt.Sprintf("%d/%s", pid, domain)] return } @@ -217,58 +229,63 @@ func (conn *Connection) makeKey() string { // Save saves the connection object in the storage and propagates the change. func (conn *Connection) Save() error { + conn.Lock() + defer conn.Unlock() + if conn.process == nil { return errors.New("cannot save connection without process") } - if conn.DatabaseKey() == "" { + if !conn.KeyIsSet() { conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Domain)) conn.CreateMeta() } key := conn.makeKey() - dataLock.RLock() + connectionsLock.RLock() _, ok := connections[key] - dataLock.RUnlock() + connectionsLock.RUnlock() if !ok { - dataLock.Lock() + connectionsLock.Lock() connections[key] = conn - dataLock.Unlock() + connectionsLock.Unlock() } - dbController.PushUpdate(conn) + go dbController.PushUpdate(conn) return nil } // Delete deletes a connection from the storage and propagates the change. func (conn *Connection) Delete() { - dataLock.Lock() - defer dataLock.Unlock() - delete(connections, conn.makeKey()) conn.Lock() - defer conn.Lock() + defer conn.Unlock() + + connectionsLock.Lock() + delete(connections, conn.makeKey()) + connectionsLock.Unlock() + conn.Meta().Delete() - dbController.PushUpdate(conn) + go dbController.PushUpdate(conn) conn.process.RemoveConnection() } // AddLink applies the connection to the link and increases sets counter and timestamps. func (conn *Connection) AddLink(link *Link) { link.Lock() - defer link.Unlock() link.connection = conn link.Verdict = conn.Verdict link.Inspect = conn.Inspect + link.Unlock() link.Save() conn.Lock() - defer conn.Unlock() conn.LinkCount++ conn.LastLinkEstablished = time.Now().Unix() if conn.FirstLinkEstablished == 0 { conn.FirstLinkEstablished = conn.LastLinkEstablished } + conn.Unlock() conn.Save() } @@ -276,6 +293,7 @@ func (conn *Connection) AddLink(link *Link) { func (conn *Connection) RemoveLink() { conn.Lock() defer conn.Unlock() + if conn.LinkCount > 0 { conn.LinkCount-- } @@ -283,13 +301,16 @@ func (conn *Connection) RemoveLink() { // String returns a string representation of Connection. func (conn *Connection) String() string { + conn.Lock() + defer conn.Unlock() + switch conn.Domain { - case "I": + case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid: if conn.process == nil { return "? <- *" } return fmt.Sprintf("%s <- *", conn.process.String()) - case "D": + case PeerHost, PeerLAN, PeerInternet, PeerInvalid: if conn.process == nil { return "? -> *" } diff --git a/network/database.go b/network/database.go index 1cd3f434..28147abf 100644 --- a/network/database.go +++ b/network/database.go @@ -14,9 +14,10 @@ import ( ) var ( - links map[string]*Link - connections map[string]*Connection - dataLock sync.RWMutex + links = make(map[string]*Link) + linksLock sync.RWMutex + connections = make(map[string]*Connection) + connectionsLock sync.RWMutex dbController *database.Controller ) @@ -29,9 +30,6 @@ type StorageInterface struct { // Get returns a database record. func (s *StorageInterface) Get(key string) (record.Record, error) { - dataLock.RLock() - defer dataLock.RUnlock() - splitted := strings.Split(key, "/") switch splitted[0] { case "tree": @@ -45,11 +43,15 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { } } case 3: + connectionsLock.RLock() + defer connectionsLock.RUnlock() conn, ok := connections[splitted[2]] if ok { return conn, nil } case 4: + linksLock.RLock() + defer linksLock.RUnlock() link, ok := links[splitted[3]] if ok { return link, nil @@ -77,22 +79,23 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { } } - dataLock.RLock() - defer dataLock.RUnlock() - // connections + connectionsLock.RLock() for _, conn := range connections { if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) { it.Next <- conn } } + connectionsLock.RUnlock() // links + linksLock.RLock() for _, link := range links { if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) { it.Next <- link } } + linksLock.RUnlock() it.Finish(nil) } diff --git a/network/environment/dbus_linux.go b/network/environment/dbus_linux.go index 7ffe0905..8e6751c3 100644 --- a/network/environment/dbus_linux.go +++ b/network/environment/dbus_linux.go @@ -24,6 +24,9 @@ func getNameserversFromDbus() ([]Nameserver, error) { var nameservers []Nameserver var err error + dbusConnLock.Lock() + defer dbusConnLock.Unlock() + if dbusConn == nil { dbusConn, err = dbus.SystemBus() } @@ -158,6 +161,9 @@ func getNameserversFromDbus() ([]Nameserver, error) { func getConnectivityStateFromDbus() (uint8, error) { var err error + dbusConnLock.Lock() + defer dbusConnLock.Unlock() + if dbusConn == nil { dbusConn, err = dbus.SystemBus() } diff --git a/network/link.go b/network/link.go index 2b85ea34..118ed7dc 100644 --- a/network/link.go +++ b/network/link.go @@ -46,16 +46,33 @@ type Link struct { // Connection returns the Connection the Link is part of func (link *Link) Connection() *Connection { + link.Lock() + defer link.Unlock() + return link.connection } +// GetVerdict returns the current verdict. +func (link *Link) GetVerdict() Verdict { + link.Lock() + defer link.Unlock() + + return link.Verdict +} + // FirewallHandlerIsSet returns whether a firewall handler is set or not func (link *Link) FirewallHandlerIsSet() bool { + link.Lock() + defer link.Unlock() + return link.firewallHandler != nil } // SetFirewallHandler sets the firewall handler for this link func (link *Link) SetFirewallHandler(handler FirewallHandler) { + link.Lock() + defer link.Unlock() + if link.firewallHandler == nil { link.firewallHandler = handler link.pktQueue = make(chan packet.Packet, 1000) @@ -67,16 +84,22 @@ func (link *Link) SetFirewallHandler(handler FirewallHandler) { // StopFirewallHandler unsets the firewall handler func (link *Link) StopFirewallHandler() { + link.Lock() + link.firewallHandler = nil + link.Unlock() link.pktQueue <- nil } // HandlePacket queues packet of Link for handling func (link *Link) HandlePacket(pkt packet.Packet) { + link.Lock() + defer link.Unlock() + if link.firewallHandler != nil { link.pktQueue <- pkt return } - log.Criticalf("network: link %s does not have a firewallHandler (maybe it's a copy), dropping packet", link) + log.Criticalf("network: link %s does not have a firewallHandler, dropping packet", link) pkt.Drop() } @@ -88,7 +111,7 @@ func (link *Link) Accept(reason string) { // Deny blocks or drops the link depending on the connection direction and adds the given reason. func (link *Link) Deny(reason string) { - if link.connection.Direction { + if link.connection != nil && link.connection.Direction { link.Drop(reason) } else { link.Block(reason) @@ -107,6 +130,17 @@ func (link *Link) Drop(reason string) { link.UpdateVerdict(DROP) } +// RerouteToNameserver reroutes the link to the portmaster nameserver. +func (link *Link) RerouteToNameserver() { + link.UpdateVerdict(RerouteToNameserver) +} + +// RerouteToTunnel reroutes the link to the tunnel entrypoint and adds the given reason for accepting the connection. +func (link *Link) RerouteToTunnel(reason string) { + link.AddReason(reason) + link.UpdateVerdict(RerouteToTunnel) +} + // UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts func (link *Link) UpdateVerdict(newVerdict Verdict) { link.Lock() @@ -114,7 +148,7 @@ func (link *Link) UpdateVerdict(newVerdict Verdict) { if newVerdict > link.Verdict { link.Verdict = newVerdict - link.Save() + go link.Save() } } @@ -138,54 +172,103 @@ func (link *Link) packetHandler() { for { pkt := <-link.pktQueue if pkt == nil { - break + return + } + link.Lock() + fwH := link.firewallHandler + link.Unlock() + if fwH != nil { + fwH(pkt, link) + } else { + link.ApplyVerdict(pkt) + } + } +} + +// ApplyVerdict appies the link verdict to a packet. +func (link *Link) ApplyVerdict(pkt packet.Packet) { + link.Lock() + defer link.Unlock() + + if link.VerdictPermanent { + switch link.Verdict { + case ACCEPT: + pkt.PermanentAccept() + case BLOCK: + pkt.PermanentBlock() + case DROP: + pkt.PermanentDrop() + case RerouteToNameserver: + pkt.RerouteToNameserver() + case RerouteToTunnel: + pkt.RerouteToTunnel() + default: + pkt.Drop() + } + } else { + switch link.Verdict { + case ACCEPT: + pkt.Accept() + case BLOCK: + pkt.Block() + case DROP: + pkt.Drop() + case RerouteToNameserver: + pkt.RerouteToNameserver() + case RerouteToTunnel: + pkt.RerouteToTunnel() + default: + pkt.Drop() } - link.firewallHandler(pkt, link) } - link.firewallHandler = nil } // Save saves the link object in the storage and propagates the change. func (link *Link) Save() error { + link.Lock() + defer link.Unlock() + if link.connection == nil { return errors.New("cannot save link without connection") } - if link.DatabaseKey() == "" { + if !link.KeyIsSet() { link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.connection.Process().Pid, link.connection.Domain, link.ID)) link.CreateMeta() } - dataLock.RLock() + linksLock.RLock() _, ok := links[link.ID] - dataLock.RUnlock() + linksLock.RUnlock() if !ok { - dataLock.Lock() + linksLock.Lock() links[link.ID] = link - dataLock.Unlock() + linksLock.Unlock() } - dbController.PushUpdate(link) + go dbController.PushUpdate(link) return nil } // Delete deletes a link from the storage and propagates the change. func (link *Link) Delete() { - dataLock.Lock() - defer dataLock.Unlock() - delete(links, link.ID) link.Lock() - defer link.Lock() + defer link.Unlock() + + linksLock.Lock() + delete(links, link.ID) + linksLock.Unlock() + link.Meta().Delete() - dbController.PushUpdate(link) + go dbController.PushUpdate(link) link.connection.RemoveLink() } // GetLink fetches a Link from the database from the default namespace for this object func GetLink(id string) (*Link, bool) { - dataLock.RLock() - defer dataLock.RUnlock() + linksLock.RLock() + defer linksLock.RUnlock() link, ok := links[id] return link, ok @@ -215,6 +298,7 @@ func CreateLinkFromPacket(pkt packet.Packet) *Link { func (link *Link) GetActiveInspectors() []bool { link.Lock() defer link.Unlock() + return link.activeInspectors } @@ -222,6 +306,7 @@ func (link *Link) GetActiveInspectors() []bool { func (link *Link) SetActiveInspectors(new []bool) { link.Lock() defer link.Unlock() + link.activeInspectors = new } @@ -229,6 +314,7 @@ func (link *Link) SetActiveInspectors(new []bool) { func (link *Link) GetInspectorData() map[uint8]interface{} { link.Lock() defer link.Unlock() + return link.inspectorData } @@ -236,11 +322,15 @@ func (link *Link) GetInspectorData() map[uint8]interface{} { func (link *Link) SetInspectorData(new map[uint8]interface{}) { link.Lock() defer link.Unlock() + link.inspectorData = new } // String returns a string representation of Link. func (link *Link) String() string { + link.Lock() + defer link.Unlock() + if link.connection == nil { return fmt.Sprintf("? <-> %s", link.RemoteAddress) } diff --git a/network/netutils/ip.go b/network/netutils/ip.go index c8066ba1..780c6901 100644 --- a/network/netutils/ip.go +++ b/network/netutils/ip.go @@ -15,8 +15,8 @@ const ( Invalid ) -// ClassifyAddress returns the classification for the given IP address. -func ClassifyAddress(ip net.IP) int8 { +// ClassifyIP returns the classification for the given IP address. +func ClassifyIP(ip net.IP) int8 { if ip4 := ip.To4(); ip4 != nil { // IPv4 switch { @@ -73,12 +73,12 @@ func ClassifyAddress(ip net.IP) int8 { // IPIsLocalhost returns whether the IP refers to the host itself. func IPIsLocalhost(ip net.IP) bool { - return ClassifyAddress(ip) == HostLocal + return ClassifyIP(ip) == HostLocal } // IPIsLAN returns true if the given IP is a site-local or link-local address. func IPIsLAN(ip net.IP) bool { - switch ClassifyAddress(ip) { + switch ClassifyIP(ip) { case SiteLocal: return true case LinkLocal: @@ -90,15 +90,15 @@ func IPIsLAN(ip net.IP) bool { // IPIsGlobal returns true if the given IP is a global address. func IPIsGlobal(ip net.IP) bool { - return ClassifyAddress(ip) == Global + return ClassifyIP(ip) == Global } // IPIsLinkLocal returns true if the given IP is a link-local address. func IPIsLinkLocal(ip net.IP) bool { - return ClassifyAddress(ip) == LinkLocal + return ClassifyIP(ip) == LinkLocal } // IPIsSiteLocal returns true if the given IP is a site-local address. func IPIsSiteLocal(ip net.IP) bool { - return ClassifyAddress(ip) == SiteLocal + return ClassifyIP(ip) == SiteLocal } diff --git a/network/netutils/ip_test.go b/network/netutils/ip_test.go index bc9848b4..b7ede062 100644 --- a/network/netutils/ip_test.go +++ b/network/netutils/ip_test.go @@ -13,7 +13,7 @@ func TestIPClassification(t *testing.T) { } func testClassification(t *testing.T, ip net.IP, expectedClassification int8) { - c := ClassifyAddress(ip) + c := ClassifyIP(ip) if c != expectedClassification { t.Errorf("%s is %s, expected %s", ip, classificationString(c), classificationString(expectedClassification)) } diff --git a/network/status.go b/network/status.go index 0763da23..a02756ae 100644 --- a/network/status.go +++ b/network/status.go @@ -12,6 +12,8 @@ const ( ACCEPT BLOCK DROP + RerouteToNameserver + RerouteToTunnel ) // Packer Directions diff --git a/process/database.go b/process/database.go index 4c3a8a81..a28ef40d 100644 --- a/process/database.go +++ b/process/database.go @@ -45,7 +45,7 @@ func (p *Process) Save() { p.Lock() defer p.Unlock() - if p.DatabaseKey() == "" { + if !p.KeyIsSet() { p.SetKey(fmt.Sprintf("network:tree/%d", p.Pid)) p.CreateMeta() } @@ -61,21 +61,22 @@ func (p *Process) Save() { } if dbControllerFlag.IsSet() { - dbController.PushUpdate(p) + go dbController.PushUpdate(p) } } // Delete deletes a process from the storage and propagates the change. func (p *Process) Delete() { - processesLock.Lock() - defer processesLock.Unlock() - delete(processes, p.Pid) p.Lock() defer p.Lock() - p.Meta().Delete() + processesLock.Lock() + delete(processes, p.Pid) + processesLock.Unlock() + + p.Meta().Delete() if dbControllerFlag.IsSet() { - dbController.PushUpdate(p) + go dbController.PushUpdate(p) } profile.DeactivateProfileSet(p.profileSet) @@ -88,9 +89,11 @@ func CleanProcessStorage(thresholdDuration time.Duration) { threshold := time.Now().Add(-thresholdDuration).Unix() for _, p := range processes { + p.Lock() if p.FirstConnectionEstablished < threshold && p.ConnectionCount == 0 { - p.Delete() + go p.Delete() } + p.Unlock() } } diff --git a/process/find.go b/process/find.go index b8e9989f..842922b9 100644 --- a/process/find.go +++ b/process/find.go @@ -4,13 +4,14 @@ import ( "errors" "net" + "github.com/Safing/portbase/log" "github.com/Safing/portmaster/network/packet" ) // Errors var ( - ErrConnectionNotFound = errors.New("could not find connection") - ErrProcessNotFound = errors.New("could not find process") + ErrConnectionNotFound = errors.New("could not find connection in system state tables") + ErrProcessNotFound = errors.New("could not find process in system state tables") ) // GetPidByPacket returns the pid of the owner of the packet. @@ -57,18 +58,23 @@ func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, er var pid int pid, direction, err = GetPidByPacket(pkt) - if pid < 0 { - return nil, direction, ErrConnectionNotFound - } if err != nil { return nil, direction, err } + if pid < 0 { + return nil, direction, ErrConnectionNotFound + } process, err = GetOrFindProcess(pid) if err != nil { return nil, direction, err } + err = process.FindProfiles() + if err != nil { + log.Errorf("failed to find profiles for process %s: %s", process.String(), err) + } + return process, direction, nil } @@ -113,6 +119,11 @@ func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, re return nil, err } + err = process.FindProfiles() + if err != nil { + log.Errorf("failed to find profiles for process %s: %s", process.String(), err) + } + return process, nil } diff --git a/process/matching.go b/process/matching.go index ccf31875..bce476a3 100644 --- a/process/matching.go +++ b/process/matching.go @@ -1,52 +1,59 @@ package process import ( - "fmt" - "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/query" "github.com/Safing/portbase/log" "github.com/Safing/portmaster/profile" - "github.com/Safing/portmaster/profile/index" +) + +var ( + profileDB = database.NewInterface(nil) ) // FindProfiles finds and assigns a profile set to the process. func (p *Process) FindProfiles() error { - // Get fingerprints of process + p.Lock() + defer p.Unlock() - // Check if user profile already exists, else create new - pathIdentifier := profile.GetPathIdentifier(p.Path) - indexRecord, err := index.Get(pathIdentifier) - if err != nil && err != database.ErrNotFound { - log.Errorf("process: could not get profile index for %s: %s", pathIdentifier, err) + // only find profiles if not already done. + if p.profileSet != nil { + return nil } - var possibleProfiles []*profile.Profile - if indexRecord != nil { - for _, profileID := range indexRecord.UserProfiles { - prof, err := profile.Get(profileID) - if err != nil { - log.Errorf("process: failed to load profile %s: %s", profileID, err) - } - possibleProfiles = append(possibleProfiles, prof) - } + // User Profile + it, err := profileDB.Query(query.New(profile.MakeProfileKey(profile.UserNamespace, "")).Where(query.Where("LinkedPath", query.SameAs, p.Path))) + if err != nil { + return err + } + + var userProfile *profile.Profile + for r := range it.Next { + it.Cancel() + userProfile, err = profile.EnsureProfile(r) + if err != nil { + return err } + break + } + if it.Err() != nil { + return it.Err() } - prof := selectProfile(p, possibleProfiles) - if prof == nil { + // create new profile if it does not exist. + if userProfile == nil { // create new profile - prof := profile.New() - prof.Name = p.ExecName - prof.AddFingerprint(&profile.Fingerprint{ - Type: "full_path", - Value: p.Path, - }) - // TODO: maybe add sha256_sum? - prof.MarkUsed() - prof.Save() + userProfile = profile.New() + userProfile.Name = p.ExecName + userProfile.LinkedPath = p.Path } + if userProfile.MarkUsed() { + userProfile.Save(profile.UserNamespace) + } + + // Stamp // Find/Re-evaluate Stamp profile // 1. check linked stamp profile // 2. if last check is was more than a week ago, fetch from stamp: @@ -56,13 +63,9 @@ func (p *Process) FindProfiles() error { // 6. link stamp profile to user profile // FIXME: implement! - if prof.MarkUsed() { - prof.Save() - } - - p.UserProfileKey = prof.Key() - p.profileSet = profile.NewSet(prof, nil) - p.Save() + p.UserProfileKey = userProfile.Key() + p.profileSet = profile.NewSet(userProfile, nil) + go p.Save() return nil } diff --git a/process/proc/gather.go b/process/proc/gather.go index 39a1fe60..98a83be4 100644 --- a/process/proc/gather.go +++ b/process/proc/gather.go @@ -13,14 +13,19 @@ const ( NoProcess ) -func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { +var ( + waitTime = 15 * time.Millisecond +) + +// GetPidOfConnection returns the PID of the given connection. +func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { uid, inode, ok := getConnectionSocket(localIP, localPort, protocol) if !ok { uid, inode, ok = getListeningSocket(localIP, localPort, protocol) for i := 0; i < 3 && !ok; i++ { // give kernel some time, then try again // log.Tracef("process: giving kernel some time to think") - time.Sleep(15 * time.Millisecond) + time.Sleep(waitTime) uid, inode, ok = getConnectionSocket(localIP, localPort, protocol) if !ok { uid, inode, ok = getListeningSocket(localIP, localPort, protocol) @@ -30,27 +35,48 @@ func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid return -1, NoSocket } } + pid, ok = GetPidOfInode(uid, inode) for i := 0; i < 3 && !ok; i++ { // give kernel some time, then try again // log.Tracef("process: giving kernel some time to think") - time.Sleep(15 * time.Millisecond) + time.Sleep(waitTime) pid, ok = GetPidOfInode(uid, inode) } if !ok { return -1, NoProcess } + return } -func GetPidOfIncomingConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { +// GetPidOfConnection returns the PID of the given incoming connection. +func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { uid, inode, ok := getListeningSocket(localIP, localPort, protocol) if !ok { - return -1, NoSocket + // for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version. + switch protocol { + case TCP4: + uid, inode, ok = getListeningSocket(localIP, localPort, TCP6) + case UDP4: + uid, inode, ok = getListeningSocket(localIP, localPort, UDP6) + } + + if !ok { + return -1, NoSocket + } } + pid, ok = GetPidOfInode(uid, inode) + for i := 0; i < 3 && !ok; i++ { + // give kernel some time, then try again + // log.Tracef("process: giving kernel some time to think") + time.Sleep(waitTime) + pid, ok = GetPidOfInode(uid, inode) + } if !ok { return -1, NoProcess } + return } diff --git a/process/proc/get.go b/process/proc/get.go index 9f3fc0ed..174ec823 100644 --- a/process/proc/get.go +++ b/process/proc/get.go @@ -6,39 +6,39 @@ import ( ) func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(TCP4, localIP, localPort, direction) + return search(TCP4, localIP, localPort, pktDirection) } func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(TCP6, localIP, localPort, direction) + return search(TCP6, localIP, localPort, pktDirection) } func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(UDP4, localIP, localPort, direction) + return search(UDP4, localIP, localPort, pktDirection) } func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(UDP6, localIP, localPort, direction) + return search(UDP6, localIP, localPort, pktDirection) } func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) { var status uint8 if pktDirection { - pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol) + pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol) if pid >= 0 { return pid, true, nil } - // pid, status = GetPidOfConnection(&localIP, localPort, protocol) + // pid, status = GetPidOfConnection(localIP, localPort, protocol) // if pid >= 0 { // return pid, false, nil // } } else { - pid, status = GetPidOfConnection(&localIP, localPort, protocol) + pid, status = GetPidOfConnection(localIP, localPort, protocol) if pid >= 0 { return pid, false, nil } - // pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol) + // pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol) // if pid >= 0 { // return pid, true, nil // } diff --git a/process/proc/sockets.go b/process/proc/sockets.go index 9dea4c00..49b443ea 100644 --- a/process/proc/sockets.go +++ b/process/proc/sockets.go @@ -81,7 +81,7 @@ var ( globalListeningUDP6 = make(map[uint16][]int) ) -func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int, int, bool) { +func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) { // listeningSocketsLock.Lock() // defer listeningSocketsLock.Unlock() @@ -98,10 +98,10 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]})) case TCP6: procFile = TCP6Data - localIPHex = hex.EncodeToString([]byte(*localIP)) + localIPHex = hex.EncodeToString([]byte(localIP)) case UDP6: procFile = UDP6Data - localIPHex = hex.EncodeToString([]byte(*localIP)) + localIPHex = hex.EncodeToString([]byte(localIP)) } localPortHex := fmt.Sprintf("%04X", localPort) @@ -162,38 +162,38 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int } -func getListeningSocket(localIP *net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) { +func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) { listeningSocketsLock.Lock() defer listeningSocketsLock.Unlock() - var addressListening *map[string][]int - var globalListening *map[uint16][]int + var addressListening map[string][]int + var globalListening map[uint16][]int switch protocol { case TCP4: - addressListening = &addressListeningTCP4 - globalListening = &globalListeningTCP4 + addressListening = addressListeningTCP4 + globalListening = globalListeningTCP4 case UDP4: - addressListening = &addressListeningUDP4 - globalListening = &globalListeningUDP4 + addressListening = addressListeningUDP4 + globalListening = globalListeningUDP4 case TCP6: - addressListening = &addressListeningTCP6 - globalListening = &globalListeningTCP6 + addressListening = addressListeningTCP6 + globalListening = globalListeningTCP6 case UDP6: - addressListening = &addressListeningUDP6 - globalListening = &globalListeningUDP6 + addressListening = addressListeningUDP6 + globalListening = globalListeningUDP6 } - data, ok := (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)] + data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)] if !ok { - data, ok = (*globalListening)[localPort] + data, ok = globalListening[localPort] } if ok { return data[0], data[1], true } updateListeners(protocol) - data, ok = (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)] + data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)] if !ok { - data, ok = (*globalListening)[localPort] + data, ok = globalListening[localPort] } if ok { return data[0], data[1], true @@ -206,7 +206,7 @@ func procDelimiter(c rune) bool { return unicode.IsSpace(c) || c == ':' } -func convertIPv4(data string) *net.IP { +func convertIPv4(data string) net.IP { decoded, err := hex.DecodeString(data) if err != nil { log.Warningf("process: could not parse IPv4 %s: %s", data, err) @@ -217,10 +217,10 @@ func convertIPv4(data string) *net.IP { return nil } ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) - return &ip + return ip } -func convertIPv6(data string) *net.IP { +func convertIPv6(data string) net.IP { decoded, err := hex.DecodeString(data) if err != nil { log.Warningf("process: could not parse IPv6 %s: %s", data, err) @@ -231,7 +231,7 @@ func convertIPv6(data string) *net.IP { return nil } ip := net.IP(decoded) - return &ip + return ip } func updateListeners(protocol uint8) { @@ -247,7 +247,7 @@ func updateListeners(protocol uint8) { } } -func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) *net.IP) (map[string][]int, map[uint16][]int) { +func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) { addressListening := make(map[string][]int) globalListening := make(map[uint16][]int) @@ -312,6 +312,7 @@ func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter return addressListening, globalListening } +// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS. func GetActiveConnectionIDs() []string { var connections []string @@ -323,7 +324,7 @@ func GetActiveConnectionIDs() []string { return connections } -func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) *net.IP) []string { +func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string { var connections []string // open file diff --git a/process/proc/sockets_test.go b/process/proc/sockets_test.go index 305c73a3..6bc77bfc 100644 --- a/process/proc/sockets_test.go +++ b/process/proc/sockets_test.go @@ -22,14 +22,14 @@ func TestSockets(t *testing.T) { t.Logf("addressListeningUDP6: %v", addressListeningUDP6) t.Logf("globalListeningUDP6: %v", globalListeningUDP6) - getListeningSocket(&net.IPv4zero, 53, TCP4) - getListeningSocket(&net.IPv4zero, 53, UDP4) - getListeningSocket(&net.IPv6zero, 53, TCP6) - getListeningSocket(&net.IPv6zero, 53, UDP6) + getListeningSocket(net.IPv4zero, 53, TCP4) + getListeningSocket(net.IPv4zero, 53, UDP4) + getListeningSocket(net.IPv6zero, 53, TCP6) + getListeningSocket(net.IPv6zero, 53, UDP6) // spotify: 192.168.0.102:5353 192.121.140.65:80 localIP := net.IPv4(192, 168, 127, 10) - uid, inode, ok := getConnectionSocket(&localIP, 46634, TCP4) + uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4) t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok) activeConnectionIDs := GetActiveConnectionIDs() diff --git a/process/process.go b/process/process.go index 89f8e3b1..65757dbe 100644 --- a/process/process.go +++ b/process/process.go @@ -49,11 +49,17 @@ type Process struct { // ProfileSet returns the assigned profile set. func (p *Process) ProfileSet() *profile.Set { + p.Lock() + defer p.Unlock() + return p.profileSet } // Strings returns a string represenation of process. func (p *Process) String() string { + p.Lock() + defer p.Unlock() + if p == nil { return "?" } @@ -64,6 +70,7 @@ func (p *Process) String() string { func (p *Process) AddConnection() { p.Lock() defer p.Unlock() + p.ConnectionCount++ p.LastConnectionEstablished = time.Now().Unix() if p.FirstConnectionEstablished == 0 { @@ -75,6 +82,7 @@ func (p *Process) AddConnection() { func (p *Process) RemoveConnection() { p.Lock() defer p.Unlock() + if p.ConnectionCount > 0 { p.ConnectionCount-- } @@ -235,9 +243,8 @@ func GetOrFindProcess(pid int) (*Process, error) { // Executable Information // FIXME: use os specific path seperator - splittedPath := strings.Split("/", new.Path) - new.ExecName = strings.ToTitle(splittedPath[len(splittedPath)-1]) - + splittedPath := strings.Split(new.Path, "/") + new.ExecName = splittedPath[len(splittedPath)-1] } // save to storage diff --git a/profile/database.go b/profile/database.go index 1e748c32..740ab4a9 100644 --- a/profile/database.go +++ b/profile/database.go @@ -12,9 +12,9 @@ import ( // Namespaces const ( - userNamespace = "user" - stampNamespace = "stamp" - specialNamespace = "special" + UserNamespace = "user" + StampNamespace = "stamp" + SpecialNamespace = "special" ) var ( diff --git a/profile/index/index.go b/profile/index/index.go index 17d13817..c5859990 100644 --- a/profile/index/index.go +++ b/profile/index/index.go @@ -90,7 +90,7 @@ func Get(fpType, id string) (*ProfileIndex, error) { // Save saves the Identifiers to the database func (pi *ProfileIndex) Save() error { - if pi.Key() == "" { + if !pi.KeyIsSet() { if pi.ID != "" { pi.SetKey(makeIndexRecordKey(pi.ID)) } else { diff --git a/profile/module.go b/profile/module.go index f6652aa1..37e6f709 100644 --- a/profile/module.go +++ b/profile/module.go @@ -7,7 +7,7 @@ var ( ) func init() { - modules.Register("profile", nil, start, stop, "database") + modules.Register("profile", nil, start, stop, "global", "database") } func start() error { diff --git a/profile/ports_test.go b/profile/ports_test.go index b068e40e..d40d3b7f 100644 --- a/profile/ports_test.go +++ b/profile/ports_test.go @@ -33,7 +33,9 @@ func TestPorts(t *testing.T) { }, }, } - if ports.String() != "TCP:[permit:22], securityLevel { - set.securityLevel = profileSecurityLevel + set.combinedSecurityLevel = profileSecurityLevel } else { - set.securityLevel = securityLevel + set.combinedSecurityLevel = securityLevel } - // update independence set.Unlock() + // update independence if set.CheckFlag(Independent) { set.Lock() set.independent = true @@ -72,6 +77,14 @@ func (set *Set) Update(securityLevel uint8) { } } +// SecurityLevel returns the applicable security level for the profile set. +func (set *Set) SecurityLevel() uint8 { + set.Lock() + defer set.Unlock() + + return set.combinedSecurityLevel +} + // GetProfileMode returns the active profile mode. func (set *Set) GetProfileMode() uint8 { switch { @@ -97,7 +110,7 @@ func (set *Set) CheckFlag(flag uint8) (active bool) { } if profile != nil { - active, ok := profile.Flags.Check(flag, set.securityLevel) + active, ok := profile.Flags.Check(flag, set.combinedSecurityLevel) if ok { return active } @@ -153,8 +166,11 @@ func (set *Set) CheckPort(listen bool, protocol uint8, port uint16) (permit, ok return false, false } -// SecurityLevel returns the highest prioritized security level. -func (set *Set) getProfileSecurityLevel() uint8 { +// getSecurityLevel returns the highest prioritized security level. +func (set *Set) getSecurityLevel() uint8 { + if set == nil { + return 0 + } for i, profile := range set.profiles { if i == 2 { diff --git a/profile/profileset_test.go b/profile/set_test.go similarity index 100% rename from profile/profileset_test.go rename to profile/set_test.go diff --git a/profile/specialprofiles.go b/profile/specialprofiles.go index 9af2559d..07ed8885 100644 --- a/profile/specialprofiles.go +++ b/profile/specialprofiles.go @@ -24,7 +24,7 @@ func initSpecialProfiles() (err error) { return err } globalProfile = makeDefaultGlobalProfile() - globalProfile.Save(specialNamespace) + globalProfile.Save(SpecialNamespace) } fallbackProfile, err = getSpecialProfile("fallback") @@ -33,12 +33,12 @@ func initSpecialProfiles() (err error) { return err } fallbackProfile = makeDefaultFallbackProfile() - fallbackProfile.Save(specialNamespace) + fallbackProfile.Save(SpecialNamespace) } return nil } func getSpecialProfile(ID string) (*Profile, error) { - return getProfile(specialNamespace, ID) + return getProfile(SpecialNamespace, ID) } diff --git a/profile/updates.go b/profile/updates.go index 55c60d3d..2bedce15 100644 --- a/profile/updates.go +++ b/profile/updates.go @@ -9,7 +9,7 @@ import ( ) func initUpdateListener() error { - sub, err := profileDB.Subscribe(query.New(makeProfileKey(specialNamespace, ""))) + sub, err := profileDB.Subscribe(query.New(MakeProfileKey(SpecialNamespace, ""))) if err != nil { return err } @@ -29,7 +29,7 @@ func updateListener(sub *database.Subscription) { continue } - profile, err := ensureProfile(r) + profile, err := EnsureProfile(r) if err != nil { log.Errorf("profile: received update for special profile, but could not read: %s", err) continue @@ -46,9 +46,9 @@ func updateListener(sub *database.Subscription) { specialProfileLock.Unlock() default: switch { - case strings.HasPrefix(profile.Key(), makeProfileKey(userNamespace, "")): + case strings.HasPrefix(profile.Key(), MakeProfileKey(UserNamespace, "")): updateActiveUserProfile(profile) - case strings.HasPrefix(profile.Key(), makeProfileKey(stampNamespace, "")): + case strings.HasPrefix(profile.Key(), MakeProfileKey(StampNamespace, "")): updateActiveStampProfile(profile) } } diff --git a/status/module.go b/status/module.go new file mode 100644 index 00000000..bd0c3a42 --- /dev/null +++ b/status/module.go @@ -0,0 +1,25 @@ +package status + +import ( + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/modules" +) + +func init() { + modules.Register("status", prep, nil, nil) +} + +func prep() error { + + if CurrentSecurityLevel() == SecurityLevelOff { + log.Infof("switching to default active security level: dynamic") + SetCurrentSecurityLevel(SecurityLevelDynamic) + } + + if SelectedSecurityLevel() == SecurityLevelOff { + log.Infof("switching to default selected security level: dynamic") + SetSelectedSecurityLevel(SecurityLevelDynamic) + } + + return nil +}