diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..2dde10d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +dnsonly +main + +*.exe diff --git a/README.md b/README.md index 7563da83..d39eb5d5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,41 @@ # Portmaster -The Portmaster is currently being revamped. You can check out the latest changes in the `develop` branch. +The Portmaster enables you to protect your data on your device. You are back in charge of your outgoing connections: you choose what data you share and what data stays private. + +## Current Status + +The Portmaster is currently in alpha. Expect dragons. +Supported platforms: + +- linux_amd64 +- windows_amd64 (_soon_) +- darwin_amd64 (_later_) + +## Usage + +Just download the portmaster from the releases page. + + ./portmaster -db=/opt/pm_db + # this will add some rules to iptables for traffic interception via nfqueue (and will clean up afterwards!) + + # then start the ui + ./portmaster -db=/opt/pm_db -ui + # missing files will be automatically download when first needed + +## Documentation + +Documentation _in progress_ can be found here: [http://docs.safing.io/](http://docs.safing.io/) + +## Dependencies + +#### Linux +- libnetfilter_queue + - debian/ubuntu: `sudo apt-get install libnetfilter-queue1` + - fedora: `sudo yum install libnetfilter_queue` + - arch: `sudo pacman -S libnetfilter_queue` +- [Network Manager](https://wiki.gnome.org/Projects/NetworkManager) (_optional_) + +#### Windows +- Windows 7 (with update KB3033929) or up + - [KB3033929](https://docs.microsoft.com/en-us/security-updates/SecurityAdvisories/2015/3033929) (a 2015 security update) is required for correctly verifying the driver signature +- Windows Server 2016 systems must have secure boot disabled. (_clarification needed_) diff --git a/build b/build new file mode 100755 index 00000000..bbde1f9f --- /dev/null +++ b/build @@ -0,0 +1,52 @@ +#!/bin/bash + +# get build data +if [[ "$BUILD_COMMIT" == "" ]]; then + BUILD_COMMIT=$(git describe --all --long --abbrev=99 --dirty 2>/dev/null) +fi +if [[ "$BUILD_USER" == "" ]]; then + BUILD_USER=$(id -un) +fi +if [[ "$BUILD_HOST" == "" ]]; then + BUILD_HOST=$(hostname -f) +fi +if [[ "$BUILD_DATE" == "" ]]; then + BUILD_DATE=$(date +%d.%m.%Y) +fi +if [[ "$BUILD_SOURCE" == "" ]]; then + BUILD_SOURCE=$(git remote -v | grep origin | cut -f2 | cut -d" " -f1 | head -n 1) +fi +if [[ "$BUILD_SOURCE" == "" ]]; then + BUILD_SOURCE=$(git remote -v | cut -f2 | cut -d" " -f1 | head -n 1) +fi +BUILD_BUILDOPTIONS=$(echo $* | sed "s/ /§/g") + +# check +if [[ "$BUILD_COMMIT" == "" ]]; then + echo "could not automatically determine BUILD_COMMIT, please supply manually as environment variable." + exit 1 +fi +if [[ "$BUILD_USER" == "" ]]; then + echo "could not automatically determine BUILD_USER, please supply manually as environment variable." + exit 1 +fi +if [[ "$BUILD_HOST" == "" ]]; then + echo "could not automatically determine BUILD_HOST, please supply manually as environment variable." + exit 1 +fi +if [[ "$BUILD_DATE" == "" ]]; then + echo "could not automatically determine BUILD_DATE, please supply manually as environment variable." + exit 1 +fi +if [[ "$BUILD_SOURCE" == "" ]]; then + echo "could not automatically determine BUILD_SOURCE, please supply manually as environment variable." + exit 1 +fi + +echo "Please notice, that this build script includes metadata into the build." +echo "This information is useful for debugging and license compliance." +echo "Run the compiled binary with the -version flag to see the information included." + +# build +BUILD_PATH="github.com/Safing/portbase/info" +go build -ldflags "-X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $* diff --git a/dnsonly.go b/dnsonly.go new file mode 100644 index 00000000..ca3e4a56 --- /dev/null +++ b/dnsonly.go @@ -0,0 +1,55 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/Safing/portbase/info" + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/modules" + + // include packages here + _ "github.com/Safing/portmaster/nameserver/only" +) + +func main() { + + // Set Info + info.Set("Portmaster (DNS only)", "0.2.0") + + // Start + err := modules.Start() + if err != nil { + if err == modules.ErrCleanExit { + os.Exit(0) + } else { + err = modules.Shutdown() + if err != nil { + log.Shutdown() + } + os.Exit(1) + } + } + + // Shutdown + // catch interrupt for clean shutdown + signalCh := make(chan os.Signal) + signal.Notify( + signalCh, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + ) + select { + case <-signalCh: + fmt.Println(" ") + log.Warning("main: program was interrupted, shutting down.") + modules.Shutdown() + case <-modules.ShuttingDown(): + } + +} diff --git a/firewall/config.go b/firewall/config.go new file mode 100644 index 00000000..b900a657 --- /dev/null +++ b/firewall/config.go @@ -0,0 +1,26 @@ +package firewall + +import ( + "github.com/Safing/portbase/config" +) + +var ( + permanentVerdicts config.BoolOption +) + +func registerConfig() error { + err := config.Register(&config.Option{ + Name: "Permanent Verdicts", + Key: "firewall/permanentVerdicts", + Description: "With permanent verdicts, control of a connection is fully handed back to the OS after the initial decision. This brings a great performance increase, but makes it impossible to change the decision of a link later on.", + ExpertiseLevel: config.ExpertiseLevelExpert, + OptType: config.OptTypeBool, + DefaultValue: true, + }) + if err != nil { + return err + } + permanentVerdicts = config.Concurrent.GetAsBool("firewall/permanentVerdicts", true) + + return nil +} diff --git a/firewall/firewall.go b/firewall/firewall.go index e8e52bca..49890737 100644 --- a/firewall/firewall.go +++ b/firewall/firewall.go @@ -1,28 +1,22 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package firewall import ( + "fmt" "net" "os" "sync/atomic" "time" - "github.com/Safing/safing-core/configuration" - "github.com/Safing/safing-core/firewall/inspection" - "github.com/Safing/safing-core/firewall/interception" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/modules" - "github.com/Safing/safing-core/network" - "github.com/Safing/safing-core/network/packet" - "github.com/Safing/safing-core/port17/entry" - "github.com/Safing/safing-core/port17/mode" - "github.com/Safing/safing-core/portmaster" - "github.com/Safing/safing-core/process" + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/modules" + "github.com/Safing/portmaster/firewall/inspection" + "github.com/Safing/portmaster/firewall/interception" + "github.com/Safing/portmaster/network" + "github.com/Safing/portmaster/network/packet" + "github.com/Safing/portmaster/process" ) var ( - firewallModule *modules.Module // localNet net.IPNet localhost net.IP dnsServer net.IPNet @@ -30,8 +24,6 @@ var ( packetsBlocked *uint64 packetsDropped *uint64 - config = configuration.Get() - localNet4 *net.IPNet // Yes, this would normally be 127.0.0.0/8 // TODO: figure out any side effects @@ -46,23 +38,30 @@ var ( ) func init() { + modules.Register("firewall", prep, start, stop, "global", "network", "nameserver", "profile") +} - var err error +func prep() (err error) { + + err = registerConfig() + if err != nil { + return err + } _, localNet4, err = net.ParseCIDR("127.0.0.0/24") // Yes, this would normally be 127.0.0.0/8 // TODO: figure out any side effects if err != nil { - log.Criticalf("firewall: failed to parse cidr 127.0.0.0/24: %s", err) + return fmt.Errorf("firewall: failed to parse cidr 127.0.0.0/24: %s", err) } _, tunnelNet4, err = net.ParseCIDR("127.17.0.0/16") if err != nil { - log.Criticalf("firewall: failed to parse cidr 127.17.0.0/16: %s", err) + return fmt.Errorf("firewall: failed to parse cidr 127.17.0.0/16: %s", err) } _, tunnelNet6, err = net.ParseCIDR("fd17::/64") if err != nil { - log.Criticalf("firewall: failed to parse cidr fd17::/64: %s", err) + return fmt.Errorf("firewall: failed to parse cidr fd17::/64: %s", err) } var pA uint64 @@ -71,20 +70,22 @@ func init() { packetsBlocked = &pB var pD uint64 packetsDropped = &pD + + return nil } -func Start() { - firewallModule = modules.Register("Firewall", 128) - defer firewallModule.StopComplete() - - // start interceptor - go interception.Start() +func start() error { go statLogger() + go run() + // go run() + // go run() + // go run() - // go run() - // go run() - // go run() - run() + return interception.Start() +} + +func stop() error { + return interception.Stop() } func handlePacket(pkt packet.Packet) { @@ -111,12 +112,6 @@ func handlePacket(pkt packet.Packet) { return } - // allow anything that goes to a tunnel entrypoint - if pkt.IsOutbound() && (pkt.GetIPHeader().Dst.Equal(tunnelEntry4) || pkt.GetIPHeader().Dst.Equal(tunnelEntry6)) { - pkt.PermanentAccept() - return - } - // log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetConnectionID()) // use this to time how long it takes process packet @@ -124,16 +119,16 @@ func handlePacket(pkt packet.Packet) { // defer log.Tracef("firewall: took %s to process packet %s", time.Now().Sub(timed).String(), pkt) // check if packet is destined for tunnel - switch pkt.IPVersion() { - case packet.IPv4: - if portmaster.TunnelNet4 != nil && portmaster.TunnelNet4.Contains(pkt.GetIPHeader().Dst) { - tunnelHandler(pkt) - } - case packet.IPv6: - if portmaster.TunnelNet6 != nil && portmaster.TunnelNet6.Contains(pkt.GetIPHeader().Dst) { - tunnelHandler(pkt) - } - } + // switch pkt.IPVersion() { + // case packet.IPv4: + // if TunnelNet4 != nil && TunnelNet4.Contains(pkt.GetIPHeader().Dst) { + // tunnelHandler(pkt) + // } + // case packet.IPv6: + // if TunnelNet6 != nil && TunnelNet6.Contains(pkt.GetIPHeader().Dst) { + // tunnelHandler(pkt) + // } + // } // associate packet to link and handle link, created := network.GetOrCreateLinkByPacket(pkt) @@ -146,7 +141,7 @@ func handlePacket(pkt packet.Packet) { link.HandlePacket(pkt) return } - verdict(pkt, link.Verdict) + verdict(pkt, link.GetVerdict()) } @@ -157,57 +152,68 @@ func initialHandler(pkt packet.Packet, link *network.Link) { if err != nil { if err != process.ErrConnectionNotFound { log.Warningf("firewall: could not find process of packet (dropping link %s): %s", pkt.String(), err) + link.Deny(fmt.Sprintf("could not find process or it does not exist (unsolicited packet): %s", err)) + } else { + log.Warningf("firewall: internal error finding process of packet (dropping link %s): %s", pkt.String(), err) + link.Deny(fmt.Sprintf("internal error finding process: %s", err)) } - link.UpdateVerdict(network.DROP) - verdict(pkt, network.DROP) + + if pkt.IsInbound() { + network.UnknownIncomingConnection.AddLink(link) + } else { + network.UnknownDirectConnection.AddLink(link) + } + + verdict(pkt, link.GetVerdict()) + link.StopFirewallHandler() + return } + // add new Link to Connection (and save both) + connection.AddLink(link) + // reroute dns requests to nameserver if connection.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 { - pkt.RerouteToNameserver() + link.RerouteToNameserver() + verdict(pkt, link.GetVerdict()) + link.StopFirewallHandler() return } - // persist connection - connection.CreateInProcessNamespace() - - // add new Link to Connection - connection.AddLink(link, pkt) - // make a decision if not made already - if connection.Verdict == network.UNDECIDED { - portmaster.DecideOnConnection(connection, pkt) + if connection.GetVerdict() == network.UNDECIDED { + DecideOnConnection(connection, pkt) } - if connection.Verdict != network.CANTSAY { - link.UpdateVerdict(connection.Verdict) + if connection.GetVerdict() == network.ACCEPT { + DecideOnLink(connection, link, pkt) } else { - portmaster.DecideOnLink(connection, link, pkt) + link.UpdateVerdict(connection.GetVerdict()) } // log decision logInitialVerdict(link) // TODO: link this to real status - port17Active := mode.Client() + // port17Active := mode.Client() switch { - case port17Active && link.Inspect: - // tunnel link, but also inspect (after reroute) - link.Tunneled = true - link.SetFirewallHandler(inspectThenVerdict) - verdict(pkt, link.Verdict) - case port17Active: - // tunnel link, don't inspect - link.Tunneled = true - link.StopFirewallHandler() - permanentVerdict(pkt, network.ACCEPT) + // case port17Active && link.Inspect: + // // tunnel link, but also inspect (after reroute) + // link.Tunneled = true + // link.SetFirewallHandler(inspectThenVerdict) + // verdict(pkt, link.GetVerdict()) + // case port17Active: + // // tunnel link, don't inspect + // link.Tunneled = true + // link.StopFirewallHandler() + // permanentVerdict(pkt, network.ACCEPT) case link.Inspect: link.SetFirewallHandler(inspectThenVerdict) inspectThenVerdict(pkt, link) default: link.StopFirewallHandler() - verdict(pkt, link.Verdict) + verdict(pkt, link.GetVerdict()) } } @@ -216,10 +222,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) { pktVerdict, continueInspection := inspection.RunInspectors(pkt, link) if continueInspection { // do not allow to circumvent link decision: e.g. to ACCEPT packets from a DROP-ed link - if pktVerdict > link.Verdict { + linkVerdict := link.GetVerdict() + if pktVerdict > linkVerdict { verdict(pkt, pktVerdict) } else { - verdict(pkt, link.Verdict) + verdict(pkt, linkVerdict) } return } @@ -227,13 +234,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) { // we are done with inspecting link.StopFirewallHandler() - config.Changed() - config.RLock() - link.VerdictPermanent = config.PermanentVerdicts - config.RUnlock() - + link.Lock() + defer link.Unlock() + link.VerdictPermanent = permanentVerdicts() if link.VerdictPermanent { - link.Save() + go link.Save() permanentVerdict(pkt, link.Verdict) } else { verdict(pkt, link.Verdict) @@ -254,6 +259,12 @@ func permanentVerdict(pkt packet.Packet, action network.Verdict) { atomic.AddUint64(packetsDropped, 1) pkt.PermanentDrop() return + case network.RerouteToNameserver: + pkt.RerouteToNameserver() + return + case network.RerouteToTunnel: + pkt.RerouteToTunnel() + return } pkt.Drop() } @@ -272,36 +283,46 @@ func verdict(pkt packet.Packet, action network.Verdict) { atomic.AddUint64(packetsDropped, 1) pkt.Drop() return + case network.RerouteToNameserver: + pkt.RerouteToNameserver() + return + case network.RerouteToTunnel: + pkt.RerouteToTunnel() + return } pkt.Drop() } -func tunnelHandler(pkt packet.Packet) { - tunnelInfo := portmaster.GetTunnelInfo(pkt.GetIPHeader().Dst) - if tunnelInfo == nil { - pkt.Block() - return - } - - entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords()) - log.Tracef("firewall: rerouting %s to tunnel entry point", pkt) - pkt.RerouteToTunnel() - return -} +// func tunnelHandler(pkt packet.Packet) { +// tunnelInfo := GetTunnelInfo(pkt.GetIPHeader().Dst) +// if tunnelInfo == nil { +// pkt.Block() +// return +// } +// +// entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords()) +// log.Tracef("firewall: rerouting %s to tunnel entry point", pkt) +// pkt.RerouteToTunnel() +// return +// } func logInitialVerdict(link *network.Link) { - // switch link.Verdict { + // switch link.GetVerdict() { // case network.ACCEPT: // log.Infof("firewall: accepting new link: %s", link.String()) // case network.BLOCK: // log.Infof("firewall: blocking new link: %s", link.String()) // case network.DROP: // log.Infof("firewall: dropping new link: %s", link.String()) + // case network.RerouteToNameserver: + // log.Infof("firewall: rerouting new link to nameserver: %s", link.String()) + // case network.RerouteToTunnel: + // log.Infof("firewall: rerouting new link to tunnel: %s", link.String()) // } } func logChangedVerdict(link *network.Link) { - // switch link.Verdict { + // switch link.GetVerdict() { // case network.ACCEPT: // log.Infof("firewall: change! - now accepting link: %s", link.String()) // case network.BLOCK: @@ -312,25 +333,26 @@ func logChangedVerdict(link *network.Link) { } func run() { - -packetProcessingLoop: for { select { - case <-firewallModule.Stop: - break packetProcessingLoop + case <-modules.ShuttingDown(): + return case pkt := <-interception.Packets: handlePacket(pkt) } } - } func statLogger() { for { - time.Sleep(10 * time.Second) - log.Tracef("firewall: packets accepted %d, blocked %d, dropped %d", atomic.LoadUint64(packetsAccepted), atomic.LoadUint64(packetsBlocked), atomic.LoadUint64(packetsDropped)) - atomic.StoreUint64(packetsAccepted, 0) - atomic.StoreUint64(packetsBlocked, 0) - atomic.StoreUint64(packetsDropped, 0) + select { + case <-modules.ShuttingDown(): + return + case <-time.After(10 * time.Second): + log.Tracef("firewall: packets accepted %d, blocked %d, dropped %d", atomic.LoadUint64(packetsAccepted), atomic.LoadUint64(packetsBlocked), atomic.LoadUint64(packetsDropped)) + atomic.StoreUint64(packetsAccepted, 0) + atomic.StoreUint64(packetsBlocked, 0) + atomic.StoreUint64(packetsDropped, 0) + } } } diff --git a/firewall/inspection/inspection.go b/firewall/inspection/inspection.go index 58f02d7d..a0061925 100644 --- a/firewall/inspection/inspection.go +++ b/firewall/inspection/inspection.go @@ -3,9 +3,10 @@ package inspection import ( - "github.com/Safing/safing-core/network" - "github.com/Safing/safing-core/network/packet" "sync" + + "github.com/Safing/portmaster/network" + "github.com/Safing/portmaster/network/packet" ) const ( @@ -40,24 +41,28 @@ func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool // inspectorsLock.Lock() // defer inspectorsLock.Unlock() - if link.ActiveInspectors == nil { - link.ActiveInspectors = make([]bool, len(inspectors), len(inspectors)) + activeInspectors := link.GetActiveInspectors() + if activeInspectors == nil { + activeInspectors = make([]bool, len(inspectors), len(inspectors)) + link.SetActiveInspectors(activeInspectors) } - if link.InspectorData == nil { - link.InspectorData = make(map[uint8]interface{}) + inspectorData := link.GetInspectorData() + if inspectorData == nil { + inspectorData = make(map[uint8]interface{}) + link.SetInspectorData(inspectorData) } continueInspection := false verdict := network.UNDECIDED - for key, skip := range link.ActiveInspectors { + for key, skip := range activeInspectors { if skip { continue } if link.Verdict > inspectVerdicts[key] { - link.ActiveInspectors[key] = true + activeInspectors[key] = true continue } @@ -78,16 +83,16 @@ func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool continueInspection = true case BLOCK_LINK: link.UpdateVerdict(network.BLOCK) - link.ActiveInspectors[key] = true + activeInspectors[key] = true if verdict < network.BLOCK { verdict = network.BLOCK } case DROP_LINK: link.UpdateVerdict(network.DROP) - link.ActiveInspectors[key] = true + activeInspectors[key] = true verdict = network.DROP case STOP_INSPECTING: - link.ActiveInspectors[key] = true + activeInspectors[key] = true } } diff --git a/firewall/inspection/tls/const.go b/firewall/inspection/tls/const.go index 89298c5d..07657014 100644 --- a/firewall/inspection/tls/const.go +++ b/firewall/inspection/tls/const.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package tls var ( diff --git a/firewall/inspection/tls/tls.go b/firewall/inspection/tls/tls.go index 63aa65c9..36aade25 100644 --- a/firewall/inspection/tls/tls.go +++ b/firewall/inspection/tls/tls.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package tls import ( @@ -12,14 +10,13 @@ import ( "github.com/google/gopacket/layers" "github.com/google/gopacket/tcpassembly" - "github.com/Safing/safing-core/configuration" - "github.com/Safing/safing-core/crypto/verify" - "github.com/Safing/safing-core/firewall/inspection" - "github.com/Safing/safing-core/firewall/inspection/tls/tlslib" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/network" - "github.com/Safing/safing-core/network/netutils" - "github.com/Safing/safing-core/network/packet" + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/firewall/inspection" + "github.com/Safing/portmaster/firewall/inspection/tls/tlslib" + "github.com/Safing/portmaster/firewall/inspection/tls/verify" + "github.com/Safing/portmaster/network" + "github.com/Safing/portmaster/network/netutils" + "github.com/Safing/portmaster/network/packet" ) // TODO: @@ -31,8 +28,6 @@ var ( tlsInspectorIndex int assemblerManager *netutils.SimpleStreamAssemblerManager assembler *tcpassembly.Assembler - - config = configuration.Get() ) const ( diff --git a/firewall/inspection/tls/tls_test.go b/firewall/inspection/tls/tls_test.go index b3e89e0f..6e7b823b 100644 --- a/firewall/inspection/tls/tls_test.go +++ b/firewall/inspection/tls/tls_test.go @@ -6,7 +6,7 @@ import ( "fmt" "testing" - "github.com/Safing/safing-core/firewall/inspection/tls/tlslib" + "github.com/Safing/portmaster/firewall/inspection/tls/tlslib" ) var clientHelloSample = []byte{ diff --git a/firewall/inspection/tls/verify/cert.go b/firewall/inspection/tls/verify/cert.go index 0af7f120..27f92012 100644 --- a/firewall/inspection/tls/verify/cert.go +++ b/firewall/inspection/tls/verify/cert.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package verify import ( @@ -14,15 +12,15 @@ import ( "strings" "github.com/cloudflare/cfssl/crypto/pkcs7" - datastore "github.com/ipfs/go-datastore" - "github.com/Safing/safing-core/crypto/hash" - "github.com/Safing/safing-core/database" + "github.com/Safing/portbase/crypto/hash" + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/record" ) // Cert saves a certificate. type Cert struct { - database.Base + record.Record cert *x509.Certificate Raw []byte @@ -120,7 +118,7 @@ func (m *Cert) CreateRevokedCert(caID string, serialNumber *big.Int) error { } // CreateInNamespace saves Cert with the provided name in the provided namespace. -func (m *Cert) CreateInNamespace(namespace *datastore.Key, name string) error { +func (m *Cert) CreateInNamespace(namespace string, name string) error { return m.CreateObject(namespace, name, m) } @@ -140,7 +138,7 @@ func GetCertWithSPKI(spki []byte) (*Cert, error) { } // GetCertFromNamespace gets Cert with the provided name from the provided namespace. -func GetCertFromNamespace(namespace *datastore.Key, name string) (*Cert, error) { +func GetCertFromNamespace(namespace string, name string) (*Cert, error) { object, err := database.GetAndEnsureModel(namespace, name, certModel) if err != nil { return nil, err diff --git a/firewall/inspection/tls/verify/cert_test.go b/firewall/inspection/tls/verify/cert_test.go index 5a9d2312..d9ebe1fd 100644 --- a/firewall/inspection/tls/verify/cert_test.go +++ b/firewall/inspection/tls/verify/cert_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package verify import ( diff --git a/firewall/inspection/tls/verify/crl.go b/firewall/inspection/tls/verify/crl.go index 6f5136ef..7124ed52 100644 --- a/firewall/inspection/tls/verify/crl.go +++ b/firewall/inspection/tls/verify/crl.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package verify import ( @@ -14,16 +12,15 @@ import ( "sync" "time" - datastore "github.com/ipfs/go-datastore" - - "github.com/Safing/safing-core/crypto/hash" - "github.com/Safing/safing-core/database" - "github.com/Safing/safing-core/log" + "github.com/Safing/portbase/crypto/hash" + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/log" ) // CARevocationInfo saves Information on revokation of Certificates of a Certificate Authority. type CARevocationInfo struct { - database.Base + record.Record CRLDistributionPoints []string OCSPServers []string @@ -39,23 +36,17 @@ type CARevocationInfo struct { } var ( - caRevocationInfoModel *CARevocationInfo // only use this as parameter for database.EnsureModel-like functions - dupCrlReqMap = make(map[string]*sync.Mutex) dupCrlReqLock sync.Mutex ) -func init() { - database.RegisterModel(caRevocationInfoModel, func() database.Model { return new(CARevocationInfo) }) -} - // Create saves CARevocationInfo with the provided name in the default namespace. func (m *CARevocationInfo) Create(name string) error { return m.CreateObject(&database.CARevocationInfoCache, name, m) } // CreateInNamespace saves CARevocationInfo with the provided name in the provided namespace. -func (m *CARevocationInfo) CreateInNamespace(namespace *datastore.Key, name string) error { +func (m *CARevocationInfo) CreateInNamespace(namespace string, name string) error { return m.CreateObject(namespace, name, m) } @@ -78,7 +69,7 @@ func GetCARevocationInfo(name string) (*CARevocationInfo, error) { } // GetCARevocationInfoFromNamespace fetches CARevocationInfo with the provided name from the provided namespace. -func GetCARevocationInfoFromNamespace(namespace *datastore.Key, name string) (*CARevocationInfo, error) { +func GetCARevocationInfoFromNamespace(namespace string, name string) (*CARevocationInfo, error) { object, err := database.GetAndEnsureModel(namespace, name, caRevocationInfoModel) if err != nil { return nil, err diff --git a/firewall/inspection/tls/verify/ocsp.go b/firewall/inspection/tls/verify/ocsp.go index e638c947..b4459de5 100644 --- a/firewall/inspection/tls/verify/ocsp.go +++ b/firewall/inspection/tls/verify/ocsp.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package verify import ( @@ -16,8 +14,8 @@ import ( "golang.org/x/crypto/ocsp" - "github.com/Safing/safing-core/crypto/hash" - "github.com/Safing/safing-core/log" + "github.com/Safing/portbase/crypto/hash" + "github.com/Safing/portbase/log" ) var ( diff --git a/firewall/inspection/tls/verify/verify.go b/firewall/inspection/tls/verify/verify.go index ba1e6d1f..f162f842 100644 --- a/firewall/inspection/tls/verify/verify.go +++ b/firewall/inspection/tls/verify/verify.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package verify import ( @@ -8,9 +6,8 @@ import ( "fmt" "time" - "github.com/Safing/safing-core/configuration" - "github.com/Safing/safing-core/crypto/hash" - "github.com/Safing/safing-core/database" + "github.com/Safing/portbase/crypto/hash" + "github.com/Safing/portbase/database" ) // useful references: @@ -24,10 +21,6 @@ import ( // RE: https://www.grc.com/revocation/crlsets.htm // RE: RE: https://www.imperialviolet.org/2014/04/29/revocationagain.html -var ( - config = configuration.Get() -) - // FullCheckBytes does a full certificate check, certificates are provided as raw bytes. // It parses the raw certificates and calls FullCheck. func FullCheckBytes(name string, certBytes [][]byte) (bool, error) { diff --git a/firewall/interception/interception_linux.go b/firewall/interception/interception_linux.go index 88bdc0e5..4fd6897e 100644 --- a/firewall/interception/interception_linux.go +++ b/firewall/interception/interception_linux.go @@ -2,10 +2,19 @@ package interception -import "github.com/Safing/safing-core/network/packet" +import "github.com/Safing/portmaster/network/packet" -var Packets chan packet.Packet - -func init() { +var ( + // Packets channel for feeding the firewall. Packets = make(chan packet.Packet, 1000) +) + +// Start starts the interception. +func Start() error { + return StartNfqueueInterception() +} + +// Stop starts the interception. +func Stop() error { + return StopNfqueueInterception() } diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index 60be31a1..ddebb0b7 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -1,31 +1,31 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package interception import ( - "github.com/Safing/safing-core/firewall/interception/windivert" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/modules" - "github.com/Safing/safing-core/network/packet" + "fmt" + + "github.com/Safing/portmaster/firewall/interception/windivert" + "github.com/Safing/portmaster/network/packet" ) var Packets chan packet.Packet func init() { + // Packets channel for feeding the firewall. Packets = make(chan packet.Packet, 1000) } -func Start() { - - windivertModule := modules.Register("Firewall:Interception:WinDivert", 192) +// Start starts the interception. +func Start() error { wd, err := windivert.New("/WinDivert.dll", "") if err != nil { - log.Criticalf("firewall/interception: could not init windivert: %s", err) - } else { - wd.Packets(Packets) + return fmt.Errorf("firewall/interception: could not init windivert: %s", err) } - <-windivertModule.Stop - windivertModule.StopComplete() + return wd.Packets(Packets) +} + +// Stop starts the interception. +func Stop() error { + return nil } diff --git a/firewall/interception/nfqueue/multiqueue.go b/firewall/interception/nfqueue/multiqueue.go index 4ee43195..6fd4b71c 100644 --- a/firewall/interception/nfqueue/multiqueue.go +++ b/firewall/interception/nfqueue/multiqueue.go @@ -2,45 +2,48 @@ package nfqueue -import ( - "github.com/Safing/safing-core/network/packet" - "sync" -) +// suspended for now -type multiQueue struct { - qs []*nfQueue -} - -func NewMultiQueue(min, max uint16) (mq *multiQueue) { - mq = &multiQueue{make([]*nfQueue, 0, max-min)} - for i := min; i < max; i++ { - mq.qs = append(mq.qs, NewNFQueue(i)) - } - return mq -} - -func (mq *multiQueue) Process() <-chan packet.Packet { - var ( - wg sync.WaitGroup - out = make(chan packet.Packet, len(mq.qs)) - ) - for _, q := range mq.qs { - wg.Add(1) - go func(ch <-chan packet.Packet) { - for pkt := range ch { - out <- pkt - } - wg.Done() - }(q.Process()) - } - go func() { - wg.Wait() - close(out) - }() - return out -} -func (mq *multiQueue) Destroy() { - for _, q := range mq.qs { - q.Destroy() - } -} +// import ( +// "sync" +// +// "github.com/Safing/portmaster/network/packet" +// ) +// +// type multiQueue struct { +// qs []*NFQueue +// } +// +// func NewMultiQueue(min, max uint16) (mq *multiQueue) { +// mq = &multiQueue{make([]*NFQueue, 0, max-min)} +// for i := min; i < max; i++ { +// mq.qs = append(mq.qs, NewNFQueue(i)) +// } +// return mq +// } +// +// func (mq *multiQueue) Process() <-chan packet.Packet { +// var ( +// wg sync.WaitGroup +// out = make(chan packet.Packet, len(mq.qs)) +// ) +// for _, q := range mq.qs { +// wg.Add(1) +// go func(ch <-chan packet.Packet) { +// for pkt := range ch { +// out <- pkt +// } +// wg.Done() +// }(q.Process()) +// } +// go func() { +// wg.Wait() +// close(out) +// }() +// return out +// } +// func (mq *multiQueue) Destroy() { +// for _, q := range mq.qs { +// q.Destroy() +// } +// } diff --git a/firewall/interception/nfqueue/nfqueue.go b/firewall/interception/nfqueue/nfqueue.go index 1f991bea..aedd1850 100644 --- a/firewall/interception/nfqueue/nfqueue.go +++ b/firewall/interception/nfqueue/nfqueue.go @@ -17,17 +17,19 @@ import ( "syscall" "time" "unsafe" + "errors" + "fmt" - "github.com/Safing/safing-core/network/packet" + "github.com/Safing/portmaster/network/packet" ) -var queues map[uint16]*nfQueue +var queues map[uint16]*NFQueue func init() { - queues = make(map[uint16]*nfQueue) + queues = make(map[uint16]*NFQueue) } -type nfQueue struct { +type NFQueue struct { DefaultVerdict uint32 Timeout time.Duration qid uint16 @@ -38,83 +40,77 @@ type nfQueue struct { fd int lk sync.Mutex - pktch chan packet.Packet + Packets chan packet.Packet } -func NewNFQueue(qid uint16) (nfq *nfQueue) { +func NewNFQueue(qid uint16) (nfq *NFQueue, err error) { if os.Geteuid() != 0 { - panic("Must be ran by root.") + return nil, errors.New("must be root to intercept packets") } - nfq = &nfQueue{DefaultVerdict: NFQ_ACCEPT, Timeout: 100 * time.Millisecond, qid: qid, qidptr: &qid} + nfq = &NFQueue{DefaultVerdict: NFQ_ACCEPT, Timeout: 100 * time.Millisecond, qid: qid, qidptr: &qid} queues[nfq.qid] = nfq - return nfq -} -/* -This returns a channel that will recieve packets, -the user then must call pkt.Accept() or pkt.Drop() -*/ -func (this *nfQueue) Process() <-chan packet.Packet { - if this.h != nil { - return this.pktch + err = nfq.init() + if err != nil { + return nil, err } - this.init() go func() { runtime.LockOSThread() - C.loop_for_packets(this.h) + C.loop_for_packets(nfq.h) }() - return this.pktch + return nfq, nil } -func (this *nfQueue) init() { +func (this *NFQueue) init() error { var err error if this.h, err = C.nfq_open(); err != nil || this.h == nil { - panic(err) + fmt.Errorf("could not open nfqueue: %s", err) } //if this.qh, err = C.nfq_create_queue(this.h, qid, C.get_cb(), unsafe.Pointer(nfq)); err != nil || this.qh == nil { - this.pktch = make(chan packet.Packet, 1) + this.Packets = make(chan packet.Packet, 1) if C.nfq_unbind_pf(this.h, C.AF_INET) < 0 { this.Destroy() - panic("nfq_unbind_pf(AF_INET) failed, are you running root?.") + return errors.New("nfq_unbind_pf(AF_INET) failed, are you root?") } if C.nfq_unbind_pf(this.h, C.AF_INET6) < 0 { this.Destroy() - panic("nfq_unbind_pf(AF_INET6) failed.") + return errors.New("nfq_unbind_pf(AF_INET6) failed") } if C.nfq_bind_pf(this.h, C.AF_INET) < 0 { this.Destroy() - panic("nfq_bind_pf(AF_INET) failed.") + return errors.New("nfq_bind_pf(AF_INET) failed") } - if C.nfq_bind_pf(this.h, C.AF_INET6) < 0 { this.Destroy() - panic("nfq_bind_pf(AF_INET6) failed.") + return errors.New("nfq_bind_pf(AF_INET6) failed") } if this.qh, err = C.create_queue(this.h, C.uint16_t(this.qid)); err != nil || this.qh == nil { C.nfq_close(this.h) - panic(err) + return fmt.Errorf("could not create queue: %s", err) } this.fd = int(C.nfq_fd(this.h)) if C.nfq_set_mode(this.qh, C.NFQNL_COPY_PACKET, 0xffff) < 0 { this.Destroy() - panic("nfq_set_mode(NFQNL_COPY_PACKET) failed.") + return errors.New("nfq_set_mode(NFQNL_COPY_PACKET) failed") } if C.nfq_set_queue_maxlen(this.qh, 1024*8) < 0 { this.Destroy() - panic("nfq_set_queue_maxlen(1024 * 8) failed.") + return errors.New("nfq_set_queue_maxlen(1024 * 8) failed") } + + return nil } -func (this *nfQueue) Destroy() { +func (this *NFQueue) Destroy() { this.lk.Lock() defer this.lk.Unlock() @@ -131,12 +127,12 @@ func (this *nfQueue) Destroy() { } // TODO: don't close, we're exiting anyway - // if this.pktch != nil { - // close(this.pktch) + // if this.Packets != nil { + // close(this.Packets) // } } -func (this *nfQueue) Valid() bool { +func (this *NFQueue) Valid() bool { return this.h != nil && this.qh != nil } @@ -148,7 +144,7 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32, qidptr := (*uint16)(data) qid := uint16(*qidptr) - // nfq := (*nfQueue)(nfqptr) + // nfq := (*NFQueue)(nfqptr) new_version := version ipver := packet.IPVersion(new_version) ipsz := C.int(ipver.ByteSize()) @@ -187,7 +183,7 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32, // fmt.Printf("%s queuing packet\n", time.Now().Format("060102 15:04:05.000")) // BUG: "panic: send on closed channel" when shutting down - queues[qid].pktch <- &pkt + queues[qid].Packets <- &pkt select { case v = <-pkt.verdict: diff --git a/firewall/interception/nfqueue/packet.go b/firewall/interception/nfqueue/packet.go index 37e2ebd5..a05abc6d 100644 --- a/firewall/interception/nfqueue/packet.go +++ b/firewall/interception/nfqueue/packet.go @@ -5,7 +5,7 @@ package nfqueue import ( "fmt" - "github.com/Safing/safing-core/network/packet" + "github.com/Safing/portmaster/network/packet" ) var ( diff --git a/firewall/interception/nfqueue.go b/firewall/interception/nfqueue_linux.go similarity index 71% rename from firewall/interception/nfqueue.go rename to firewall/interception/nfqueue_linux.go index b9ef25e2..a2142494 100644 --- a/firewall/interception/nfqueue.go +++ b/firewall/interception/nfqueue_linux.go @@ -1,31 +1,33 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -// +build linux - package interception import ( + "fmt" "sort" "strings" "github.com/coreos/go-iptables/iptables" - "github.com/Safing/safing-core/firewall/interception/nfqueue" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/modules" + "github.com/Safing/portmaster/firewall/interception/nfqueue" ) // iptables -A OUTPUT -p icmp -j", "NFQUEUE", "--queue-num", "1", "--queue-bypass -var nfqueueModule *modules.Module +var ( + v4chains []string + v4rules []string + v4once []string -var v4chains []string -var v4rules []string -var v4once []string + v6chains []string + v6rules []string + v6once []string -var v6chains []string -var v6rules []string -var v6once []string + out4Queue *nfqueue.NFQueue + in4Queue *nfqueue.NFQueue + out6Queue *nfqueue.NFQueue + in6Queue *nfqueue.NFQueue + + shutdownSignal = make(chan struct{}) +) func init() { @@ -100,8 +102,8 @@ func init() { } // Reverse because we'd like to insert in a loop - sort.Reverse(sort.StringSlice(v4once)) - sort.Reverse(sort.StringSlice(v6once)) + _ = sort.Reverse(sort.StringSlice(v4once)) // silence vet (sort is used just like in the docs) + _ = sort.Reverse(sort.StringSlice(v6once)) // silence vet (sort is used just like in the docs) } @@ -127,9 +129,10 @@ func activateNfqueueFirewall() error { } } + var ok bool for _, rule := range v4once { splittedRule := strings.Split(rule, " ") - ok, err := ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...) + ok, err = ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...) if err != nil { return err } @@ -183,9 +186,10 @@ func deactivateNfqueueFirewall() error { return err } + var ok bool for _, rule := range v4once { splittedRule := strings.Split(rule, " ") - ok, err := ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...) + ok, err = ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...) if err != nil { return err } @@ -198,10 +202,10 @@ func deactivateNfqueueFirewall() error { for _, chain := range v4chains { splittedRule := strings.Split(chain, " ") - if err := ip4tables.ClearChain(splittedRule[0], splittedRule[1]); err != nil { + if err = ip4tables.ClearChain(splittedRule[0], splittedRule[1]); err != nil { return err } - if err := ip4tables.DeleteChain(splittedRule[0], splittedRule[1]); err != nil { + if err = ip4tables.DeleteChain(splittedRule[0], splittedRule[1]); err != nil { return err } } @@ -238,70 +242,84 @@ func deactivateNfqueueFirewall() error { return nil } -func Start() { +// StartNfqueueInterception starts the nfqueue interception. +func StartNfqueueInterception() (err error) { - nfqueueModule = modules.Register("Firewall:Interception:Nfqueue", 192) - - if err := activateNfqueueFirewall(); err != nil { - log.Criticalf("could not activate firewall for nfqueue: %q", err) + err = activateNfqueueFirewall() + if err != nil { + Stop() + return fmt.Errorf("could not initialize nfqueue: %s", err) } - out4Queue := nfqueue.NewNFQueue(17040) - in4Queue := nfqueue.NewNFQueue(17140) - out6Queue := nfqueue.NewNFQueue(17060) - in6Queue := nfqueue.NewNFQueue(17160) + out4Queue, err = nfqueue.NewNFQueue(17040) + if err != nil { + Stop() + return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err) + } + in4Queue, err = nfqueue.NewNFQueue(17140) + if err != nil { + Stop() + return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err) + } + out6Queue, err = nfqueue.NewNFQueue(17060) + if err != nil { + Stop() + return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err) + } + in6Queue, err = nfqueue.NewNFQueue(17160) + if err != nil { + Stop() + return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err) + } - out4Channel := out4Queue.Process() - // if err != nil { - // log.Criticalf("could not open nfqueue out4") - // } else { - defer out4Queue.Destroy() - // } - in4Channel := in4Queue.Process() - // if err != nil { - // log.Criticalf("could not open nfqueue in4") - // } else { - defer in4Queue.Destroy() - // } - out6Channel := out6Queue.Process() - // if err != nil { - // log.Criticalf("could not open nfqueue out6") - // } else { - defer out6Queue.Destroy() - // } - in6Channel := in6Queue.Process() - // if err != nil { - // log.Criticalf("could not open nfqueue in6") - // } else { - defer in6Queue.Destroy() - // } + go handleInterception() + return nil +} -packetInterceptionLoop: +// StopNfqueueInterception stops the nfqueue interception. +func StopNfqueueInterception() error { + defer close(shutdownSignal) + + if out4Queue != nil { + out4Queue.Destroy() + } + if in4Queue != nil { + in4Queue.Destroy() + } + if out6Queue != nil { + out6Queue.Destroy() + } + if in6Queue != nil { + in6Queue.Destroy() + } + + err := deactivateNfqueueFirewall() + if err != nil { + return fmt.Errorf("interception: error while deactivating nfqueue: %s", err) + } + + return nil +} + +func handleInterception() { for { select { - case <-nfqueueModule.Stop: - break packetInterceptionLoop - case pkt := <-out4Channel: + case <-shutdownSignal: + return + case pkt := <-out4Queue.Packets: pkt.SetOutbound() Packets <- pkt - case pkt := <-in4Channel: + case pkt := <-in4Queue.Packets: pkt.SetInbound() Packets <- pkt - case pkt := <-out6Channel: + case pkt := <-out6Queue.Packets: pkt.SetOutbound() Packets <- pkt - case pkt := <-in6Channel: + case pkt := <-in6Queue.Packets: pkt.SetInbound() Packets <- pkt } } - - if err := deactivateNfqueueFirewall(); err != nil { - log.Criticalf("could not deactivate firewall for nfqueue: %q", err) - } - - nfqueueModule.StopComplete() - } func stringInSlice(slice []string, value string) bool { diff --git a/firewall/interception/windivert/test/main.exe b/firewall/interception/windivert/test/main.exe deleted file mode 100644 index 3e51b896..00000000 Binary files a/firewall/interception/windivert/test/main.exe and /dev/null differ diff --git a/firewall/master.go b/firewall/master.go new file mode 100644 index 00000000..bc96c6f0 --- /dev/null +++ b/firewall/master.go @@ -0,0 +1,340 @@ +package firewall + +import ( + "fmt" + "os" + "strings" + + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/intel" + "github.com/Safing/portmaster/network" + "github.com/Safing/portmaster/network/packet" + "github.com/Safing/portmaster/profile" + "github.com/Safing/portmaster/status" + + "github.com/agext/levenshtein" +) + +// Call order: +// +// 1. DecideOnConnectionBeforeIntel (if connecting to domain) +// is called when a DNS query is made, before the query is resolved +// 2. DecideOnConnectionAfterIntel (if connecting to domain) +// is called when a DNS query is made, after the query is resolved +// 3. DecideOnConnection +// is called when the first packet of the first link of the connection arrives +// 4. DecideOnLink +// is called when when the first packet of a link arrives only if connection has verdict UNDECIDED or CANTSAY + +// DecideOnConnectionBeforeIntel makes a decision about a connection before the dns query is resolved and intel is gathered. +func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) { + // check: + // Profile.DomainWhitelist + // Profile.Flags + // - process specific: System, Admin, User + // - network specific: Internet, LocalNet + + // grant self + if connection.Process().Pid == os.Getpid() { + log.Infof("firewall: granting own connection %s", connection) + connection.Accept("") + return + } + + // check if there is a profile + profileSet := connection.Process().ProfileSet() + if profileSet == nil { + log.Errorf("firewall: denying connection %s, no profile set", connection) + connection.Deny("no profile set") + return + } + profileSet.Update(status.CurrentSecurityLevel()) + + // check for any network access + if !profileSet.CheckFlag(profile.Internet) && !profileSet.CheckFlag(profile.LAN) { + log.Infof("firewall: denying connection %s, accessing Internet or LAN not allowed", connection) + connection.Deny("accessing Internet or LAN not allowed") + return + } + + // check domain list + permitted, reason, ok := profileSet.CheckEndpoint(fqdn, 0, 0, false) + if ok { + if permitted { + log.Infof("firewall: accepting connection %s, endpoint is whitelisted: %s", connection, reason) + connection.Accept(fmt.Sprintf("endpoint is whitelisted: %s", reason)) + } else { + log.Infof("firewall: denying connection %s, endpoint is blacklisted", connection) + connection.Deny("endpoint is blacklisted") + } + return + } + + switch profileSet.GetProfileMode() { + case profile.Whitelist: + log.Infof("firewall: denying connection %s, domain is not whitelisted", connection) + connection.Deny("domain is not whitelisted") + case profile.Prompt: + + // check Related flag + // TODO: improve this! + if profileSet.CheckFlag(profile.Related) { + matched := false + pathElements := strings.Split(connection.Process().Path, "/") // FIXME: path seperator + // only look at the last two path segments + if len(pathElements) > 2 { + pathElements = pathElements[len(pathElements)-2:] + } + domainElements := strings.Split(fqdn, ".") + + var domainElement string + var processElement string + + matchLoop: + for _, domainElement = range domainElements { + for _, pathElement := range pathElements { + if levenshtein.Match(domainElement, pathElement, nil) > 0.5 { + matched = true + processElement = pathElement + break matchLoop + } + } + if levenshtein.Match(domainElement, profileSet.UserProfile().Name, nil) > 0.5 { + matched = true + processElement = profileSet.UserProfile().Name + break matchLoop + } + if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 { + matched = true + processElement = connection.Process().Name + break matchLoop + } + if levenshtein.Match(domainElement, connection.Process().ExecName, nil) > 0.5 { + matched = true + processElement = connection.Process().ExecName + break matchLoop + } + } + + if matched { + log.Infof("firewall: accepting connection %s, match to domain was found: %s ~= %s", connection, domainElement, processElement) + connection.Accept("domain is related to process") + } + } + + if connection.GetVerdict() != network.ACCEPT { + // TODO + log.Infof("firewall: accepting connection %s, domain permitted (prompting is not yet implemented)", connection) + connection.Accept("domain permitted (prompting is not yet implemented)") + } + + case profile.Blacklist: + log.Infof("firewall: accepting connection %s, domain is not blacklisted", connection) + connection.Accept("domain is not blacklisted") + } + +} + +// DecideOnConnectionAfterIntel makes a decision about a connection after the dns query is resolved and intel is gathered. +func DecideOnConnectionAfterIntel(connection *network.Connection, fqdn string, rrCache *intel.RRCache) *intel.RRCache { + + // grant self + if connection.Process().Pid == os.Getpid() { + log.Infof("firewall: granting own connection %s", connection) + connection.Accept("") + return rrCache + } + + // check if there is a profile + profileSet := connection.Process().ProfileSet() + if profileSet == nil { + log.Errorf("firewall: denying connection %s, no profile set", connection) + connection.Deny("no profile") + return rrCache + } + profileSet.Update(status.CurrentSecurityLevel()) + + // TODO: Stamp integration + + // TODO: Gate17 integration + // tunnelInfo, err := AssignTunnelIP(fqdn) + + rrCache.Duplicate().FilterEntries(profileSet.CheckFlag(profile.Internet), profileSet.CheckFlag(profile.LAN), false) + if len(rrCache.Answer) == 0 { + if profileSet.CheckFlag(profile.Internet) { + connection.Deny("server is located in the LAN, but LAN access is not permitted") + } else { + connection.Deny("server is located in the Internet, but Internet access is not permitted") + } + } + + return rrCache +} + +// DeciceOnConnection makes a decision about a connection with its first packet. +func DecideOnConnection(connection *network.Connection, pkt packet.Packet) { + + // grant self + if connection.Process().Pid == os.Getpid() { + log.Infof("firewall: granting own connection %s", connection) + connection.Accept("") + return + } + + // check if there is a profile + profileSet := connection.Process().ProfileSet() + if profileSet == nil { + log.Errorf("firewall: denying connection %s, no profile set", connection) + connection.Deny("no profile") + return + } + profileSet.Update(status.CurrentSecurityLevel()) + + // check connection type + switch connection.Domain { + case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: + if !profileSet.CheckFlag(profile.Service) { + log.Infof("firewall: denying connection %s, not a service", connection) + if connection.Domain == network.IncomingHost { + connection.Block("not a service") + } else { + connection.Drop("not a service") + } + return + } + case network.PeerLAN, network.PeerInternet, network.PeerInvalid: // Important: PeerHost is and should be missing! + if !profileSet.CheckFlag(profile.PeerToPeer) { + log.Infof("firewall: denying connection %s, peer to peer connections (to an IP) not allowed", connection) + connection.Deny("peer to peer connections (to an IP) not allowed") + return + } + default: + + } + + // check network scope + switch connection.Domain { + case network.IncomingHost: + if !profileSet.CheckFlag(profile.Localhost) { + log.Infof("firewall: denying connection %s, serving localhost not allowed", connection) + connection.Block("serving localhost not allowed") + return + } + case network.IncomingLAN: + if !profileSet.CheckFlag(profile.LAN) { + log.Infof("firewall: denying connection %s, serving LAN not allowed", connection) + connection.Deny("serving LAN not allowed") + return + } + case network.IncomingInternet: + if !profileSet.CheckFlag(profile.Internet) { + log.Infof("firewall: denying connection %s, serving Internet not allowed", connection) + connection.Deny("serving Internet not allowed") + return + } + case network.IncomingInvalid: + log.Infof("firewall: denying connection %s, invalid IP address", connection) + connection.Drop("invalid IP address") + return + case network.PeerHost: + if !profileSet.CheckFlag(profile.Localhost) { + log.Infof("firewall: denying connection %s, accessing localhost not allowed", connection) + connection.Block("accessing localhost not allowed") + return + } + case network.PeerLAN: + if !profileSet.CheckFlag(profile.LAN) { + log.Infof("firewall: denying connection %s, accessing the LAN not allowed", connection) + connection.Deny("accessing the LAN not allowed") + return + } + case network.PeerInternet: + if !profileSet.CheckFlag(profile.Internet) { + log.Infof("firewall: denying connection %s, accessing the Internet not allowed", connection) + connection.Deny("accessing the Internet not allowed") + return + } + case network.PeerInvalid: + log.Infof("firewall: denying connection %s, invalid IP address", connection) + connection.Deny("invalid IP address") + return + } + + log.Infof("firewall: accepting connection %s", connection) + connection.Accept("") +} + +// DecideOnLink makes a decision about a link with the first packet. +func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet.Packet) { + // check: + // Profile.Flags + // - network specific: Internet, LocalNet + // Profile.ConnectPorts + // Profile.ListenPorts + + // grant self + if connection.Process().Pid == os.Getpid() { + log.Infof("firewall: granting own link %s", connection) + connection.Accept("") + return + } + + // check if there is a profile + profileSet := connection.Process().ProfileSet() + if profileSet == nil { + log.Infof("firewall: no profile, denying %s", link) + link.Block("no profile") + return + } + profileSet.Update(status.CurrentSecurityLevel()) + + // get host + var domainOrIP string + switch { + case strings.HasSuffix(connection.Domain, "."): + domainOrIP = connection.Domain + case connection.Direction: + domainOrIP = pkt.GetIPHeader().Src.String() + default: + domainOrIP = pkt.GetIPHeader().Dst.String() + } + + // get protocol / destination port + protocol := pkt.GetIPHeader().Protocol + var dstPort uint16 + tcpUDPHeader := pkt.GetTCPUDPHeader() + if tcpUDPHeader != nil { + dstPort = tcpUDPHeader.DstPort + } + + // check endpoints list + permitted, reason, ok := profileSet.CheckEndpoint(domainOrIP, uint8(protocol), dstPort, connection.Direction) + if ok { + if permitted { + log.Infof("firewall: accepting link %s, endpoint is whitelisted: %s", link, reason) + link.Accept(fmt.Sprintf("port whitelisted: %s", reason)) + } else { + log.Infof("firewall: denying link %s: port %d is blacklisted", link, dstPort) + link.Deny("port blacklisted") + } + return + } + + switch profileSet.GetProfileMode() { + case profile.Whitelist: + log.Infof("firewall: denying link %s: endpoint %d is not whitelisted", link, dstPort) + link.Deny("endpoint is not whitelisted") + return + case profile.Prompt: + log.Infof("firewall: accepting link %s: endpoint %d is blacklisted", link, dstPort) + link.Accept("endpoint permitted (prompting is not yet implemented)") + return + case profile.Blacklist: + log.Infof("firewall: accepting link %s: endpoint %d is not blacklisted", link, dstPort) + link.Accept("endpoint is not blacklisted") + return + } + + log.Infof("firewall: accepting link %s", link) + link.Accept("") +} diff --git a/portmaster/tunnel.go b/firewall/tunnel.go similarity index 95% rename from portmaster/tunnel.go rename to firewall/tunnel.go index 3269eb51..badbb98b 100644 --- a/portmaster/tunnel.go +++ b/firewall/tunnel.go @@ -1,4 +1,4 @@ -package portmaster +package firewall import ( "errors" @@ -7,8 +7,8 @@ import ( "sync" "time" - "github.com/Safing/safing-core/crypto/random" - "github.com/Safing/safing-core/intel" + "github.com/Safing/portbase/crypto/random" + "github.com/Safing/portmaster/intel" "github.com/miekg/dns" ) diff --git a/global/databases.go b/global/databases.go new file mode 100644 index 00000000..624370b6 --- /dev/null +++ b/global/databases.go @@ -0,0 +1,49 @@ +package global + +import ( + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/modules" + + // module dependencies + _ "github.com/Safing/portbase/database/dbmodule" + _ "github.com/Safing/portbase/database/storage/badger" + _ "github.com/Safing/portmaster/status" +) + +func init() { + modules.Register("global", nil, start, nil, "database", "status") +} + +func start() error { + _, err := database.Register(&database.Database{ + Name: "core", + Description: "Holds core data, such as settings and profiles", + StorageType: "badger", + PrimaryAPI: "", + }) + if err != nil { + return err + } + + _, err = database.Register(&database.Database{ + Name: "cache", + Description: "Cached data, such as Intelligence and DNS Records", + StorageType: "badger", + PrimaryAPI: "", + }) + if err != nil { + return err + } + + // _, err = database.Register(&database.Database{ + // Name: "history", + // Description: "Historic event data", + // StorageType: "badger", + // PrimaryAPI: "", + // }) + // if err != nil { + // return err + // } + + return nil +} diff --git a/intel/clients.go b/intel/clients.go new file mode 100644 index 00000000..18b34b02 --- /dev/null +++ b/intel/clients.go @@ -0,0 +1,94 @@ +package intel + +import ( + "crypto/tls" + "sync" + "time" + + "github.com/miekg/dns" +) + +type clientManager struct { + dnsClient *dns.Client + factory func() *dns.Client + + lock sync.Mutex + refreshAfter time.Time + ttl time.Duration // force refresh of connection to reduce traceability +} + +// ref: https://godoc.org/github.com/miekg/dns#Client + +func newDNSClientManager(resolver *Resolver) *clientManager { + return &clientManager{ + ttl: -1 * time.Minute, + factory: func() *dns.Client { + return &dns.Client{ + Timeout: 5 * time.Second, + } + }, + } +} + +func newTCPClientManager(resolver *Resolver) *clientManager { + return &clientManager{ + ttl: -15 * time.Minute, + factory: func() *dns.Client { + return &dns.Client{ + Net: "tcp", + Timeout: 5 * time.Second, + } + }, + } +} + +func newTLSClientManager(resolver *Resolver) *clientManager { + return &clientManager{ + ttl: -15 * time.Minute, + factory: func() *dns.Client { + return &dns.Client{ + Net: "tcp-tls", + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: resolver.VerifyDomain, + // TODO: use custom random + // Rand: io.Reader, + }, + Timeout: 5 * time.Second, + } + }, + } +} + +func newHTTPSClientManager(resolver *Resolver) *clientManager { + return &clientManager{ + ttl: -15 * time.Minute, + factory: func() *dns.Client { + new := &dns.Client{ + Net: "https", + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + // TODO: use custom random + // Rand: io.Reader, + }, + Timeout: 5 * time.Second, + } + if resolver.VerifyDomain != "" { + new.TLSConfig.ServerName = resolver.VerifyDomain + } + return new + }, + } +} + +func (cm *clientManager) getDNSClient() *dns.Client { + cm.lock.Lock() + defer cm.lock.Unlock() + + if cm.dnsClient == nil || time.Now().After(cm.refreshAfter) { + cm.dnsClient = cm.factory() + cm.refreshAfter = time.Now().Add(cm.ttl) + } + + return cm.dnsClient +} diff --git a/intel/config.go b/intel/config.go new file mode 100644 index 00000000..d7a9525c --- /dev/null +++ b/intel/config.go @@ -0,0 +1,100 @@ +package intel + +import ( + "github.com/Safing/portbase/config" + "github.com/Safing/portmaster/status" +) + +var ( + configuredNameServers config.StringArrayOption + defaultNameServers = []string{ + "tls|1.1.1.1:853|cloudflare-dns.com", // Cloudflare + "tls|1.0.0.1:853|cloudflare-dns.com", // Cloudflare + "tls|9.9.9.9:853|dns.quad9.net", // Quad9 + // "https|cloudflare-dns.com/dns-query", // HTTPS still experimental + "dns|1.1.1.1:53", // Cloudflare + "dns|1.0.0.1:53", // Cloudflare + "dns|9.9.9.9:53", // Quad9 + } + + nameserverRetryRate config.IntOption + doNotUseMulticastDNS status.SecurityLevelOption + doNotUseAssignedNameservers status.SecurityLevelOption + doNotResolveSpecialDomains status.SecurityLevelOption +) + +func prep() error { + err := config.Register(&config.Option{ + Name: "Nameservers (DNS)", + Key: "intel/nameservers", + Description: "Nameserver to use for resolving DNS requests.", + ExpertiseLevel: config.ExpertiseLevelExpert, + OptType: config.OptTypeStringArray, + DefaultValue: defaultNameServers, + ValidationRegex: "^(dns|tcp|tls|https)$", + }) + if err != nil { + return err + } + configuredNameServers = config.Concurrent.GetAsStringArray("intel/nameservers", defaultNameServers) + + err = config.Register(&config.Option{ + Name: "Nameserver Retry Rate", + Key: "intel/nameserverRetryRate", + Description: "Rate at which to retry failed nameservers, in seconds.", + ExpertiseLevel: config.ExpertiseLevelExpert, + OptType: config.OptTypeInt, + DefaultValue: 600, + }) + if err != nil { + return err + } + nameserverRetryRate = config.Concurrent.GetAsInt("intel/nameserverRetryRate", 0) + + err = config.Register(&config.Option{ + Name: "Do not use Multicast DNS", + Key: "intel/doNotUseMulticastDNS", + Description: "Multicast DNS queries other devices in the local network", + ExpertiseLevel: config.ExpertiseLevelExpert, + OptType: config.OptTypeInt, + ExternalOptType: "security level", + DefaultValue: 3, + ValidationRegex: "^(1|2|3)$", + }) + if err != nil { + return err + } + doNotUseMulticastDNS = status.ConfigIsActiveConcurrent("intel/doNotUseMulticastDNS") + + err = config.Register(&config.Option{ + Name: "Do not use assigned Nameservers", + Key: "intel/doNotUseAssignedNameservers", + Description: "that were acquired by the network (dhcp) or system", + ExpertiseLevel: config.ExpertiseLevelExpert, + OptType: config.OptTypeInt, + ExternalOptType: "security level", + DefaultValue: 3, + ValidationRegex: "^(1|2|3)$", + }) + if err != nil { + return err + } + doNotUseAssignedNameservers = status.ConfigIsActiveConcurrent("intel/doNotUseAssignedNameservers") + + err = config.Register(&config.Option{ + Name: "Do not resolve special domains", + Key: "intel/doNotResolveSpecialDomains", + Description: "Do not resolve special (top level) domains: example, example.com, example.net, example.org, invalid, test, onion. (RFC6761, RFC7686)", + ExpertiseLevel: config.ExpertiseLevelExpert, + OptType: config.OptTypeInt, + ExternalOptType: "security level", + DefaultValue: 3, + ValidationRegex: "^(1|2|3)$", + }) + if err != nil { + return err + } + doNotResolveSpecialDomains = status.ConfigIsActiveConcurrent("intel/doNotResolveSpecialDomains") + + return nil +} diff --git a/intel/data.go b/intel/data.go deleted file mode 100644 index 67074f22..00000000 --- a/intel/data.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package intel - -import ( - "github.com/Safing/safing-core/database" - - datastore "github.com/ipfs/go-datastore" -) - -// EntityClassification holds classification information about an internet entity. -type EntityClassification struct { - lists []byte -} - -// Intel holds intelligence data for a domain. -type Intel struct { - database.Base - Domain string - DomainOwner string - CertOwner string - Classification *EntityClassification -} - -var intelModel *Intel // only use this as parameter for database.EnsureModel-like functions - -func init() { - database.RegisterModel(intelModel, func() database.Model { return new(Intel) }) -} - -// Create saves the Intel with the provided name in the default namespace. -func (m *Intel) Create(name string) error { - return m.CreateObject(&database.IntelCache, name, m) -} - -// CreateInNamespace saves the Intel with the provided name in the provided namespace. -func (m *Intel) CreateInNamespace(namespace *datastore.Key, name string) error { - return m.CreateObject(namespace, name, m) -} - -// Save saves the Intel. -func (m *Intel) Save() error { - return m.SaveObject(m) -} - -// getIntel fetches the Intel with the provided name in the default namespace. -func getIntel(name string) (*Intel, error) { - return getIntelFromNamespace(&database.IntelCache, name) -} - -// getIntelFromNamespace fetches the Intel with the provided name in the provided namespace. -func getIntelFromNamespace(namespace *datastore.Key, name string) (*Intel, error) { - object, err := database.GetAndEnsureModel(namespace, name, intelModel) - if err != nil { - return nil, err - } - model, ok := object.(*Intel) - if !ok { - return nil, database.NewMismatchError(object, intelModel) - } - return model, nil -} diff --git a/intel/dns.go b/intel/dns.go deleted file mode 100644 index 2d15f13b..00000000 --- a/intel/dns.go +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package intel - -import ( - "fmt" - "net" - "time" - - "github.com/Safing/safing-core/database" - - datastore "github.com/ipfs/go-datastore" - "github.com/miekg/dns" -) - -// RRCache is used to cache DNS data -type RRCache struct { - Answer []dns.RR - Ns []dns.RR - Extra []dns.RR - Expires int64 - Modified int64 - servedFromCache bool - requestingNew bool -} - -func (m *RRCache) Clean(minExpires uint32) { - - var lowestTTL uint32 = 0xFFFFFFFF - var header *dns.RR_Header - - // set TTLs to 17 - // TODO: double append? is there something more elegant? - for _, rr := range append(m.Answer, append(m.Ns, m.Extra...)...) { - header = rr.Header() - if lowestTTL > header.Ttl { - lowestTTL = header.Ttl - } - header.Ttl = 17 - } - - // TTL must be at least minExpires - if lowestTTL < minExpires { - lowestTTL = minExpires - } - - m.Expires = time.Now().Unix() + int64(lowestTTL) - m.Modified = time.Now().Unix() - -} - -func (m *RRCache) ExportAllARecords() (ips []net.IP) { - for _, rr := range m.Answer { - if rr.Header().Class == dns.ClassINET && rr.Header().Rrtype == dns.TypeA { - aRecord, ok := rr.(*dns.A) - if ok { - ips = append(ips, aRecord.A) - } - } else if rr.Header().Class == dns.ClassINET && rr.Header().Rrtype == dns.TypeAAAA { - aRecord, ok := rr.(*dns.AAAA) - if ok { - ips = append(ips, aRecord.AAAA) - } - } - } - return -} - -func (m *RRCache) ToRRSave() *RRSave { - var s RRSave - s.Expires = m.Expires - s.Modified = m.Modified - for _, entry := range m.Answer { - s.Answer = append(s.Answer, entry.String()) - } - for _, entry := range m.Ns { - s.Ns = append(s.Ns, entry.String()) - } - for _, entry := range m.Extra { - s.Extra = append(s.Extra, entry.String()) - } - return &s -} - -func (m *RRCache) Create(name string) error { - s := m.ToRRSave() - return s.CreateObject(&database.DNSCache, name, s) -} - -func (m *RRCache) CreateWithType(name string, qtype dns.Type) error { - s := m.ToRRSave() - return s.Create(fmt.Sprintf("%s%s", name, qtype.String())) -} - -func (m *RRCache) Save() error { - s := m.ToRRSave() - return s.SaveObject(s) -} - -func GetRRCache(domain string, qtype dns.Type) (*RRCache, error) { - return GetRRCacheFromNamespace(&database.DNSCache, domain, qtype) -} - -func GetRRCacheFromNamespace(namespace *datastore.Key, domain string, qtype dns.Type) (*RRCache, error) { - var m RRCache - - rrSave, err := GetRRSaveFromNamespace(namespace, domain, qtype) - if err != nil { - return nil, err - } - - m.Expires = rrSave.Expires - m.Modified = rrSave.Modified - for _, entry := range rrSave.Answer { - rr, err := dns.NewRR(entry) - if err == nil { - m.Answer = append(m.Answer, rr) - } - } - for _, entry := range rrSave.Ns { - rr, err := dns.NewRR(entry) - if err == nil { - m.Ns = append(m.Ns, rr) - } - } - for _, entry := range rrSave.Extra { - rr, err := dns.NewRR(entry) - if err == nil { - m.Extra = append(m.Extra, rr) - } - } - - m.servedFromCache = true - return &m, nil -} - -// ServedFromCache marks the RRCache as served from cache. -func (m *RRCache) ServedFromCache() bool { - return m.servedFromCache -} - -// RequestingNew informs that it has expired and new RRs are being fetched. -func (m *RRCache) RequestingNew() bool { - return m.requestingNew -} - -// Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format. -func (m *RRCache) Flags() string { - switch { - case m.servedFromCache && m.requestingNew: - return " [CR]" - case m.servedFromCache: - return " [C]" - case m.requestingNew: - return " [R]" // theoretically impossible, but let's leave it here, just in case - default: - return "" - } -} - -// IsNXDomain returnes whether the result is nxdomain. -func (m *RRCache) IsNXDomain() bool { - return len(m.Answer) == 0 -} - -// RRSave is helper struct to RRCache to better save data to the database. -type RRSave struct { - database.Base - Answer []string - Ns []string - Extra []string - Expires int64 - Modified int64 -} - -var rrSaveModel *RRSave // only use this as parameter for database.EnsureModel-like functions - -func init() { - database.RegisterModel(rrSaveModel, func() database.Model { return new(RRSave) }) -} - -// Create saves RRSave with the provided name in the default namespace. -func (m *RRSave) Create(name string) error { - return m.CreateObject(&database.DNSCache, name, m) -} - -// CreateWithType saves RRSave with the provided name and type in the default namespace. -func (m *RRSave) CreateWithType(name string, qtype dns.Type) error { - return m.Create(fmt.Sprintf("%s%s", name, qtype.String())) -} - -// CreateInNamespace saves RRSave with the provided name in the provided namespace. -func (m *RRSave) CreateInNamespace(namespace *datastore.Key, name string) error { - return m.CreateObject(namespace, name, m) -} - -// Save saves RRSave. -func (m *RRSave) Save() error { - return m.SaveObject(m) -} - -// GetRRSave fetches RRSave with the provided name in the default namespace. -func GetRRSave(name string, qtype dns.Type) (*RRSave, error) { - return GetRRSaveFromNamespace(&database.DNSCache, name, qtype) -} - -// GetRRSaveFromNamespace fetches RRSave with the provided name in the provided namespace. -func GetRRSaveFromNamespace(namespace *datastore.Key, name string, qtype dns.Type) (*RRSave, error) { - object, err := database.GetAndEnsureModel(namespace, fmt.Sprintf("%s%s", name, qtype.String()), rrSaveModel) - if err != nil { - return nil, err - } - model, ok := object.(*RRSave) - if !ok { - return nil, database.NewMismatchError(object, rrSaveModel) - } - return model, nil -} diff --git a/intel/domainfronting.go b/intel/domainfronting.go deleted file mode 100644 index 516fc0d1..00000000 --- a/intel/domainfronting.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package intel - -import ( - "github.com/Safing/safing-core/log" - "sync" - - "github.com/miekg/dns" -) - -var ( - dfMap = make(map[string]string) - dfMapLock sync.RWMutex -) - -func checkDomainFronting(hidden string, qtype dns.Type, securityLevel int8) (*RRCache, bool) { - dfMapLock.RLock() - front, ok := dfMap[hidden] - dfMapLock.RUnlock() - if !ok { - return nil, false - } - log.Tracef("intel: applying domain fronting %s -> %s", hidden, front) - // get domain name - rrCache := resolveAndCache(front, qtype, securityLevel) - if rrCache == nil { - return nil, true - } - // replace domain name - var header *dns.RR_Header - for _, rr := range rrCache.Answer { - header = rr.Header() - if header.Name == front { - header.Name = hidden - } - } - // save under front - rrCache.CreateWithType(hidden, qtype) - return rrCache, true -} - -func addDomainFronting(hidden string, front string) { - dfMapLock.Lock() - dfMap[hidden] = front - dfMapLock.Unlock() - return -} diff --git a/intel/intel.go b/intel/intel.go index b6bf3f3c..c398710e 100644 --- a/intel/intel.go +++ b/intel/intel.go @@ -3,44 +3,66 @@ package intel import ( - "github.com/Safing/safing-core/database" - "github.com/Safing/safing-core/modules" + "fmt" + "sync" - "github.com/miekg/dns" + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/record" ) var ( - intelModule *modules.Module + intelDatabase = database.NewInterface(&database.Options{ + AlwaysSetRelativateExpiry: 2592000, // 30 days + }) ) -func init() { - intelModule = modules.Register("Intel", 128) - go Start() +// Intel holds intelligence data for a domain. +type Intel struct { + record.Base + sync.Mutex + + Domain string } -// GetIntel returns an Intel object of the given domain. The returned Intel object MUST not be modified. -func GetIntel(domain string) *Intel { - fqdn := dns.Fqdn(domain) - intel, err := getIntel(fqdn) +func makeIntelKey(domain string) string { + return fmt.Sprintf("cache:intel/domain/%s", domain) +} + +// GetIntelFromDB gets an Intel record from the database. +func GetIntelFromDB(domain string) (*Intel, error) { + key := makeIntelKey(domain) + + r, err := intelDatabase.Get(key) if err != nil { - if err == database.ErrNotFound { - intel = &Intel{Domain: fqdn} - intel.Create(fqdn) - } else { - return nil - } + return nil, err } - return intel + + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + new := &Intel{} + err = record.Unwrap(r, new) + if err != nil { + return nil, err + } + return new, nil + } + + // or adjust type + new, ok := r.(*Intel) + if !ok { + return nil, fmt.Errorf("record not of type *Intel, but %T", r) + } + return new, nil } -func GetIntelAndRRs(domain string, qtype dns.Type, securityLevel int8) (intel *Intel, rrs *RRCache) { - intel = GetIntel(domain) - rrs = Resolve(domain, qtype, securityLevel) - return +// Save saves the Intel record to the database. +func (intel *Intel) Save() error { + intel.SetKey(makeIntelKey(intel.Domain)) + return intelDatabase.PutNew(intel) } -func Start() { - // mocking until intel has its own goroutines - defer intelModule.StopComplete() - <-intelModule.Stop +// GetIntel fetches intelligence data for the given domain. +func GetIntel(domain string) (*Intel, error) { + return &Intel{Domain: domain}, nil } diff --git a/intel/ipinfo.go b/intel/ipinfo.go index 007e67c3..e00921fa 100644 --- a/intel/ipinfo.go +++ b/intel/ipinfo.go @@ -1,61 +1,91 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package intel import ( + "fmt" "strings" + "sync" - "github.com/Safing/safing-core/database" + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/utils" +) - datastore "github.com/ipfs/go-datastore" +var ( + ipInfoDatabase = database.NewInterface(&database.Options{ + AlwaysSetRelativateExpiry: 86400, // 24 hours + }) ) // IPInfo represents various information about an IP. type IPInfo struct { - database.Base + record.Base + sync.Mutex + + IP string Domains []string } -var ipInfoModel *IPInfo // only use this as parameter for database.EnsureModel-like functions - -func init() { - database.RegisterModel(ipInfoModel, func() database.Model { return new(IPInfo) }) +func makeIPInfoKey(ip string) string { + return fmt.Sprintf("cache:intel/ipInfo/%s", ip) } -// Create saves the IPInfo with the provided name in the default namespace. -func (m *IPInfo) Create(name string) error { - return m.CreateObject(&database.IPInfoCache, name, m) -} +// GetIPInfo gets an IPInfo record from the database. +func GetIPInfo(ip string) (*IPInfo, error) { + key := makeIPInfoKey(ip) -// CreateInNamespace saves the IPInfo with the provided name in the provided namespace. -func (m *IPInfo) CreateInNamespace(namespace *datastore.Key, name string) error { - return m.CreateObject(namespace, name, m) -} - -// Save saves the IPInfo. -func (m *IPInfo) Save() error { - return m.SaveObject(m) -} - -// GetIPInfo fetches the IPInfo with the provided name in the default namespace. -func GetIPInfo(name string) (*IPInfo, error) { - return GetIPInfoFromNamespace(&database.IPInfoCache, name) -} - -// GetIPInfoFromNamespace fetches the IPInfo with the provided name in the provided namespace. -func GetIPInfoFromNamespace(namespace *datastore.Key, name string) (*IPInfo, error) { - object, err := database.GetAndEnsureModel(namespace, name, ipInfoModel) + r, err := ipInfoDatabase.Get(key) if err != nil { return nil, err } - model, ok := object.(*IPInfo) - if !ok { - return nil, database.NewMismatchError(object, ipInfoModel) + + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + new := &IPInfo{} + err = record.Unwrap(r, new) + if err != nil { + return nil, err + } + return new, nil } - return model, nil + + // or adjust type + new, ok := r.(*IPInfo) + if !ok { + return nil, fmt.Errorf("record not of type *IPInfo, but %T", r) + } + return new, nil +} + +// AddDomain adds a domain to the list and reports back if it was added, or was already present. +func (ipi *IPInfo) AddDomain(domain string) (added bool) { + ipi.Lock() + defer ipi.Unlock() + if !utils.StringInSlice(ipi.Domains, domain) { + ipi.Domains = append([]string{domain}, ipi.Domains...) + return true + } + return false +} + +// Save saves the IPInfo record to the database. +func (ipi *IPInfo) Save() error { + ipi.Lock() + if !ipi.KeyIsSet() { + ipi.SetKey(makeIPInfoKey(ipi.IP)) + } + ipi.Unlock() + return ipInfoDatabase.Put(ipi) } // FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " -func (m *IPInfo) FmtDomains() string { - return strings.Join(m.Domains, " or ") +func (ipi *IPInfo) FmtDomains() string { + return strings.Join(ipi.Domains, " or ") +} + +// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " +func (ipi *IPInfo) String() string { + ipi.Lock() + defer ipi.Unlock() + return fmt.Sprintf(" local scopes, global // special -> local scopes, local -type Resolver struct { - // static - Server string - ServerAddress string - IP *net.IP - Port uint16 - Resolve func(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) - Search *[]string - AllowedSecurityLevel int8 - SkipFqdnBeforeInit string - HTTPClient *http.Client - Source string - - // atomic - Initialized *abool.AtomicBool - InitLock sync.Mutex - LastFail *int64 - Expires *int64 - - // must be locked - LockReason sync.Mutex - FailReason string - - // TODO: add: - // Expiration (for server got from DHCP / ICMPv6) - // bootstrapping (first query is already sent, wait for it to either succeed or fail - think about http bootstrapping here!) - // expanded server info: type, server address, server port, options - so we do not have to parse this every time! -} - -func (r *Resolver) String() string { - return r.Server -} - -func (r *Resolver) Address() string { - return urlFormatAddress(r.IP, r.Port) -} - -type Scope struct { - Domain string - Resolvers []*Resolver -} - -var ( - config = configuration.Get() - - globalResolvers []*Resolver // all resolvers - localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges - localScopes []Scope // list of scopes with a list of local resolvers that can resolve the scope - mDNSResolver *Resolver // holds a reference to the mDNS resolver - resolversLock sync.RWMutex - - env = environment.NewInterface() - - dupReqMap = make(map[string]*sync.Mutex) - dupReqLock sync.Mutex -) - -func init() { - loadResolvers(false) -} - -func indexOfResolver(server string, list []*Resolver) int { - for k, v := range list { - if v.Server == server { - return k - } - } - return -1 -} - -func indexOfScope(domain string, list *[]Scope) int { - for k, v := range *list { - if v.Domain == domain { - return k - } - } - return -1 -} - -func parseAddress(server string) (*net.IP, uint16, error) { - delimiter := strings.LastIndex(server, ":") - if delimiter < 0 { - return nil, 0, errors.New("port missing") - } - ip := net.ParseIP(strings.Trim(server[:delimiter], "[]")) - if ip == nil { - return nil, 0, errors.New("invalid IP address") - } - port, err := strconv.Atoi(server[delimiter+1:]) - if err != nil || port < 1 || port > 65536 { - return nil, 0, errors.New("invalid port") - } - return &ip, uint16(port), nil -} - -func urlFormatAddress(ip *net.IP, port uint16) string { - var address string - if ipv4 := ip.To4(); ipv4 != nil { - address = fmt.Sprintf("%s:%d", ipv4.String(), port) - } else { - address = fmt.Sprintf("[%s]:%d", ip.String(), port) - } - return address -} - -func loadResolvers(resetResolvers bool) { - // TODO: what happens when a lot of processes want to reload at once? we do not need to run this multiple times in a short time frame. - resolversLock.Lock() - defer resolversLock.Unlock() - - var newResolvers []*Resolver - -configuredServersLoop: - for _, server := range config.DNSServers { - key := indexOfResolver(server, newResolvers) - if key >= 0 { - continue configuredServersLoop - } - key = indexOfResolver(server, globalResolvers) - if resetResolvers || key == -1 { - parts := strings.Split(server, "|") - if len(parts) < 2 { - log.Warningf("intel: invalid DNS server in config: %s (invalid format)", server) - continue configuredServersLoop - } - var lastFail int64 - new := &Resolver{ - Server: server, - ServerAddress: parts[1], - LastFail: &lastFail, - Source: "config", - Initialized: abool.NewBool(false), - } - ip, port, err := parseAddress(parts[1]) - if err != nil { - new.IP = ip - new.Port = port - } - switch { - case strings.HasPrefix(server, "DNS|"): - new.Resolve = queryDNS - new.AllowedSecurityLevel = configuration.SecurityLevelFortress - case strings.HasPrefix(server, "DoH|"): - new.Resolve = queryDNSoverHTTPS - new.AllowedSecurityLevel = configuration.SecurityLevelFortress - new.SkipFqdnBeforeInit = dns.Fqdn(strings.Split(parts[1], ":")[0]) - - tls := &tls.Config{ - // TODO: use custom random - // Rand: io.Reader, - } - tr := &http.Transport{ - MaxIdleConnsPerHost: 100, - TLSClientConfig: tls, - // TODO: use custom resolver as of Go1.9 - } - if len(parts) == 3 && strings.HasPrefix(parts[2], "df:") { - // activate domain fronting - tls.ServerName = parts[2][3:] - addDomainFronting(new.SkipFqdnBeforeInit, dns.Fqdn(tls.ServerName)) - new.SkipFqdnBeforeInit = dns.Fqdn(tls.ServerName) - } - new.HTTPClient = &http.Client{Transport: tr} - - default: - log.Warningf("intel: invalid DNS server in config: %s (not starting with a valid identifier)", server) - continue configuredServersLoop - } - newResolvers = append(newResolvers, new) - } else { - newResolvers = append(newResolvers, globalResolvers[key]) - } - } - - // add local resolvers - assignedNameservers := environment.Nameservers() -assignedServersLoop: - for _, nameserver := range assignedNameservers { - server := fmt.Sprintf("DNS|%s", urlFormatAddress(&nameserver.IP, 53)) - key := indexOfResolver(server, newResolvers) - if key >= 0 { - continue assignedServersLoop - } - key = indexOfResolver(server, globalResolvers) - if resetResolvers || key == -1 { - var lastFail int64 - new := &Resolver{ - Server: server, - ServerAddress: urlFormatAddress(&nameserver.IP, 53), - IP: &nameserver.IP, - Port: 53, - LastFail: &lastFail, - Resolve: queryDNS, - AllowedSecurityLevel: configuration.SecurityLevelFortress, - Initialized: abool.NewBool(false), - Source: "dhcp", - } - if netutils.IPIsLocal(nameserver.IP) && len(nameserver.Search) > 0 { - // only allow searches for local resolvers - var newSearch []string - for _, value := range nameserver.Search { - newSearch = append(newSearch, fmt.Sprintf(".%s.", strings.Trim(value, "."))) - } - new.Search = &newSearch - } - newResolvers = append(newResolvers, new) - } else { - newResolvers = append(newResolvers, globalResolvers[key]) - } - } - - // save resolvers - globalResolvers = newResolvers - if len(globalResolvers) == 0 { - log.Criticalf("intel: no (valid) dns servers found in configuration and system") - } - - // make list with local resolvers - localResolvers = make([]*Resolver, 0) - for _, resolver := range globalResolvers { - if resolver.IP != nil && netutils.IPIsLocal(*resolver.IP) { - localResolvers = append(localResolvers, resolver) - } - } - - // add resolvers to every scope the cover - localScopes = make([]Scope, 0) - for _, resolver := range globalResolvers { - - if resolver.Search != nil { - // add resolver to custom searches - for _, search := range *resolver.Search { - if search == "." { - continue - } - key := indexOfScope(search, &localScopes) - if key == -1 { - localScopes = append(localScopes, Scope{ - Domain: search, - Resolvers: []*Resolver{resolver}, - }) - } else { - localScopes[key].Resolvers = append(localScopes[key].Resolvers, resolver) - } - } - - } - } - - // init mdns resolver - if mDNSResolver == nil { - cannotFail := int64(-1) - mDNSResolver = &Resolver{ - Server: "mDNS", - Resolve: queryMulticastDNS, - AllowedSecurityLevel: config.DoNotUseMDNS.Level(), - Initialized: abool.NewBool(false), - Source: "static", - LastFail: &cannotFail, - } - } - - // sort scopes by length - sort.Slice(localScopes, - func(i, j int) bool { - return len(localScopes[i].Domain) > len(localScopes[j].Domain) - }, - ) - - log.Trace("intel: loaded global resolvers:") - for _, resolver := range globalResolvers { - log.Tracef("intel: %s", resolver.Server) - } - log.Trace("intel: loaded local resolvers:") - for _, resolver := range localResolvers { - log.Tracef("intel: %s", resolver.Server) - } - log.Trace("intel: loaded scopes:") - for _, scope := range localScopes { - var scopeServers []string - for _, resolver := range scope.Resolvers { - scopeServers = append(scopeServers, resolver.Server) - } - log.Tracef("intel: %s: %s", scope.Domain, strings.Join(scopeServers, ", ")) - } - -} - // Resolve resolves the given query for a domain and type and returns a RRCache object or nil, if the query failed. -func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { +func Resolve(fqdn string, qtype dns.Type, securityLevel uint8) *RRCache { fqdn = dns.Fqdn(fqdn) // use this to time how long it takes resolve this domain // timed := time.Now() // defer log.Tracef("intel: took %s to get resolve %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String()) - // handle request for localhost - if fqdn == "localhost." { - var rr dns.RR - var err error - switch uint16(qtype) { - case dns.TypeA: - rr, err = dns.NewRR("localhost. 3600 IN A 127.0.0.1") - case dns.TypeAAAA: - rr, err = dns.NewRR("localhost. 3600 IN AAAA ::1") - default: - return nil - } - if err != nil { - return nil - } - return &RRCache{ - Answer: []dns.RR{rr}, - } - } - // check cache rrCache, err := GetRRCache(fqdn, qtype) if err != nil { @@ -406,7 +88,8 @@ func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { return resolveAndCache(fqdn, qtype, securityLevel) } - if rrCache.Expires <= time.Now().Unix() { + if rrCache.TTL <= time.Now().Unix() { + log.Tracef("intel: serving cache, requesting new. TTL=%d, now=%d", rrCache.TTL, time.Now().Unix()) rrCache.requestingNew = true go resolveAndCache(fqdn, qtype, securityLevel) } @@ -420,17 +103,9 @@ func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { return rrCache } -func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { +func resolveAndCache(fqdn string, qtype dns.Type, securityLevel uint8) (rrCache *RRCache) { // log.Tracef("intel: resolving %s%s", fqdn, qtype.String()) - rrCache, ok := checkDomainFronting(fqdn, qtype, securityLevel) - if ok { - if rrCache == nil { - return nil - } - return rrCache - } - // dedup requests dupKey := fmt.Sprintf("%s%s", fqdn, qtype.String()) dupReqLock.Lock() @@ -456,7 +131,7 @@ func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { } defer func() { dupReqLock.Lock() - delete(dupReqMap, fqdn) + delete(dupReqMap, dupKey) dupReqLock.Unlock() mutex.Unlock() }() @@ -469,29 +144,29 @@ func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { // persist to database rrCache.Clean(600) - rrCache.CreateWithType(fqdn, qtype) + rrCache.Save() return rrCache } -func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { +func intelligentResolve(fqdn string, qtype dns.Type, securityLevel uint8) *RRCache { // TODO: handle being offline // TODO: handle multiple network connections - if config.Changed() { - log.Info("intel: config changed, reloading resolvers") - loadResolvers(false) - } else if env.NetworkChanged() { - log.Info("intel: network changed, reloading resolvers") - loadResolvers(true) - } - config.RLock() - defer config.RUnlock() + // TODO: handle these in a separate goroutine + // if config.Changed() { + // log.Info("intel: config changed, reloading resolvers") + // loadResolvers(false) + // } else if env.NetworkChanged() { + // log.Info("intel: network changed, reloading resolvers") + // loadResolvers(true) + // } + resolversLock.RLock() defer resolversLock.RUnlock() - lastFailBoundary := time.Now().Unix() - config.DNSServerRetryRate + lastFailBoundary := time.Now().Unix() - nameserverRetryRate() preDottedFqdn := "." + fqdn // resolve: @@ -510,11 +185,14 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach } } // check config - if config.DoNotUseMDNS.IsSetWithLevel(securityLevel) { + if doNotUseMulticastDNS(securityLevel) { return nil } // try mdns - rrCache, _ := tryResolver(mDNSResolver, lastFailBoundary, fqdn, qtype, securityLevel) + rrCache, err := queryMulticastDNS(fqdn, qtype) + if err != nil { + log.Errorf("intel: failed to query mdns: %s", err) + } return rrCache } @@ -533,15 +211,18 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach switch { case strings.HasSuffix(preDottedFqdn, ".local."): // check config - if config.DoNotUseMDNS.IsSetWithLevel(securityLevel) { + if doNotUseMulticastDNS(securityLevel) { return nil } // try mdns - rrCache, _ := tryResolver(mDNSResolver, lastFailBoundary, fqdn, qtype, securityLevel) + rrCache, err := queryMulticastDNS(fqdn, qtype) + if err != nil { + log.Errorf("intel: failed to query mdns: %s", err) + } return rrCache case domainInScopes(preDottedFqdn, specialScopes): // check config - if config.DoNotForwardSpecialDomains.IsSetWithLevel(securityLevel) { + if doNotResolveSpecialDomains(securityLevel) { return nil } // try local resolvers @@ -568,15 +249,15 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach } -func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype dns.Type, securityLevel int8) (*RRCache, bool) { +func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype dns.Type, securityLevel uint8) (*RRCache, bool) { // skip if not allowed in current security level - if resolver.AllowedSecurityLevel < config.SecurityLevel() || resolver.AllowedSecurityLevel < securityLevel { - log.Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, config.SecurityLevel(), securityLevel) + if resolver.AllowedSecurityLevel < status.CurrentSecurityLevel() || resolver.AllowedSecurityLevel < securityLevel { + log.Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, status.CurrentSecurityLevel(), securityLevel) return nil, false } // skip if not security level denies assigned dns servers - if config.DoNotUseAssignedDNS.IsSetWithLevel(securityLevel) && resolver.Source == "dhcp" { - log.Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d (%d)", resolver, config.SecurityLevel(), securityLevel, int8(config.DoNotUseAssignedDNS)) + if doNotUseAssignedNameservers(securityLevel) && resolver.Source == "dhcp" { + log.Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d", resolver, status.CurrentSecurityLevel(), securityLevel) return nil, false } // check if failed recently @@ -606,7 +287,7 @@ func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype } // resolve log.Tracef("intel: trying to resolve %s%s with %s", fqdn, qtype.String(), resolver.Server) - rrCache, err := resolver.Resolve(resolver, fqdn, qtype) + rrCache, err := query(resolver, fqdn, qtype) if err != nil { // check if failing is disabled if atomic.LoadInt64(resolver.LastFail) == -1 { @@ -622,126 +303,67 @@ func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype return nil, false } resolver.Initialized.SetToIf(false, true) + + // remove localhost entries, remove LAN entries if server is in global IP space. + if resolver.ServerIPScope == netutils.Global { + rrCache.FilterEntries(true, false, false) + } else { + rrCache.FilterEntries(true, true, false) + } + return rrCache, true } -func queryDNS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) { +func query(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) { q := new(dns.Msg) q.SetQuestion(fqdn, uint16(qtype)) var reply *dns.Msg var err error - for i := 0; i < 5; i++ { - client := new(dns.Client) - reply, _, err = client.Exchange(q, resolver.ServerAddress) + for i := 0; i < 3; i++ { + + // log query time + // qStart := time.Now() + reply, _, err = resolver.clientManager.getDNSClient().Exchange(q, resolver.ServerAddress) + // log.Tracef("intel: query to %s took %s", resolver.Server, time.Now().Sub(qStart)) + + // error handling if err != nil { + log.Tracef("intel: query to %s encountered error: %s", resolver.Server, err) // TODO: handle special cases // 1. connect: network is unreachable // 2. timeout + // temporary error if nerr, ok := err.(net.Error); ok && nerr.Timeout() { log.Tracef("intel: retrying to resolve %s%s with %s, error was: %s", fqdn, qtype.String(), resolver.Server, err) continue } + + // permanent error break } + + // no error + break } if err != nil { - log.Warningf("resolving %s%s failed: %s", fqdn, qtype.String(), err) - return nil, fmt.Errorf("resolving %s%s failed: %s", fqdn, qtype.String(), err) + err = fmt.Errorf("resolving %s%s failed: %s", fqdn, qtype.String(), err) + log.Warning(err.Error()) + return nil, err } new := &RRCache{ - Answer: reply.Answer, - Ns: reply.Ns, - Extra: reply.Extra, + Domain: fqdn, + Question: qtype, + Answer: reply.Answer, + Ns: reply.Ns, + Extra: reply.Extra, } // TODO: check if reply.Answer is valid return new, nil } - -type DnsOverHttpsReply struct { - Status uint32 - Truncated bool `json:"TC"` - Answer []DohRR - Additional []DohRR -} - -type DohRR struct { - Name string `json:"name"` - Qtype uint16 `json:"type"` - TTL uint32 `json:"TTL"` - Data string `json:"data"` -} - -func queryDNSoverHTTPS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) { - - // API documentation: https://developers.google.com/speed/public-dns/docs/dns-over-https - - payload := url.Values{} - payload.Add("name", fqdn) - payload.Add("type", strconv.Itoa(int(qtype))) - payload.Add("edns_client_subnet", "0.0.0.0/0") - // TODO: add random - only use upper- and lower-case letters, digits, hyphen, period, underscore and tilde - // payload.Add("random_padding", "") - - resp, err := resolver.HTTPClient.Get(fmt.Sprintf("https://%s/resolve?%s", resolver.ServerAddress, payload.Encode())) - if err != nil { - return nil, fmt.Errorf("resolving %s%s failed: http error: %s", fqdn, qtype.String(), err) - // TODO: handle special cases - // 1. connect: network is unreachable - // intel: resolver DoH|dns.google.com:443|df:www.google.com failed (resolving discovery-v4-4.syncthing.net.A failed: http error: Get https://dns.google.com:443/resolve?edns_client_subnet=0.0.0.0%2F0&name=discovery-v4-4.syncthing.net.&type=1: dial tcp [2a00:1450:4001:819::2004]:443: connect: network is unreachable), moving to next - // 2. timeout - } - if resp.StatusCode != 200 { - return nil, fmt.Errorf("resolving %s%s failed: request was unsuccessful, got code %d", fqdn, qtype.String(), resp.StatusCode) - } - - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("resolving %s%s failed: error reading response body: %s", fqdn, qtype.String(), err) - } - - var reply DnsOverHttpsReply - err = json.Unmarshal(body, &reply) - if err != nil { - return nil, fmt.Errorf("resolving %s%s failed: error parsing response body: %s", fqdn, qtype.String(), err) - } - - if reply.Status != 0 { - // this happens if there is a server error (e.g. DNSSEC fail), ignore for now - // TODO: do something more intelligent - } - - new := new(RRCache) - - // TODO: handle TXT records - - for _, entry := range reply.Answer { - rr, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data)) - if err != nil { - log.Warningf("intel: resolving %s%s failed: failed to parse record to DNS: %s %d IN %s %s", fqdn, qtype.String(), entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data) - continue - } - new.Answer = append(new.Answer, rr) - } - - for _, entry := range reply.Additional { - rr, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data)) - if err != nil { - log.Warningf("intel: resolving %s%s failed: failed to parse record to DNS: %s %d IN %s %s", fqdn, qtype.String(), entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data) - continue - } - new.Extra = append(new.Extra, rr) - } - - return new, nil -} - -// TODO: implement T-DNS: DNS over TCP/TLS -// server list: https://dnsprivacy.org/wiki/display/DP/DNS+Privacy+Test+Servers diff --git a/intel/resolve_test.go b/intel/resolve_test.go index c9754b16..b206bdcf 100644 --- a/intel/resolve_test.go +++ b/intel/resolve_test.go @@ -2,14 +2,16 @@ package intel -import ( - "testing" - "time" +// DISABLE TESTING FOR NOW: find a way to have tests with the module system - "github.com/miekg/dns" -) +// import ( +// "testing" +// "time" +// +// "github.com/miekg/dns" +// ) -func TestResolve(t *testing.T) { - Resolve("google.com.", dns.Type(dns.TypeA), 0) - time.Sleep(200 * time.Millisecond) -} +// func TestResolve(t *testing.T) { +// Resolve("google.com.", dns.Type(dns.TypeA), 0) +// time.Sleep(200 * time.Millisecond) +// } diff --git a/intel/resolver.go b/intel/resolver.go new file mode 100644 index 00000000..668197cf --- /dev/null +++ b/intel/resolver.go @@ -0,0 +1,295 @@ +package intel + +import ( + "errors" + "fmt" + "net" + "sort" + "strconv" + "strings" + "sync" + + "github.com/miekg/dns" + "github.com/tevino/abool" + + "github.com/Safing/portbase/log" + + "github.com/Safing/portmaster/network/environment" + "github.com/Safing/portmaster/network/netutils" + "github.com/Safing/portmaster/status" +) + +// Resolver holds information about an active resolver. +type Resolver struct { + // static + Server string + ServerType string + ServerAddress string + ServerIP net.IP + ServerIPScope int8 + ServerPort uint16 + VerifyDomain string + Source string + clientManager *clientManager + + Search *[]string + AllowedSecurityLevel uint8 + SkipFqdnBeforeInit string + + // atomic + Initialized *abool.AtomicBool + InitLock sync.Mutex + LastFail *int64 + Expires *int64 + + // must be locked + LockReason sync.Mutex + FailReason string + + // TODO: add: + // Expiration (for server got from DHCP / ICMPv6) + // bootstrapping (first query is already sent, wait for it to either succeed or fail - think about http bootstrapping here!) + // expanded server info: type, server address, server port, options - so we do not have to parse this every time! +} + +func (r *Resolver) String() string { + return r.Server +} + +// Scope defines a domain scope and which resolvers can resolve it. +type Scope struct { + Domain string + Resolvers []*Resolver +} + +var ( + globalResolvers []*Resolver // all resolvers + localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges + localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope + resolversLock sync.RWMutex + + env = environment.NewInterface() + + dupReqMap = make(map[string]*sync.Mutex) + dupReqLock sync.Mutex +) + +func indexOfResolver(server string, list []*Resolver) int { + for k, v := range list { + if v.Server == server { + return k + } + } + return -1 +} + +func indexOfScope(domain string, list []*Scope) int { + for k, v := range list { + if v.Domain == domain { + return k + } + } + return -1 +} + +func parseAddress(server string) (net.IP, uint16, error) { + delimiter := strings.LastIndex(server, ":") + if delimiter < 0 { + return nil, 0, errors.New("port missing") + } + ip := net.ParseIP(strings.Trim(server[:delimiter], "[]")) + if ip == nil { + return nil, 0, errors.New("invalid IP address") + } + port, err := strconv.Atoi(server[delimiter+1:]) + if err != nil || port < 1 || port > 65536 { + return nil, 0, errors.New("invalid port") + } + return ip, uint16(port), nil +} + +func urlFormatAddress(ip net.IP, port uint16) string { + var address string + if ipv4 := ip.To4(); ipv4 != nil { + address = fmt.Sprintf("%s:%d", ipv4.String(), port) + } else { + address = fmt.Sprintf("[%s]:%d", ip.String(), port) + } + return address +} + +func loadResolvers(resetResolvers bool) { + // TODO: what happens when a lot of processes want to reload at once? we do not need to run this multiple times in a short time frame. + resolversLock.Lock() + defer resolversLock.Unlock() + + var newResolvers []*Resolver + +configuredServersLoop: + for _, server := range configuredNameServers() { + key := indexOfResolver(server, newResolvers) + if key >= 0 { + continue configuredServersLoop + } + key = indexOfResolver(server, globalResolvers) + if resetResolvers || key == -1 { + + parts := strings.Split(server, "|") + if len(parts) < 2 { + log.Warningf("intel: nameserver format invalid: %s", server) + continue configuredServersLoop + } + + ip, port, err := parseAddress(parts[1]) + if err != nil && strings.ToLower(parts[0]) != "https" { + log.Warningf("intel: nameserver (%s) address invalid: %s", server, err) + continue configuredServersLoop + } + + var lastFail int64 + new := &Resolver{ + Server: server, + ServerType: parts[0], + ServerAddress: parts[1], + ServerIP: ip, + ServerIPScope: netutils.ClassifyIP(ip), + ServerPort: port, + LastFail: &lastFail, + Source: "config", + Initialized: abool.NewBool(false), + } + + switch strings.ToLower(parts[0]) { + case "dns": + new.clientManager = newDNSClientManager(new) + case "tcp": + new.clientManager = newTCPClientManager(new) + case "tls": + new.AllowedSecurityLevel = status.SecurityLevelFortress + if len(parts) < 3 { + log.Warningf("intel: nameserver missing verification domain as third parameter: %s", server) + continue configuredServersLoop + } + new.VerifyDomain = parts[2] + new.clientManager = newTLSClientManager(new) + case "https": + new.AllowedSecurityLevel = status.SecurityLevelFortress + new.SkipFqdnBeforeInit = dns.Fqdn(strings.Split(parts[1], ":")[0]) + if len(parts) > 2 { + new.VerifyDomain = parts[2] + } + new.clientManager = newHTTPSClientManager(new) + default: + log.Warningf("intel: nameserver (%s) type invalid: %s", server, parts[0]) + continue configuredServersLoop + } + newResolvers = append(newResolvers, new) + } else { + newResolvers = append(newResolvers, globalResolvers[key]) + } + } + + // add local resolvers + assignedNameservers := environment.Nameservers() +assignedServersLoop: + for _, nameserver := range assignedNameservers { + server := fmt.Sprintf("dns|%s", urlFormatAddress(nameserver.IP, 53)) + key := indexOfResolver(server, newResolvers) + if key >= 0 { + continue assignedServersLoop + } + key = indexOfResolver(server, globalResolvers) + if resetResolvers || key == -1 { + + var lastFail int64 + new := &Resolver{ + Server: server, + ServerType: "dns", + ServerAddress: urlFormatAddress(nameserver.IP, 53), + ServerIP: nameserver.IP, + ServerIPScope: netutils.ClassifyIP(nameserver.IP), + ServerPort: 53, + LastFail: &lastFail, + Source: "dhcp", + Initialized: abool.NewBool(false), + AllowedSecurityLevel: status.SecurityLevelSecure, + } + new.clientManager = newDNSClientManager(new) + + if netutils.IPIsLAN(nameserver.IP) && len(nameserver.Search) > 0 { + // only allow searches for local resolvers + var newSearch []string + for _, value := range nameserver.Search { + newSearch = append(newSearch, fmt.Sprintf(".%s.", strings.Trim(value, "."))) + } + new.Search = &newSearch + } + newResolvers = append(newResolvers, new) + } else { + newResolvers = append(newResolvers, globalResolvers[key]) + } + } + + // save resolvers + globalResolvers = newResolvers + if len(globalResolvers) == 0 { + log.Criticalf("intel: no (valid) dns servers found in configuration and system") + } + + // make list with local resolvers + localResolvers = make([]*Resolver, 0) + for _, resolver := range globalResolvers { + if resolver.ServerIP != nil && netutils.IPIsLAN(resolver.ServerIP) { + localResolvers = append(localResolvers, resolver) + } + } + + // add resolvers to every scope the cover + localScopes = make([]*Scope, 0) + for _, resolver := range globalResolvers { + + if resolver.Search != nil { + // add resolver to custom searches + for _, search := range *resolver.Search { + if search == "." { + continue + } + key := indexOfScope(search, localScopes) + if key == -1 { + localScopes = append(localScopes, &Scope{ + Domain: search, + Resolvers: []*Resolver{resolver}, + }) + } else { + localScopes[key].Resolvers = append(localScopes[key].Resolvers, resolver) + } + } + + } + } + + // sort scopes by length + sort.Slice(localScopes, + func(i, j int) bool { + return len(localScopes[i].Domain) > len(localScopes[j].Domain) + }, + ) + + log.Trace("intel: loaded global resolvers:") + for _, resolver := range globalResolvers { + log.Tracef("intel: %s", resolver.Server) + } + log.Trace("intel: loaded local resolvers:") + for _, resolver := range localResolvers { + log.Tracef("intel: %s", resolver.Server) + } + log.Trace("intel: loaded scopes:") + for _, scope := range localScopes { + var scopeServers []string + for _, resolver := range scope.Resolvers { + scopeServers = append(scopeServers, resolver.Server) + } + log.Tracef("intel: %s: %s", scope.Domain, strings.Join(scopeServers, ", ")) + } + +} diff --git a/intel/reverse.go b/intel/reverse.go new file mode 100644 index 00000000..21f811bc --- /dev/null +++ b/intel/reverse.go @@ -0,0 +1,72 @@ +package intel + +import ( + "errors" + "strings" + + "github.com/Safing/portbase/log" + "github.com/miekg/dns" +) + +// ResolveIPAndValidate finds (reverse DNS), validates (forward DNS) and returns the domain name assigned to the given IP. +func ResolveIPAndValidate(ip string, securityLevel uint8) (domain string, err error) { + // get reversed DNS address + rQ, err := dns.ReverseAddr(ip) + if err != nil { + log.Tracef("intel: failed to get reverse address of %s: %s", ip, err) + return "", err + } + + // get PTR record + rrCache := Resolve(rQ, dns.Type(dns.TypePTR), securityLevel) + if rrCache == nil { + return "", errors.New("querying for PTR record failed (may be NXDomain)") + } + + // get result from record + var ptrName string + for _, rr := range rrCache.Answer { + ptrRec, ok := rr.(*dns.PTR) + if ok { + ptrName = ptrRec.Ptr + break + } + } + + // check for nxDomain + if ptrName == "" { + return "", errors.New("no PTR record for IP (nxDomain)") + } + + log.Infof("ptrName: %s", ptrName) + + // get forward record + if strings.Contains(ip, ":") { + rrCache = Resolve(ptrName, dns.Type(dns.TypeAAAA), securityLevel) + } else { + rrCache = Resolve(ptrName, dns.Type(dns.TypeA), securityLevel) + } + if rrCache == nil { + return "", errors.New("querying for A/AAAA record failed (may be NXDomain)") + } + + // check for matching A/AAAA record + log.Infof("rr: %s", rrCache) + for _, rr := range rrCache.Answer { + switch v := rr.(type) { + case *dns.A: + log.Infof("A: %s", v.A.String()) + if ip == v.A.String() { + return ptrName, nil + } + case *dns.AAAA: + log.Infof("AAAA: %s", v.AAAA.String()) + if ip == v.AAAA.String() { + return ptrName, nil + } + } + } + + // no match + return "", errors.New("validation failed") +} diff --git a/intel/reverse_test.go b/intel/reverse_test.go new file mode 100644 index 00000000..e53d1a6f --- /dev/null +++ b/intel/reverse_test.go @@ -0,0 +1,28 @@ +package intel + +import "testing" + +func testReverse(t *testing.T, ip, result, expectedErr string) { + domain, err := ResolveIPAndValidate(ip, 0) + if err != nil { + if expectedErr == "" || err.Error() != expectedErr { + t.Errorf("reverse-validating %s: unexpected error: %s", ip, err) + } + return + } + + if domain != result { + t.Errorf("reverse-validating %s: unexpected result: %s", ip, domain) + } +} + +func TestResolveIPAndValidate(t *testing.T) { + testReverse(t, "198.41.0.4", "a.root-servers.net.", "") + testReverse(t, "9.9.9.9", "dns.quad9.net.", "") + testReverse(t, "2620:fe::fe", "dns.quad9.net.", "") + testReverse(t, "1.1.1.1", "one.one.one.one.", "") + testReverse(t, "2606:4700:4700::1111", "one.one.one.one.", "") + + testReverse(t, "93.184.216.34", "example.com.", "no PTR record for IP (nxDomain)") + testReverse(t, "185.199.109.153", "sites.github.io.", "no PTR record for IP (nxDomain)") +} diff --git a/intel/rrcache.go b/intel/rrcache.go new file mode 100644 index 00000000..d5fd5009 --- /dev/null +++ b/intel/rrcache.go @@ -0,0 +1,256 @@ +// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. + +package intel + +import ( + "fmt" + "net" + "strings" + "time" + + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/network/netutils" + "github.com/miekg/dns" +) + +// RRCache is used to cache DNS data +type RRCache struct { + Domain string + Question dns.Type + + Answer []dns.RR + Ns []dns.RR + Extra []dns.RR + TTL int64 + + updated int64 + servedFromCache bool + requestingNew bool + Filtered bool +} + +// Clean sets all TTLs to 17 and sets cache expiry with specified minimum. +func (m *RRCache) Clean(minExpires uint32) { + var lowestTTL uint32 = 0xFFFFFFFF + var header *dns.RR_Header + + // set TTLs to 17 + // TODO: double append? is there something more elegant? + for _, rr := range append(m.Answer, append(m.Ns, m.Extra...)...) { + header = rr.Header() + if lowestTTL > header.Ttl { + lowestTTL = header.Ttl + } + header.Ttl = 17 + } + + // TTL must be at least minExpires + if lowestTTL < minExpires { + lowestTTL = minExpires + } + + // log.Tracef("lowest TTL is %d", lowestTTL) + m.TTL = time.Now().Unix() + int64(lowestTTL) +} + +// ExportAllARecords return of a list of all A and AAAA IP addresses. +func (m *RRCache) ExportAllARecords() (ips []net.IP) { + for _, rr := range m.Answer { + if rr.Header().Class != dns.ClassINET { + continue + } + + switch rr.Header().Rrtype { + case dns.TypeA: + aRecord, ok := rr.(*dns.A) + if ok { + ips = append(ips, aRecord.A) + } + case dns.TypeAAAA: + aaaaRecord, ok := rr.(*dns.AAAA) + if ok { + ips = append(ips, aaaaRecord.AAAA) + } + } + } + return +} + +// ToNameRecord converts the RRCache to a NameRecord for cleaner persistence. +func (m *RRCache) ToNameRecord() *NameRecord { + new := &NameRecord{ + Domain: m.Domain, + Question: m.Question.String(), + TTL: m.TTL, + Filtered: m.Filtered, + } + + // stringify RR entries + for _, entry := range m.Answer { + new.Answer = append(new.Answer, entry.String()) + } + for _, entry := range m.Ns { + new.Ns = append(new.Ns, entry.String()) + } + for _, entry := range m.Extra { + new.Extra = append(new.Extra, entry.String()) + } + + return new +} + +// Save saves the RRCache to the database as a NameRecord. +func (m *RRCache) Save() error { + return m.ToNameRecord().Save() +} + +// GetRRCache tries to load the corresponding NameRecord from the database and convert it. +func GetRRCache(domain string, question dns.Type) (*RRCache, error) { + rrCache := &RRCache{ + Domain: domain, + Question: question, + } + + nameRecord, err := GetNameRecord(domain, question.String()) + if err != nil { + return nil, err + } + + rrCache.TTL = nameRecord.TTL + for _, entry := range nameRecord.Answer { + rr, err := dns.NewRR(entry) + if err == nil { + rrCache.Answer = append(rrCache.Answer, rr) + } + } + for _, entry := range nameRecord.Ns { + rr, err := dns.NewRR(entry) + if err == nil { + rrCache.Ns = append(rrCache.Ns, rr) + } + } + for _, entry := range nameRecord.Extra { + rr, err := dns.NewRR(entry) + if err == nil { + rrCache.Extra = append(rrCache.Extra, rr) + } + } + + rrCache.Filtered = nameRecord.Filtered + rrCache.servedFromCache = true + return rrCache, nil +} + +// ServedFromCache marks the RRCache as served from cache. +func (m *RRCache) ServedFromCache() bool { + return m.servedFromCache +} + +// RequestingNew informs that it has expired and new RRs are being fetched. +func (m *RRCache) RequestingNew() bool { + return m.requestingNew +} + +// Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format. +func (m *RRCache) Flags() string { + var s string + if m.servedFromCache { + s += "C" + } + if m.requestingNew { + s += "R" + } + if m.Filtered { + s += "F" + } + + if s != "" { + return fmt.Sprintf(" [%s]", s) + } + return "" +} + +// IsNXDomain returnes whether the result is nxdomain. +func (m *RRCache) IsNXDomain() bool { + return len(m.Answer) == 0 +} + +// Duplicate returns a duplicate of the cache. slices are not copied, but referenced. +func (m *RRCache) Duplicate() *RRCache { + return &RRCache{ + Domain: m.Domain, + Question: m.Question, + Answer: m.Answer, + Ns: m.Ns, + Extra: m.Extra, + TTL: m.TTL, + updated: m.updated, + servedFromCache: m.servedFromCache, + requestingNew: m.requestingNew, + Filtered: m.Filtered, + } +} + +// FilterEntries filters resource records according to the given permission scope. +func (m *RRCache) FilterEntries(internet, lan, host bool) { + var filtered bool + + m.Answer, filtered = filterEntries(m, m.Answer, internet, lan, host) + if filtered { + m.Filtered = true + } + m.Extra, filtered = filterEntries(m, m.Extra, internet, lan, host) + if filtered { + m.Filtered = true + } +} + +func filterEntries(m *RRCache, entries []dns.RR, internet, lan, host bool) (filteredEntries []dns.RR, filtered bool) { + filteredEntries = make([]dns.RR, 0, len(entries)) + var classification int8 + var deletedEntries []string + +entryLoop: + for _, rr := range entries { + + classification = -1 + switch v := rr.(type) { + case *dns.A: + classification = netutils.ClassifyIP(v.A) + case *dns.AAAA: + classification = netutils.ClassifyIP(v.AAAA) + } + + if classification >= 0 { + switch { + case !internet && classification == netutils.Global: + filtered = true + deletedEntries = append(deletedEntries, rr.String()) + continue entryLoop + case !lan && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): + filtered = true + deletedEntries = append(deletedEntries, rr.String()) + continue entryLoop + case !host && classification == netutils.HostLocal: + filtered = true + deletedEntries = append(deletedEntries, rr.String()) + continue entryLoop + } + } + + filteredEntries = append(filteredEntries, rr) + } + + if len(deletedEntries) > 0 { + log.Infof("intel: filtered DNS replies for %s%s: %s (Settings: Int=%v LAN=%v Host=%v)", + m.Domain, + m.Question.String(), + strings.Join(deletedEntries, ", "), + internet, + lan, + host, + ) + } + + return +} diff --git a/intel/special.go b/intel/special.go index 476e457c..c979599b 100644 --- a/intel/special.go +++ b/intel/special.go @@ -5,7 +5,7 @@ package intel import "strings" var ( - localReverseScopes = &[]string{ + localReverseScopes = []string{ ".10.in-addr.arpa.", ".16.172.in-addr.arpa.", ".17.172.in-addr.arpa.", @@ -31,7 +31,8 @@ var ( ".b.e.f.ip6.arpa.", } - specialScopes = &[]string{ + // RFC6761, RFC7686 + specialScopes = []string{ ".example.", ".example.com.", ".example.net.", @@ -42,8 +43,8 @@ var ( } ) -func domainInScopes(fqdn string, list *[]string) bool { - for _, scope := range *list { +func domainInScopes(fqdn string, list []string) bool { + for _, scope := range list { if strings.HasSuffix(fqdn, scope) { return true } diff --git a/main.go b/main.go new file mode 100644 index 00000000..7f04bbb9 --- /dev/null +++ b/main.go @@ -0,0 +1,87 @@ +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "runtime/pprof" + "syscall" + "time" + + "github.com/Safing/portbase/info" + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/modules" + + // include packages here + + _ "github.com/Safing/portbase/api" + _ "github.com/Safing/portbase/database/dbmodule" + _ "github.com/Safing/portbase/database/storage/badger" + _ "github.com/Safing/portmaster/firewall" + _ "github.com/Safing/portmaster/nameserver" + _ "github.com/Safing/portmaster/ui" +) + +var ( + printStackOnExit bool +) + +func init() { + flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down") +} + +func main() { + + // Set Info + info.Set("Portmaster", "0.2.0") + + // Start + err := modules.Start() + if err != nil { + if err == modules.ErrCleanExit { + os.Exit(0) + } else { + err = modules.Shutdown() + if err != nil { + log.Shutdown() + } + os.Exit(1) + } + } + + // Shutdown + // catch interrupt for clean shutdown + signalCh := make(chan os.Signal) + signal.Notify( + signalCh, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + ) + select { + case <-signalCh: + fmt.Println(" ") + log.Warning("main: program was interrupted, shutting down.") + + if printStackOnExit { + fmt.Println("=== PRINTING STACK ===") + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + fmt.Println("=== END STACK ===") + } + + go func() { + time.Sleep(3 * time.Second) + fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") + pprof.Lookup("goroutine").WriteTo(os.Stdout, 2) + os.Exit(1) + }() + modules.Shutdown() + os.Exit(0) + + case <-modules.ShuttingDown(): + } + +} diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index f2a814fc..3cbd4940 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -4,38 +4,60 @@ package nameserver import ( "net" + "time" "github.com/miekg/dns" - "github.com/Safing/safing-core/analytics/algs" - "github.com/Safing/safing-core/intel" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/modules" - "github.com/Safing/safing-core/network" - "github.com/Safing/safing-core/network/netutils" - "github.com/Safing/safing-core/portmaster" + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/modules" + + "github.com/Safing/portmaster/analytics/algs" + "github.com/Safing/portmaster/firewall" + "github.com/Safing/portmaster/intel" + "github.com/Safing/portmaster/network" + "github.com/Safing/portmaster/network/netutils" ) var ( - nameserverModule *modules.Module + localhostIPs []dns.RR ) func init() { - nameserverModule = modules.Register("Nameserver", 128) + modules.Register("nameserver", prep, start, nil, "intel") } -func Start() { +func prep() error { + localhostIPv4, err := dns.NewRR("localhost. 17 IN A 127.0.0.1") + if err != nil { + return err + } + + localhostIPv6, err := dns.NewRR("localhost. 17 IN AAAA ::1") + if err != nil { + return err + } + + localhostIPs = []dns.RR{localhostIPv4, localhostIPv6} + + return nil +} + +func start() error { server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"} dns.HandleFunc(".", handleRequest) - go func() { + go run(server) + return nil +} + +func run(server *dns.Server) { + for { err := server.ListenAndServe() if err != nil { log.Errorf("nameserver: server failed: %s", err) + log.Info("nameserver: restarting server in 10 seconds") + time.Sleep(10 * time.Second) } - }() - // TODO: stop mocking - defer nameserverModule.StopComplete() - <-nameserverModule.Stop + } } func nxDomain(w dns.ResponseWriter, query *dns.Msg) { @@ -47,7 +69,6 @@ func nxDomain(w dns.ResponseWriter, query *dns.Msg) { func handleRequest(w dns.ResponseWriter, query *dns.Msg) { // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain - // TODO: handle securityLevelOff // only process first question, that's how everyone does it. question := query.Question[0] @@ -82,6 +103,14 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { return } + // handle request for localhost + if fqdn == "localhost." { + m := new(dns.Msg) + m.SetReply(query) + m.Answer = localhostIPs + w.WriteMsg(m) + } + // get remote address // start := time.Now() rAddr, ok := w.RemoteAddr().(*net.UDPAddr) @@ -109,19 +138,19 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { // log.Tracef("nameserver: took %s to get connection/process of %s request", time.Now().Sub(timed).String(), fqdn) // check profile before we even get intel and rr - if connection.Verdict == network.UNDECIDED { + if connection.GetVerdict() == network.UNDECIDED { // start = time.Now() - portmaster.DecideOnConnectionBeforeIntel(connection, fqdn) + firewall.DecideOnConnectionBeforeIntel(connection, fqdn) // log.Tracef("nameserver: took %s to make decision", time.Since(start)) } - if connection.Verdict == network.BLOCK || connection.Verdict == network.DROP { + if connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP { nxDomain(w, query) return } // get intel and RRs // start = time.Now() - domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().Profile.SecurityLevel) + domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().ProfileSet().SecurityLevel()) // log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start)) if rrCache == nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains @@ -131,14 +160,16 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { } // set intel + connection.Lock() connection.Intel = domainIntel + connection.Unlock() connection.Save() // do a full check with intel - if connection.Verdict == network.UNDECIDED { - rrCache = portmaster.DecideOnConnectionAfterIntel(connection, fqdn, rrCache) + if connection.GetVerdict() == network.UNDECIDED { + rrCache = firewall.DecideOnConnectionAfterIntel(connection, fqdn, rrCache) } - if rrCache == nil || connection.Verdict == network.BLOCK || connection.Verdict == network.DROP { + if rrCache == nil || connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP { nxDomain(w, query) return } @@ -150,23 +181,27 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { ipInfo, err := intel.GetIPInfo(v.A.String()) if err != nil { ipInfo = &intel.IPInfo{ + IP: v.A.String(), Domains: []string{fqdn}, } - ipInfo.Create(v.A.String()) - } else { - ipInfo.Domains = append(ipInfo.Domains, fqdn) ipInfo.Save() + } else { + if ipInfo.AddDomain(fqdn) { + ipInfo.Save() + } } case *dns.AAAA: ipInfo, err := intel.GetIPInfo(v.AAAA.String()) if err != nil { ipInfo = &intel.IPInfo{ + IP: v.AAAA.String(), Domains: []string{fqdn}, } - ipInfo.Create(v.AAAA.String()) - } else { - ipInfo.Domains = append(ipInfo.Domains, fqdn) ipInfo.Save() + } else { + if ipInfo.AddDomain(fqdn) { + ipInfo.Save() + } } } } diff --git a/nameserver/only/nameserver.go b/nameserver/only/nameserver.go new file mode 100644 index 00000000..4c1edc27 --- /dev/null +++ b/nameserver/only/nameserver.go @@ -0,0 +1,100 @@ +package only + +import ( + "time" + + "github.com/miekg/dns" + + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/modules" + + "github.com/Safing/portmaster/analytics/algs" + "github.com/Safing/portmaster/intel" + "github.com/Safing/portmaster/network/netutils" +) + +func init() { + modules.Register("nameserver", nil, start, nil, "intel") +} + +func start() error { + server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"} + dns.HandleFunc(".", handleRequest) + go run(server) + return nil +} + +func run(server *dns.Server) { + for { + err := server.ListenAndServe() + if err != nil { + log.Errorf("nameserver: server failed: %s", err) + log.Info("nameserver: restarting server in 10 seconds") + time.Sleep(10 * time.Second) + } + } +} + +func nxDomain(w dns.ResponseWriter, query *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeNameError) + w.WriteMsg(m) +} + +func handleRequest(w dns.ResponseWriter, query *dns.Msg) { + + // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain + // TODO: handle securityLevelOff + + // only process first question, that's how everyone does it. + question := query.Question[0] + fqdn := dns.Fqdn(question.Name) + qtype := dns.Type(question.Qtype) + + // use this to time how long it takes process this request + // timed := time.Now() + // defer log.Tracef("nameserver: took %s to handle request for %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String()) + + // check if valid domain name + if !netutils.IsValidFqdn(fqdn) { + log.Tracef("nameserver: domain name %s is invalid, returning nxdomain", fqdn) + nxDomain(w, query) + return + } + + // check for possible DNS tunneling / data transmission + // TODO: improve this + lms := algs.LmsScoreOfDomain(fqdn) + // log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) + if lms < 10 { + log.Tracef("nameserver: possible data tunnel: %s has lms score of %f, returning nxdomain", fqdn, lms) + nxDomain(w, query) + return + } + + // check class + if question.Qclass != dns.ClassINET { + // we only serve IN records, send NXDOMAIN + nxDomain(w, query) + return + } + + // get intel and RRs + // start = time.Now() + _, rrCache := intel.GetIntelAndRRs(fqdn, qtype, 0) + // log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start)) + if rrCache == nil { + // TODO: analyze nxdomain requests, malware could be trying DGA-domains + log.Infof("nameserver: %s is nxdomain", fqdn) + nxDomain(w, query) + return + } + + // reply to query + m := new(dns.Msg) + m.SetReply(query) + m.Answer = rrCache.Answer + m.Ns = rrCache.Ns + m.Extra = rrCache.Extra + w.WriteMsg(m) +} diff --git a/network/clean.go b/network/clean.go index 762aa619..e7e671d2 100644 --- a/network/clean.go +++ b/network/clean.go @@ -5,34 +5,52 @@ package network import ( "time" - "github.com/Safing/safing-core/process" + "github.com/Safing/portmaster/process" ) -func init() { - go cleaner() -} +var ( + cleanerTickDuration = 10 * time.Second + deadLinksTimeout = 3 * time.Minute + thresholdDuration = 3 * time.Minute +) func cleaner() { - time.Sleep(15 * time.Second) for { - markDeadLinks() - purgeDeadFor(5 * time.Minute) - time.Sleep(15 * time.Second) + time.Sleep(cleanerTickDuration) + + cleanLinks() + time.Sleep(2 * time.Second) + cleanConnections() + time.Sleep(2 * time.Second) + cleanProcesses() } } -func markDeadLinks() { +func cleanLinks() { activeIDs := process.GetActiveConnectionIDs() - allLinksLock.RLock() - defer allLinksLock.RUnlock() - now := time.Now().Unix() - var found bool - for key, link := range allLinks { + deleteOlderThan := time.Now().Add(-deadLinksTimeout).Unix() - // skip dead links + // log.Tracef("network.clean: now=%d", now) + // log.Tracef("network.clean: deleteOlderThan=%d", deleteOlderThan) + + linksLock.RLock() + defer linksLock.RUnlock() + + var found bool + for key, link := range links { + + // delete dead links if link.Ended > 0 { + link.Lock() + deleteThis := link.Ended < deleteOlderThan + link.Unlock() + if deleteThis { + // log.Tracef("network.clean: deleted %s", link.DatabaseKey()) + go link.Delete() + } + continue } @@ -48,56 +66,28 @@ func markDeadLinks() { // mark end time if !found { link.Ended = now - link.Save() + // log.Tracef("network.clean: marked %s as ended.", link.DatabaseKey()) + go link.Save() } } } -func purgeDeadFor(age time.Duration) { - connections := make(map[*Connection]bool) - processes := make(map[*process.Process]bool) +func cleanConnections() { + connectionsLock.RLock() + defer connectionsLock.RUnlock() - allLinksLock.Lock() - defer allLinksLock.Unlock() - - // delete old dead links - // make a list of connections without links - ageAgo := time.Now().Add(-1 * age).Unix() - for key, link := range allLinks { - if link.Ended != 0 && link.Ended < ageAgo { - link.Delete() - delete(allLinks, key) - _, ok := connections[link.Connection()] - if !ok { - connections[link.Connection()] = false - } - } else { - connections[link.Connection()] = true + threshold := time.Now().Add(-thresholdDuration).Unix() + for _, conn := range connections { + conn.Lock() + if conn.FirstLinkEstablished < threshold && conn.LinkCount == 0 { + // log.Tracef("network.clean: deleted %s", conn.DatabaseKey()) + go conn.Delete() } + conn.Unlock() } - - // delete connections without links - // make a list of processes without connections - for conn, active := range connections { - if conn != nil { - if !active { - conn.Delete() - _, ok := processes[conn.Process()] - if !ok { - processes[conn.Process()] = false - } - } else { - processes[conn.Process()] = true - } - } - } - - // delete processes without connections - for proc, active := range processes { - if proc != nil && !active { - proc.Delete() - } - } - +} + +func cleanProcesses() { + process.CleanProcessStorage(thresholdDuration) } diff --git a/network/connection.go b/network/connection.go index d0cb84f0..418350ea 100644 --- a/network/connection.go +++ b/network/connection.go @@ -3,147 +3,187 @@ package network import ( + "errors" "fmt" "net" + "sync" "time" - "github.com/Safing/safing-core/database" - "github.com/Safing/safing-core/intel" - "github.com/Safing/safing-core/network/packet" - "github.com/Safing/safing-core/process" - - datastore "github.com/ipfs/go-datastore" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portmaster/intel" + "github.com/Safing/portmaster/network/netutils" + "github.com/Safing/portmaster/network/packet" + "github.com/Safing/portmaster/process" ) // Connection describes a connection between a process and a domain type Connection struct { - database.Base - Domain string - Direction bool - Intel *intel.Intel - process *process.Process - Verdict Verdict - Reason string - Inspect bool + record.Base + sync.Mutex + + Domain string + Direction bool + Intel *intel.Intel + process *process.Process + Verdict Verdict + Reason string + Inspect bool + FirstLinkEstablished int64 + LastLinkEstablished int64 + LinkCount uint } -var connectionModel *Connection // only use this as parameter for database.EnsureModel-like functions +// Process returns the process that owns the connection. +func (conn *Connection) Process() *process.Process { + conn.Lock() + defer conn.Unlock() -func init() { - database.RegisterModel(connectionModel, func() database.Model { return new(Connection) }) + return conn.process } -func (m *Connection) Process() *process.Process { - return m.process +// GetVerdict returns the current verdict. +func (conn *Connection) GetVerdict() Verdict { + conn.Lock() + defer conn.Unlock() + + return conn.Verdict } -// Create creates a new database entry in the database in the default namespace for this object -func (m *Connection) Create(name string) error { - return m.CreateObject(&database.OrphanedConnection, name, m) +// Accept accepts the connection and adds the given reason. +func (conn *Connection) Accept(reason string) { + conn.AddReason(reason) + conn.UpdateVerdict(ACCEPT) } -// CreateInProcessNamespace creates a new database entry in the namespace of the connection's process -func (m *Connection) CreateInProcessNamespace() error { - if m.process != nil { - return m.CreateObject(m.process.GetKey(), m.Domain, m) +// Deny blocks or drops the connection depending on the connection direction and adds the given reason. +func (conn *Connection) Deny(reason string) { + if conn.Direction { + conn.Drop(reason) + } else { + conn.Block(reason) } - return m.CreateObject(&database.OrphanedConnection, m.Domain, m) } -// Save saves the object to the database (It must have been either already created or loaded from the database) -func (m *Connection) Save() error { - return m.SaveObject(m) +// Block blocks the connection and adds the given reason. +func (conn *Connection) Block(reason string) { + conn.AddReason(reason) + conn.UpdateVerdict(BLOCK) } -func (m *Connection) CantSay() { - if m.Verdict != CANTSAY { - m.Verdict = CANTSAY - m.SaveObject(m) +// Drop drops the connection and adds the given reason. +func (conn *Connection) Drop(reason string) { + conn.AddReason(reason) + conn.UpdateVerdict(DROP) +} + +// UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts +func (conn *Connection) UpdateVerdict(newVerdict Verdict) { + conn.Lock() + defer conn.Unlock() + + if newVerdict > conn.Verdict { + conn.Verdict = newVerdict + go conn.Save() } - return -} - -func (m *Connection) Drop() { - if m.Verdict != DROP { - m.Verdict = DROP - m.SaveObject(m) - } - return -} - -func (m *Connection) Block() { - if m.Verdict != BLOCK { - m.Verdict = BLOCK - m.SaveObject(m) - } - return -} - -func (m *Connection) Accept() { - if m.Verdict != ACCEPT { - m.Verdict = ACCEPT - m.SaveObject(m) - } - return } // AddReason adds a human readable string as to why a certain verdict was set in regard to this connection -func (m *Connection) AddReason(newReason string) { - if m.Reason != "" { - m.Reason += " | " +func (conn *Connection) AddReason(reason string) { + if reason == "" { + return } - m.Reason += newReason + + conn.Lock() + defer conn.Unlock() + + if conn.Reason != "" { + conn.Reason += " | " + } + conn.Reason += reason } +// GetConnectionByFirstPacket returns the matching connection from the internal storage. func GetConnectionByFirstPacket(pkt packet.Packet) (*Connection, error) { // get Process proc, direction, err := process.GetProcessByPacket(pkt) if err != nil { return nil, err } + var domain string - // if INBOUND + // Incoming if direction { - connection, err := GetConnectionFromProcessNamespace(proc, "I") - if err != nil { + switch netutils.ClassifyIP(pkt.GetIPHeader().Src) { + case netutils.HostLocal: + domain = IncomingHost + case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: + domain = IncomingLAN + case netutils.Global, netutils.GlobalMulticast: + domain = IncomingInternet + case netutils.Invalid: + domain = IncomingInvalid + } + + connection, ok := GetConnection(proc.Pid, domain) + if !ok { connection = &Connection{ - Domain: "I", - Direction: true, + Domain: domain, + Direction: Inbound, process: proc, Inspect: true, FirstLinkEstablished: time.Now().Unix(), } } + connection.process.AddConnection() return connection, nil } // get domain ipinfo, err := intel.GetIPInfo(pkt.FmtRemoteIP()) + + // PeerToPeer if err != nil { // if no domain could be found, it must be a direct connection - connection, err := GetConnectionFromProcessNamespace(proc, "D") - if err != nil { + + switch netutils.ClassifyIP(pkt.GetIPHeader().Dst) { + case netutils.HostLocal: + domain = PeerHost + case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: + domain = PeerLAN + case netutils.Global, netutils.GlobalMulticast: + domain = PeerInternet + case netutils.Invalid: + domain = PeerInvalid + } + + connection, ok := GetConnection(proc.Pid, domain) + if !ok { connection = &Connection{ - Domain: "D", + Domain: domain, + Direction: Outbound, process: proc, Inspect: true, FirstLinkEstablished: time.Now().Unix(), } } + connection.process.AddConnection() return connection, nil } + // To Domain // FIXME: how to handle multiple possible domains? - connection, err := GetConnectionFromProcessNamespace(proc, ipinfo.Domains[0]) - if err != nil { + connection, ok := GetConnection(proc.Pid, ipinfo.Domains[0]) + if !ok { connection = &Connection{ Domain: ipinfo.Domains[0], + Direction: Outbound, process: proc, Inspect: true, FirstLinkEstablished: time.Now().Unix(), } } + connection.process.AddConnection() return connection, nil } @@ -154,6 +194,7 @@ var ( dnsPort uint16 = 53 ) +// GetConnectionByDNSRequest returns the matching connection from the internal storage. func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection, error) { // get Process proc, err := process.GetProcessByEndpoints(ip, port, dnsAddress, dnsPort, packet.UDP) @@ -161,70 +202,124 @@ func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection return nil, err } - connection, err := GetConnectionFromProcessNamespace(proc, fqdn) - if err != nil { + connection, ok := GetConnection(proc.Pid, fqdn) + if !ok { connection = &Connection{ Domain: fqdn, process: proc, Inspect: true, } - connection.CreateInProcessNamespace() + connection.process.AddConnection() + connection.Save() } return connection, nil } -// GetConnection fetches a Connection from the database from the default namespace for this object -func GetConnection(name string) (*Connection, error) { - return GetConnectionFromNamespace(&database.OrphanedConnection, name) +// GetConnection fetches a connection object from the internal storage. +func GetConnection(pid int, domain string) (conn *Connection, ok bool) { + connectionsLock.RLock() + defer connectionsLock.RUnlock() + conn, ok = connections[fmt.Sprintf("%d/%s", pid, domain)] + return } -// GetConnectionFromProcessNamespace fetches a Connection from the namespace of its process -func GetConnectionFromProcessNamespace(process *process.Process, domain string) (*Connection, error) { - return GetConnectionFromNamespace(process.GetKey(), domain) +func (conn *Connection) makeKey() string { + return fmt.Sprintf("%d/%s", conn.process.Pid, conn.Domain) } -// GetConnectionFromNamespace fetches a Connection form the database, but from a custom namespace -func GetConnectionFromNamespace(namespace *datastore.Key, name string) (*Connection, error) { - object, err := database.GetAndEnsureModel(namespace, name, connectionModel) - if err != nil { - return nil, err +// Save saves the connection object in the storage and propagates the change. +func (conn *Connection) Save() error { + conn.Lock() + defer conn.Unlock() + + if conn.process == nil { + return errors.New("cannot save connection without process") } - model, ok := object.(*Connection) + + if !conn.KeyIsSet() { + conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Domain)) + conn.CreateMeta() + } + + key := conn.makeKey() + connectionsLock.RLock() + _, ok := connections[key] + connectionsLock.RUnlock() + if !ok { - return nil, database.NewMismatchError(object, connectionModel) + connectionsLock.Lock() + connections[key] = conn + connectionsLock.Unlock() } - return model, nil + + go dbController.PushUpdate(conn) + return nil } -func (m *Connection) AddLink(link *Link, pkt packet.Packet) { - link.connection = m - link.Verdict = m.Verdict - link.Inspect = m.Inspect - if m.FirstLinkEstablished == 0 { - m.FirstLinkEstablished = time.Now().Unix() - m.Save() - } - link.CreateInConnectionNamespace(pkt.GetConnectionID()) +// Delete deletes a connection from the storage and propagates the change. +func (conn *Connection) Delete() { + conn.Lock() + defer conn.Unlock() + + connectionsLock.Lock() + delete(connections, conn.makeKey()) + connectionsLock.Unlock() + + conn.Meta().Delete() + go dbController.PushUpdate(conn) + conn.process.RemoveConnection() + go conn.process.Save() } -// FORMATTING +// AddLink applies the connection to the link and increases sets counter and timestamps. +func (conn *Connection) AddLink(link *Link) { + link.Lock() + link.connection = conn + link.Verdict = conn.Verdict + link.Inspect = conn.Inspect + link.Unlock() + link.Save() -func (m *Connection) String() string { - switch m.Domain { - case "I": - if m.process == nil { + conn.Lock() + conn.LinkCount++ + conn.LastLinkEstablished = time.Now().Unix() + if conn.FirstLinkEstablished == 0 { + conn.FirstLinkEstablished = conn.LastLinkEstablished + } + conn.Unlock() + conn.Save() +} + +// RemoveLink lowers the link counter by one. +func (conn *Connection) RemoveLink() { + conn.Lock() + defer conn.Unlock() + + if conn.LinkCount > 0 { + conn.LinkCount-- + } +} + +// String returns a string representation of Connection. +func (conn *Connection) String() string { + conn.Lock() + defer conn.Unlock() + + switch conn.Domain { + case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid: + if conn.process == nil { return "? <- *" } - return fmt.Sprintf("%s <- *", m.process.String()) - case "D": - if m.process == nil { + return fmt.Sprintf("%s <- *", conn.process.String()) + case PeerHost, PeerLAN, PeerInternet, PeerInvalid: + if conn.process == nil { return "? -> *" } - return fmt.Sprintf("%s -> *", m.process.String()) + return fmt.Sprintf("%s -> *", conn.process.String()) default: - if m.process == nil { - return fmt.Sprintf("? -> %s", m.Domain) + if conn.process == nil { + return fmt.Sprintf("? -> %s", conn.Domain) } - return fmt.Sprintf("%s -> %s", m.process.String(), m.Domain) + return fmt.Sprintf("%s -> %s", conn.process.String(), conn.Domain) } } diff --git a/network/database.go b/network/database.go new file mode 100644 index 00000000..28147abf --- /dev/null +++ b/network/database.go @@ -0,0 +1,122 @@ +package network + +import ( + "strconv" + "strings" + "sync" + + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/iterator" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/database/storage" + "github.com/Safing/portmaster/process" +) + +var ( + links = make(map[string]*Link) + linksLock sync.RWMutex + connections = make(map[string]*Connection) + connectionsLock sync.RWMutex + + dbController *database.Controller +) + +// StorageInterface provices a storage.Interface to the configuration manager. +type StorageInterface struct { + storage.InjectBase +} + +// Get returns a database record. +func (s *StorageInterface) Get(key string) (record.Record, error) { + + splitted := strings.Split(key, "/") + switch splitted[0] { + case "tree": + switch len(splitted) { + case 2: + pid, err := strconv.Atoi(splitted[1]) + if err == nil { + proc, ok := process.GetProcessFromStorage(pid) + if ok { + return proc, nil + } + } + case 3: + connectionsLock.RLock() + defer connectionsLock.RUnlock() + conn, ok := connections[splitted[2]] + if ok { + return conn, nil + } + case 4: + linksLock.RLock() + defer linksLock.RUnlock() + link, ok := links[splitted[3]] + if ok { + return link, nil + } + } + } + + return nil, storage.ErrNotFound +} + +// Query returns a an iterator for the supplied query. +func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + it := iterator.New() + go s.processQuery(q, it) + // TODO: check local and internal + + return it, nil +} + +func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { + // processes + for _, proc := range process.All() { + if strings.HasPrefix(proc.DatabaseKey(), q.DatabaseKeyPrefix()) { + it.Next <- proc + } + } + + // connections + connectionsLock.RLock() + for _, conn := range connections { + if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) { + it.Next <- conn + } + } + connectionsLock.RUnlock() + + // links + linksLock.RLock() + for _, link := range links { + if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) { + it.Next <- link + } + } + linksLock.RUnlock() + + it.Finish(nil) +} + +func registerAsDatabase() error { + _, err := database.Register(&database.Database{ + Name: "network", + Description: "Network and Firewall Data", + StorageType: "injected", + PrimaryAPI: "", + }) + if err != nil { + return err + } + + controller, err := database.InjectDatabase("network", &StorageInterface{}) + if err != nil { + return err + } + + dbController = controller + process.SetDBController(dbController) + return nil +} diff --git a/network/environment/addresses.go b/network/environment/addresses.go index 66a962e7..b8828d0f 100644 --- a/network/environment/addresses.go +++ b/network/environment/addresses.go @@ -4,7 +4,7 @@ import ( "net" "strings" - "github.com/Safing/safing-core/network/netutils" + "github.com/Safing/portmaster/network/netutils" ) func GetAssignedAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) { diff --git a/network/environment/dbus_linux.go b/network/environment/dbus_linux.go index 7ffe0905..8e6751c3 100644 --- a/network/environment/dbus_linux.go +++ b/network/environment/dbus_linux.go @@ -24,6 +24,9 @@ func getNameserversFromDbus() ([]Nameserver, error) { var nameservers []Nameserver var err error + dbusConnLock.Lock() + defer dbusConnLock.Unlock() + if dbusConn == nil { dbusConn, err = dbus.SystemBus() } @@ -158,6 +161,9 @@ func getNameserversFromDbus() ([]Nameserver, error) { func getConnectivityStateFromDbus() (uint8, error) { var err error + dbusConnLock.Lock() + defer dbusConnLock.Unlock() + if dbusConn == nil { dbusConn, err = dbus.SystemBus() } diff --git a/network/environment/environment.go b/network/environment/environment.go index 4c4926bd..e67065aa 100644 --- a/network/environment/environment.go +++ b/network/environment/environment.go @@ -11,7 +11,7 @@ import ( "sync/atomic" "time" - "github.com/Safing/safing-core/log" + "github.com/Safing/portbase/log" ) // TODO: find a good way to identify a network diff --git a/network/environment/environment_darwin.go b/network/environment/environment_darwin.go new file mode 100644 index 00000000..40cd6ee8 --- /dev/null +++ b/network/environment/environment_darwin.go @@ -0,0 +1,27 @@ +package environment + +import "net" + +func Nameservers() []Nameserver { + return nil +} + +func Gateways() []*net.IP { + return nil +} + +// TODO: implement using +// ifconfig +// scutil --nwi +// scutil --proxy +// networksetup -listallnetworkservices +// networksetup -listnetworkserviceorder +// networksetup -getdnsservers "Wi-Fi" +// networksetup -getsearchdomains +// networksetup -getftpproxy +// networksetup -getwebproxy +// networksetup -getsecurewebproxy +// networksetup -getstreamingproxy +// networksetup -getgopherproxy +// networksetup -getsocksfirewallproxy +// route -n get default diff --git a/network/environment/environment_linux.go b/network/environment/environment_linux.go index f739e123..76be75c5 100644 --- a/network/environment/environment_linux.go +++ b/network/environment/environment_linux.go @@ -12,8 +12,8 @@ import ( "github.com/miekg/dns" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/network/netutils" + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/network/netutils" ) // Gateways returns the currently active gateways diff --git a/network/environment/environment_test.go b/network/environment/environment_test.go index 3f4aded8..c4db1685 100644 --- a/network/environment/environment_test.go +++ b/network/environment/environment_test.go @@ -1,4 +1,4 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. +// +build linux package environment diff --git a/network/environment/location.go b/network/environment/location.go index 57a03263..cccc3c09 100644 --- a/network/environment/location.go +++ b/network/environment/location.go @@ -8,9 +8,10 @@ import ( "log" "net" "os" - "github.com/Safing/safing-core/network/netutils" "time" + "github.com/Safing/portmaster/network/netutils" + "golang.org/x/net/icmp" "golang.org/x/net/ipv4" ) @@ -99,7 +100,7 @@ next: if ip == nil { return nil, errors.New(fmt.Sprintf("failed to parse IP: %s", peer.String())) } - if !netutils.IPIsLocal(ip) { + if !netutils.IPIsLAN(ip) { return ip, nil } continue next diff --git a/network/link.go b/network/link.go index b26c0451..a8deae6b 100644 --- a/network/link.go +++ b/network/link.go @@ -3,28 +3,30 @@ package network import ( + "errors" "fmt" "sync" "time" - datastore "github.com/ipfs/go-datastore" - - "github.com/Safing/safing-core/database" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/network/packet" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/network/packet" ) +// FirewallHandler defines the function signature for a firewall handle function type FirewallHandler func(pkt packet.Packet, link *Link) var ( - linkTimeout = 10 * time.Minute - allLinks = make(map[string]*Link) - allLinksLock sync.RWMutex + linkTimeout = 10 * time.Minute ) -// Link describes an distinct physical connection (e.g. TCP connection) - like an instance - of a Connection +// Link describes a distinct physical connection (e.g. TCP connection) - like an instance - of a Connection. type Link struct { - database.Base + record.Base + sync.Mutex + + ID string + Verdict Verdict Reason string Tunneled bool @@ -32,180 +34,322 @@ type Link struct { Inspect bool Started int64 Ended int64 - connection *Connection RemoteAddress string - ActiveInspectors []bool `json:"-" bson:"-"` - InspectorData map[uint8]interface{} `json:"-" bson:"-"` pktQueue chan packet.Packet firewallHandler FirewallHandler -} + connection *Connection -var linkModel *Link // only use this as parameter for database.EnsureModel-like functions - -func init() { - database.RegisterModel(linkModel, func() database.Model { return new(Link) }) + activeInspectors []bool + inspectorData map[uint8]interface{} } // Connection returns the Connection the Link is part of -func (m *Link) Connection() *Connection { - return m.connection +func (link *Link) Connection() *Connection { + link.Lock() + defer link.Unlock() + + return link.connection +} + +// GetVerdict returns the current verdict. +func (link *Link) GetVerdict() Verdict { + link.Lock() + defer link.Unlock() + + return link.Verdict } // FirewallHandlerIsSet returns whether a firewall handler is set or not -func (m *Link) FirewallHandlerIsSet() bool { - return m.firewallHandler != nil +func (link *Link) FirewallHandlerIsSet() bool { + link.Lock() + defer link.Unlock() + + return link.firewallHandler != nil } // SetFirewallHandler sets the firewall handler for this link -func (m *Link) SetFirewallHandler(handler FirewallHandler) { - if m.firewallHandler == nil { - m.firewallHandler = handler - m.pktQueue = make(chan packet.Packet, 1000) - go m.packetHandler() +func (link *Link) SetFirewallHandler(handler FirewallHandler) { + link.Lock() + defer link.Unlock() + + if link.firewallHandler == nil { + link.firewallHandler = handler + link.pktQueue = make(chan packet.Packet, 1000) + go link.packetHandler() return } - m.firewallHandler = handler + link.firewallHandler = handler } // StopFirewallHandler unsets the firewall handler -func (m *Link) StopFirewallHandler() { - m.pktQueue <- nil +func (link *Link) StopFirewallHandler() { + link.Lock() + link.firewallHandler = nil + link.Unlock() + link.pktQueue <- nil } // HandlePacket queues packet of Link for handling -func (m *Link) HandlePacket(pkt packet.Packet) { - if m.firewallHandler != nil { - m.pktQueue <- pkt +func (link *Link) HandlePacket(pkt packet.Packet) { + link.Lock() + defer link.Unlock() + + if link.firewallHandler != nil { + link.pktQueue <- pkt return } - log.Criticalf("network: link %s does not have a firewallHandler, maybe its a copy, dropping packet", m) + log.Criticalf("network: link %s does not have a firewallHandler, dropping packet", link) pkt.Drop() } +// Accept accepts the link and adds the given reason. +func (link *Link) Accept(reason string) { + link.AddReason(reason) + link.UpdateVerdict(ACCEPT) +} + +// Deny blocks or drops the link depending on the connection direction and adds the given reason. +func (link *Link) Deny(reason string) { + if link.connection != nil && link.connection.Direction { + link.Drop(reason) + } else { + link.Block(reason) + } +} + +// Block blocks the link and adds the given reason. +func (link *Link) Block(reason string) { + link.AddReason(reason) + link.UpdateVerdict(BLOCK) +} + +// Drop drops the link and adds the given reason. +func (link *Link) Drop(reason string) { + link.AddReason(reason) + link.UpdateVerdict(DROP) +} + +// RerouteToNameserver reroutes the link to the portmaster nameserver. +func (link *Link) RerouteToNameserver() { + link.UpdateVerdict(RerouteToNameserver) +} + +// RerouteToTunnel reroutes the link to the tunnel entrypoint and adds the given reason for accepting the connection. +func (link *Link) RerouteToTunnel(reason string) { + link.AddReason(reason) + link.UpdateVerdict(RerouteToTunnel) +} + // UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts -func (m *Link) UpdateVerdict(newVerdict Verdict) { - if newVerdict > m.Verdict { - m.Verdict = newVerdict - m.Save() +func (link *Link) UpdateVerdict(newVerdict Verdict) { + link.Lock() + defer link.Unlock() + + if newVerdict > link.Verdict { + link.Verdict = newVerdict + go link.Save() } } // AddReason adds a human readable string as to why a certain verdict was set in regard to this link -func (m *Link) AddReason(newReason string) { - if m.Reason != "" { - m.Reason += " | " +func (link *Link) AddReason(reason string) { + if reason == "" { + return } - m.Reason += newReason + + link.Lock() + defer link.Unlock() + + if link.Reason != "" { + link.Reason += " | " + } + link.Reason += reason } // packetHandler sequentially handles queued packets -func (m *Link) packetHandler() { +func (link *Link) packetHandler() { for { - pkt := <-m.pktQueue + pkt := <-link.pktQueue if pkt == nil { - break + return + } + link.Lock() + fwH := link.firewallHandler + link.Unlock() + if fwH != nil { + fwH(pkt, link) + } else { + link.ApplyVerdict(pkt) } - m.firewallHandler(pkt, m) } - m.firewallHandler = nil } -// Create creates a new database entry in the database in the default namespace for this object -func (m *Link) Create(name string) error { - m.CreateShallow(name) - return m.CreateObject(&database.OrphanedLink, name, m) -} +// ApplyVerdict appies the link verdict to a packet. +func (link *Link) ApplyVerdict(pkt packet.Packet) { + link.Lock() + defer link.Unlock() -// Create creates a new database entry in the database in the default namespace for this object -func (m *Link) CreateShallow(name string) { - allLinksLock.Lock() - allLinks[name] = m - allLinksLock.Unlock() -} - -// CreateWithDefaultKey creates a new database entry in the database in the default namespace for this object using the default key -func (m *Link) CreateInConnectionNamespace(name string) error { - if m.connection != nil { - return m.CreateObject(m.connection.GetKey(), name, m) + if link.VerdictPermanent { + switch link.Verdict { + case ACCEPT: + pkt.PermanentAccept() + case BLOCK: + pkt.PermanentBlock() + case DROP: + pkt.PermanentDrop() + case RerouteToNameserver: + pkt.RerouteToNameserver() + case RerouteToTunnel: + pkt.RerouteToTunnel() + default: + pkt.Drop() + } + } else { + switch link.Verdict { + case ACCEPT: + pkt.Accept() + case BLOCK: + pkt.Block() + case DROP: + pkt.Drop() + case RerouteToNameserver: + pkt.RerouteToNameserver() + case RerouteToTunnel: + pkt.RerouteToTunnel() + default: + pkt.Drop() + } } - return m.CreateObject(&database.OrphanedLink, name, m) } -// Save saves the object to the database (It must have been either already created or loaded from the database) -func (m *Link) Save() error { - return m.SaveObject(m) +// Save saves the link object in the storage and propagates the change. +func (link *Link) Save() error { + link.Lock() + defer link.Unlock() + + if link.connection == nil { + return errors.New("cannot save link without connection") + } + + if !link.KeyIsSet() { + link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.connection.Process().Pid, link.connection.Domain, link.ID)) + link.CreateMeta() + } + + linksLock.RLock() + _, ok := links[link.ID] + linksLock.RUnlock() + + if !ok { + linksLock.Lock() + links[link.ID] = link + linksLock.Unlock() + } + + go dbController.PushUpdate(link) + return nil +} + +// Delete deletes a link from the storage and propagates the change. +func (link *Link) Delete() { + link.Lock() + defer link.Unlock() + + linksLock.Lock() + delete(links, link.ID) + linksLock.Unlock() + + link.Meta().Delete() + go dbController.PushUpdate(link) + link.connection.RemoveLink() + go link.connection.Save() } // GetLink fetches a Link from the database from the default namespace for this object -func GetLink(name string) (*Link, error) { - allLinksLock.RLock() - link, ok := allLinks[name] - allLinksLock.RUnlock() - if !ok { - return nil, database.ErrNotFound - } - return link, nil - // return GetLinkFromNamespace(&database.RunningLink, name) -} +func GetLink(id string) (*Link, bool) { + linksLock.RLock() + defer linksLock.RUnlock() -func SaveInCache(link *Link) { - -} - -// GetLinkFromNamespace fetches a Link form the database, but from a custom namespace -func GetLinkFromNamespace(namespace *datastore.Key, name string) (*Link, error) { - object, err := database.GetAndEnsureModel(namespace, name, linkModel) - if err != nil { - return nil, err - } - model, ok := object.(*Link) - if !ok { - return nil, database.NewMismatchError(object, linkModel) - } - return model, nil + link, ok := links[id] + return link, ok } // GetOrCreateLinkByPacket returns the associated Link for a packet and a bool expressing if the Link was newly created func GetOrCreateLinkByPacket(pkt packet.Packet) (*Link, bool) { - link, err := GetLink(pkt.GetConnectionID()) - if err != nil { - return CreateLinkFromPacket(pkt), true + link, ok := GetLink(pkt.GetLinkID()) + if ok { + return link, false } - return link, false + return CreateLinkFromPacket(pkt), true } -// CreateLinkFromPacket creates a new Link based on Packet. The Link is shallowly saved and SHOULD be saved to the database as soon more information is available +// CreateLinkFromPacket creates a new Link based on Packet. func CreateLinkFromPacket(pkt packet.Packet) *Link { link := &Link{ + ID: pkt.GetLinkID(), Verdict: UNDECIDED, Started: time.Now().Unix(), RemoteAddress: pkt.FmtRemoteAddress(), } - link.CreateShallow(pkt.GetConnectionID()) return link } -// FORMATTING -func (m *Link) String() string { - if m.connection == nil { - return fmt.Sprintf("? <-> %s", m.RemoteAddress) +// GetActiveInspectors returns the list of active inspectors. +func (link *Link) GetActiveInspectors() []bool { + link.Lock() + defer link.Unlock() + + return link.activeInspectors +} + +// SetActiveInspectors sets the list of active inspectors. +func (link *Link) SetActiveInspectors(new []bool) { + link.Lock() + defer link.Unlock() + + link.activeInspectors = new +} + +// GetInspectorData returns the list of inspector data. +func (link *Link) GetInspectorData() map[uint8]interface{} { + link.Lock() + defer link.Unlock() + + return link.inspectorData +} + +// SetInspectorData set the list of inspector data. +func (link *Link) SetInspectorData(new map[uint8]interface{}) { + link.Lock() + defer link.Unlock() + + link.inspectorData = new +} + +// String returns a string representation of Link. +func (link *Link) String() string { + link.Lock() + defer link.Unlock() + + if link.connection == nil { + return fmt.Sprintf("? <-> %s", link.RemoteAddress) } - switch m.connection.Domain { + switch link.connection.Domain { case "I": - if m.connection.process == nil { - return fmt.Sprintf("? <- %s", m.RemoteAddress) + if link.connection.process == nil { + return fmt.Sprintf("? <- %s", link.RemoteAddress) } - return fmt.Sprintf("%s <- %s", m.connection.process.String(), m.RemoteAddress) + return fmt.Sprintf("%s <- %s", link.connection.process.String(), link.RemoteAddress) case "D": - if m.connection.process == nil { - return fmt.Sprintf("? -> %s", m.RemoteAddress) + if link.connection.process == nil { + return fmt.Sprintf("? -> %s", link.RemoteAddress) } - return fmt.Sprintf("%s -> %s", m.connection.process.String(), m.RemoteAddress) + return fmt.Sprintf("%s -> %s", link.connection.process.String(), link.RemoteAddress) default: - if m.connection.process == nil { - return fmt.Sprintf("? -> %s (%s)", m.connection.Domain, m.RemoteAddress) + if link.connection.process == nil { + return fmt.Sprintf("? -> %s (%s)", link.connection.Domain, link.RemoteAddress) } - return fmt.Sprintf("%s to %s (%s)", m.connection.process.String(), m.connection.Domain, m.RemoteAddress) + return fmt.Sprintf("%s to %s (%s)", link.connection.process.String(), link.connection.Domain, link.RemoteAddress) } } diff --git a/network/module.go b/network/module.go new file mode 100644 index 00000000..806d7ff9 --- /dev/null +++ b/network/module.go @@ -0,0 +1,14 @@ +package network + +import ( + "github.com/Safing/portbase/modules" +) + +func init() { + modules.Register("network", nil, start, nil, "database") +} + +func start() error { + go cleaner() + return registerAsDatabase() +} diff --git a/network/netutils/cleandns.go b/network/netutils/cleandns.go index 67c9a1e3..66bcb24a 100644 --- a/network/netutils/cleandns.go +++ b/network/netutils/cleandns.go @@ -11,6 +11,7 @@ var ( cleanDomainRegex = regexp.MustCompile("^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\\.[a-z]{2,}\\.)$") ) +// IsValidFqdn returns whether the given string is a valid fqdn. func IsValidFqdn(fqdn string) bool { return cleanDomainRegex.MatchString(fqdn) } diff --git a/network/netutils/ip.go b/network/netutils/ip.go index 96dc92a9..780c6901 100644 --- a/network/netutils/ip.go +++ b/network/netutils/ip.go @@ -4,95 +4,101 @@ package netutils import "net" -// IP types +// IP classifications const ( - hostLocal int8 = iota - linkLocal - siteLocal - global - localMulticast - globalMulticast - invalid + HostLocal int8 = iota + LinkLocal + SiteLocal + Global + LocalMulticast + GlobalMulticast + Invalid ) -func classifyAddress(ip net.IP) int8 { +// ClassifyIP returns the classification for the given IP address. +func ClassifyIP(ip net.IP) int8 { if ip4 := ip.To4(); ip4 != nil { // IPv4 switch { case ip4[0] == 127: // 127.0.0.0/8 - return hostLocal + return HostLocal case ip4[0] == 169 && ip4[1] == 254: // 169.254.0.0/16 - return linkLocal + return LinkLocal case ip4[0] == 10: // 10.0.0.0/8 - return siteLocal + return SiteLocal case ip4[0] == 172 && ip4[1]&0xf0 == 16: // 172.16.0.0/12 - return siteLocal + return SiteLocal case ip4[0] == 192 && ip4[1] == 168: // 192.168.0.0/16 - return siteLocal + return SiteLocal case ip4[0] == 224: // 224.0.0.0/8 - return localMulticast + return LocalMulticast case ip4[0] >= 225 && ip4[0] <= 239: // 225.0.0.0/8 - 239.0.0.0/8 - return globalMulticast + return GlobalMulticast case ip4[0] >= 240: // 240.0.0.0/8 - 255.0.0.0/8 - return invalid + return Invalid default: - return global + return Global } } else if len(ip) == net.IPv6len { // IPv6 switch { case ip.Equal(net.IPv6loopback): - return hostLocal + return HostLocal case ip[0]&0xfe == 0xfc: // fc00::/7 - return siteLocal + return SiteLocal case ip[0] == 0xfe && ip[1]&0xc0 == 0x80: // fe80::/10 - return linkLocal + return LinkLocal case ip[0] == 0xff && ip[1] <= 0x05: // ff00::/16 - ff05::/16 - return localMulticast + return LocalMulticast case ip[0] == 0xff: // other ff00::/8 - return globalMulticast + return GlobalMulticast default: - return global + return Global } } - return invalid + return Invalid } -// IPIsLocal returns true if the given IP is a site-local or link-local address -func IPIsLocal(ip net.IP) bool { - switch classifyAddress(ip) { - case siteLocal: +// IPIsLocalhost returns whether the IP refers to the host itself. +func IPIsLocalhost(ip net.IP) bool { + return ClassifyIP(ip) == HostLocal +} + +// IPIsLAN returns true if the given IP is a site-local or link-local address. +func IPIsLAN(ip net.IP) bool { + switch ClassifyIP(ip) { + case SiteLocal: return true - case linkLocal: + case LinkLocal: return true default: return false } } -// IPIsGlobal returns true if the given IP is a global address +// IPIsGlobal returns true if the given IP is a global address. func IPIsGlobal(ip net.IP) bool { - return classifyAddress(ip) == global + return ClassifyIP(ip) == Global } -// IPIsLinkLocal returns true if the given IP is a link-local address +// IPIsLinkLocal returns true if the given IP is a link-local address. func IPIsLinkLocal(ip net.IP) bool { - return classifyAddress(ip) == linkLocal + return ClassifyIP(ip) == LinkLocal } -// IPIsSiteLocal returns true if the given IP is a site-local address +// IPIsSiteLocal returns true if the given IP is a site-local address. func IPIsSiteLocal(ip net.IP) bool { - return classifyAddress(ip) == siteLocal + return ClassifyIP(ip) == SiteLocal } diff --git a/network/netutils/ip_test.go b/network/netutils/ip_test.go index a479f7ab..b7ede062 100644 --- a/network/netutils/ip_test.go +++ b/network/netutils/ip_test.go @@ -6,14 +6,14 @@ import ( ) func TestIPClassification(t *testing.T) { - testClassification(t, net.IPv4(71, 87, 113, 211), global) - testClassification(t, net.IPv4(127, 0, 0, 1), hostLocal) - testClassification(t, net.IPv4(127, 255, 255, 1), hostLocal) - testClassification(t, net.IPv4(192, 168, 172, 24), siteLocal) + testClassification(t, net.IPv4(71, 87, 113, 211), Global) + testClassification(t, net.IPv4(127, 0, 0, 1), HostLocal) + testClassification(t, net.IPv4(127, 255, 255, 1), HostLocal) + testClassification(t, net.IPv4(192, 168, 172, 24), SiteLocal) } func testClassification(t *testing.T, ip net.IP, expectedClassification int8) { - c := classifyAddress(ip) + c := ClassifyIP(ip) if c != expectedClassification { t.Errorf("%s is %s, expected %s", ip, classificationString(c), classificationString(expectedClassification)) } @@ -21,19 +21,19 @@ func testClassification(t *testing.T, ip net.IP, expectedClassification int8) { func classificationString(c int8) string { switch c { - case hostLocal: + case HostLocal: return "hostLocal" - case linkLocal: + case LinkLocal: return "linkLocal" - case siteLocal: + case SiteLocal: return "siteLocal" - case global: + case Global: return "global" - case localMulticast: + case LocalMulticast: return "localMulticast" - case globalMulticast: + case GlobalMulticast: return "globalMulticast" - case invalid: + case Invalid: return "invalid" default: return "unknown" diff --git a/network/packet/packet.go b/network/packet/packet.go index 3d3ad336..16dcf166 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -106,10 +106,10 @@ type TCPUDPHeader struct { } type PacketBase struct { - connectionID string - Direction bool - InTunnel bool - Payload []byte + linkID string + Direction bool + InTunnel bool + Payload []byte *IPHeader *TCPUDPHeader } @@ -146,25 +146,25 @@ func (pkt *PacketBase) IPVersion() IPVersion { return pkt.Version } -func (pkt *PacketBase) GetConnectionID() string { - if pkt.connectionID == "" { - pkt.createConnectionID() +func (pkt *PacketBase) GetLinkID() string { + if pkt.linkID == "" { + pkt.createLinkID() } - return pkt.connectionID + return pkt.linkID } -func (pkt *PacketBase) createConnectionID() { +func (pkt *PacketBase) createLinkID() { if pkt.IPHeader.Protocol == TCP || pkt.IPHeader.Protocol == UDP { if pkt.Direction { - pkt.connectionID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Dst, pkt.DstPort, pkt.Src, pkt.SrcPort) + pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Dst, pkt.DstPort, pkt.Src, pkt.SrcPort) } else { - pkt.connectionID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort) + pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort) } } else { if pkt.Direction { - pkt.connectionID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Dst, pkt.Src) + pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Dst, pkt.Src) } else { - pkt.connectionID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Src, pkt.Dst) + pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Src, pkt.Dst) } } } @@ -299,7 +299,7 @@ type Packet interface { IsOutbound() bool SetInbound() SetOutbound() - GetConnectionID() string + GetLinkID() string IPVersion() IPVersion // MATCHING diff --git a/network/reference/protocols.go b/network/reference/protocols.go new file mode 100644 index 00000000..326ea885 --- /dev/null +++ b/network/reference/protocols.go @@ -0,0 +1,45 @@ +package reference + +import "strconv" + +var ( + protocolNames = map[uint8]string{ + 1: "ICMP", + 2: "IGMP", + 6: "TCP", + 17: "UDP", + 27: "RDP", + 58: "ICMPv6", + 33: "DCCP", + 136: "UDPLite", + } + + protocolNumbers = map[string]uint8{ + "ICMP": 1, + "IGMP": 2, + "TCP": 6, + "UDP": 17, + "RDP": 27, + "DCCP": 33, + "ICMPv6": 58, + "UDPLite": 136, + } +) + +// GetProtocolName returns the name of a IP protocol number. +func GetProtocolName(protocol uint8) (name string) { + name, ok := protocolNames[protocol] + if ok { + return name + } + return strconv.Itoa(int(protocol)) +} + +// GetProtocolNumber returns the number of a IP protocol name. +func GetProtocolNumber(protocol string) (number uint8, ok bool) { + number, ok = protocolNumbers[protocol] + if ok { + return number, true + } + return 0, false +} diff --git a/network/status.go b/network/status.go index a9976945..a02756ae 100644 --- a/network/status.go +++ b/network/status.go @@ -2,20 +2,34 @@ package network -// Status describes the status of a connection. +// Verdict describes the decision made about a connection or link. type Verdict uint8 // List of values a Status can have const ( // UNDECIDED is the default status of new connections UNDECIDED Verdict = iota - CANTSAY ACCEPT BLOCK DROP + RerouteToNameserver + RerouteToTunnel ) +// Packer Directions const ( Inbound = true Outbound = false ) + +// Non-Domain Connections +const ( + IncomingHost = "IH" + IncomingLAN = "IL" + IncomingInternet = "II" + IncomingInvalid = "IX" + PeerHost = "PH" + PeerLAN = "PL" + PeerInternet = "PI" + PeerInvalid = "PX" +) diff --git a/network/unknown.go b/network/unknown.go new file mode 100644 index 00000000..fc31aca8 --- /dev/null +++ b/network/unknown.go @@ -0,0 +1,31 @@ +package network + +import "github.com/Safing/portmaster/process" + +// Static reasons +const ( + ReasonUnknownProcess = "unknown connection owner: process could not be found" +) + +var ( + UnknownDirectConnection = &Connection{ + Domain: "PI", + Direction: Outbound, + Verdict: DROP, + Reason: ReasonUnknownProcess, + process: process.UnknownProcess, + } + + UnknownIncomingConnection = &Connection{ + Domain: "II", + Direction: Inbound, + Verdict: DROP, + Reason: ReasonUnknownProcess, + process: process.UnknownProcess, + } +) + +func init() { + UnknownDirectConnection.Save() + UnknownIncomingConnection.Save() +} diff --git a/portmaster/master.go b/portmaster/master.go deleted file mode 100644 index 42878df4..00000000 --- a/portmaster/master.go +++ /dev/null @@ -1,395 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package portmaster - -import ( - "net" - "os" - "strings" - - "github.com/Safing/safing-core/intel" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/network" - "github.com/Safing/safing-core/network/netutils" - "github.com/Safing/safing-core/network/packet" - "github.com/Safing/safing-core/port17/mode" - "github.com/Safing/safing-core/profiles" - - "github.com/agext/levenshtein" -) - -// use https://github.com/agext/levenshtein - -// Call order: -// -// 1. DecideOnConnectionBeforeIntel (if connecting to domain) -// is called when a DNS query is made, before the query is resolved -// 2. DecideOnConnectionAfterIntel (if connecting to domain) -// is called when a DNS query is made, after the query is resolved -// 3. DecideOnConnection -// is called when the first packet of the first link of the connection arrives -// 4. DecideOnLink -// is called when when the first packet of a link arrives only if connection has verdict UNDECIDED or CANTSAY - -func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) { - // check: - // Profile.DomainWhitelist - // Profile.Flags - // - process specific: System, Admin, User - // - network specific: Internet, LocalNet - - // grant self - if connection.Process().Pid == os.Getpid() { - log.Infof("sheriff: granting own connection %s", connection) - connection.Accept() - return - } - - // check if there is a profile - profile := connection.Process().Profile - if profile == nil { - log.Infof("sheriff: no profile, denying connection %s", connection) - connection.AddReason("no profile") - connection.Block() - return - } - - // check user class - if profile.Flags.Has(profiles.System) { - if !connection.Process().IsSystem() { - log.Infof("sheriff: denying connection %s, profile has System flag set, but process is not executed by System", connection) - connection.AddReason("must be executed by system") - connection.Block() - return - } - } - if profile.Flags.Has(profiles.Admin) { - if !connection.Process().IsAdmin() { - log.Infof("sheriff: denying connection %s, profile has Admin flag set, but process is not executed by Admin", connection) - connection.AddReason("must be executed by admin") - connection.Block() - return - } - } - if profile.Flags.Has(profiles.User) { - if !connection.Process().IsUser() { - log.Infof("sheriff: denying connection %s, profile has User flag set, but process is not executed by a User", connection) - connection.AddReason("must be executed by user") - connection.Block() - return - } - } - - // check for any network access - if !profile.Flags.Has(profiles.Internet) && !profile.Flags.Has(profiles.LocalNet) { - log.Infof("sheriff: denying connection %s, profile denies Internet and local network access", connection) - connection.Block() - return - } - - // check domain whitelist/blacklist - if len(profile.DomainWhitelist) > 0 { - matched := false - for _, entry := range profile.DomainWhitelist { - if !strings.HasSuffix(entry, ".") { - entry += "." - } - if strings.HasPrefix(entry, "*") { - if strings.HasSuffix(fqdn, strings.Trim(entry, "*")) { - matched = true - break - } - } else { - if entry == fqdn { - matched = true - break - } - } - } - if matched { - if profile.DomainWhitelistIsBlacklist { - log.Infof("sheriff: denying connection %s, profile has %s in domain blacklist", connection, fqdn) - connection.AddReason("domain blacklisted") - connection.Block() - return - } - } else { - if !profile.DomainWhitelistIsBlacklist { - log.Infof("sheriff: denying connection %s, profile does not have %s in domain whitelist", connection, fqdn) - connection.AddReason("domain not in whitelist") - connection.Block() - return - } - } - } - -} - -func DecideOnConnectionAfterIntel(connection *network.Connection, fqdn string, rrCache *intel.RRCache) *intel.RRCache { - // check: - // TODO: Profile.ClassificationBlacklist - // TODO: Profile.ClassificationWhitelist - // Profile.Flags - // - network specific: Strict - - // check if there is a profile - profile := connection.Process().Profile - if profile == nil { - log.Infof("sheriff: no profile, denying connection %s", connection) - connection.AddReason("no profile") - connection.Block() - return rrCache - } - - // check Strict flag - // TODO: drastically improve this! - if profile.Flags.Has(profiles.Strict) { - matched := false - pathElements := strings.Split(connection.Process().Path, "/") - if len(pathElements) > 2 { - pathElements = pathElements[len(pathElements)-2:] - } - domainElements := strings.Split(fqdn, ".") - matchLoop: - for _, domainElement := range domainElements { - for _, pathElement := range pathElements { - if levenshtein.Match(domainElement, pathElement, nil) > 0.5 { - matched = true - break matchLoop - } - } - if levenshtein.Match(domainElement, profile.Name, nil) > 0.5 { - matched = true - break matchLoop - } - if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 { - matched = true - break matchLoop - } - } - if !matched { - log.Infof("sheriff: denying connection %s, profile has declared Strict flag and no match to domain was found", connection) - connection.AddReason("domain does not relate to process") - connection.Block() - return rrCache - } - } - - // tunneling - // TODO: link this to real status - port17Active := mode.Client() - if port17Active { - tunnelInfo, err := AssignTunnelIP(fqdn) - if err != nil { - log.Errorf("portmaster: could not get tunnel IP for routing %s: %s", connection, err) - return nil // return nxDomain - } - // save original reply - tunnelInfo.RRCache = rrCache - // return tunnel IP - return tunnelInfo.ExportTunnelIP() - } - - return rrCache -} - -func DecideOnConnection(connection *network.Connection, pkt packet.Packet) { - // check: - // Profile.Flags - // - process specific: System, Admin, User - // - network specific: Internet, LocalNet, Service, Directconnect - - // grant self - if connection.Process().Pid == os.Getpid() { - log.Infof("sheriff: granting own connection %s", connection) - connection.Accept() - return - } - - // check if there is a profile - profile := connection.Process().Profile - if profile == nil { - log.Infof("sheriff: no profile, denying connection %s", connection) - connection.AddReason("no profile") - connection.Block() - return - } - - // check user class - if profile.Flags.Has(profiles.System) { - if !connection.Process().IsSystem() { - log.Infof("sheriff: denying connection %s, profile has System flag set, but process is not executed by System", connection) - connection.AddReason("must be executed by system") - connection.Block() - return - } - } - if profile.Flags.Has(profiles.Admin) { - if !connection.Process().IsAdmin() { - log.Infof("sheriff: denying connection %s, profile has Admin flag set, but process is not executed by Admin", connection) - connection.AddReason("must be executed by admin") - connection.Block() - return - } - } - if profile.Flags.Has(profiles.User) { - if !connection.Process().IsUser() { - log.Infof("sheriff: denying connection %s, profile has User flag set, but process is not executed by a User", connection) - connection.AddReason("must be executed by user") - connection.Block() - return - } - } - - // check for any network access - if !profile.Flags.Has(profiles.Internet) && !profile.Flags.Has(profiles.LocalNet) { - log.Infof("sheriff: denying connection %s, profile denies Internet and local network access", connection) - connection.AddReason("no network access allowed") - connection.Block() - return - } - - switch connection.Domain { - case "I": - // check Service flag - if !profile.Flags.Has(profiles.Service) { - log.Infof("sheriff: denying connection %s, profile does not declare service", connection) - connection.AddReason("not a service") - connection.Drop() - return - } - // check if incoming connections are allowed on any port, but only if there no other restrictions - if !!profile.Flags.Has(profiles.Internet) && !!profile.Flags.Has(profiles.LocalNet) && len(profile.ListenPorts) == 0 { - log.Infof("sheriff: granting connection %s, profile allows incoming connections from anywhere and on any port", connection) - connection.Accept() - return - } - case "D": - // check Directconnect flag - if !profile.Flags.Has(profiles.Directconnect) { - log.Infof("sheriff: denying connection %s, profile does not declare direct connections", connection) - connection.AddReason("direct connections (without DNS) not allowed") - connection.Drop() - return - } - } - - log.Infof("sheriff: could not decide on connection %s, deciding on per-link basis", connection) - connection.CantSay() -} - -func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet.Packet) { - // check: - // Profile.Flags - // - network specific: Internet, LocalNet - // Profile.ConnectPorts - // Profile.ListenPorts - - // check if there is a profile - profile := connection.Process().Profile - if profile == nil { - log.Infof("sheriff: no profile, denying %s", link) - link.AddReason("no profile") - link.UpdateVerdict(network.BLOCK) - return - } - - // check LocalNet and Internet flags - var remoteIP net.IP - if connection.Direction { - remoteIP = pkt.GetIPHeader().Src - } else { - remoteIP = pkt.GetIPHeader().Dst - } - if netutils.IPIsLocal(remoteIP) { - if !profile.Flags.Has(profiles.LocalNet) { - log.Infof("sheriff: dropping link %s, profile does not allow communication in the local network", link) - link.AddReason("profile does not allow access to local network") - link.UpdateVerdict(network.BLOCK) - return - } - } else { - if !profile.Flags.Has(profiles.Internet) { - log.Infof("sheriff: dropping link %s, profile does not allow communication with the Internet", link) - link.AddReason("profile does not allow access to the Internet") - link.UpdateVerdict(network.BLOCK) - return - } - } - - // check connect ports - if connection.Domain != "I" && len(profile.ConnectPorts) > 0 { - - tcpUdpHeader := pkt.GetTCPUDPHeader() - if tcpUdpHeader == nil { - log.Infof("sheriff: blocking link %s, profile has declared connect port whitelist, but link is not TCP/UDP", link) - link.AddReason("profile has declared connect port whitelist, but link is not TCP/UDP") - link.UpdateVerdict(network.BLOCK) - return - } - - // packet *should* be outbound, but we could be deciding on an already active connection. - var remotePort uint16 - if connection.Direction { - remotePort = tcpUdpHeader.SrcPort - } else { - remotePort = tcpUdpHeader.DstPort - } - - matched := false - for _, port := range profile.ConnectPorts { - if remotePort == port { - matched = true - break - } - } - - if !matched { - log.Infof("sheriff: blocking link %s, remote port %d not in profile connect port whitelist", link, remotePort) - link.AddReason("destination port not in whitelist") - link.UpdateVerdict(network.BLOCK) - return - } - - } - - // check listen ports - if connection.Domain == "I" && len(profile.ListenPorts) > 0 { - - tcpUdpHeader := pkt.GetTCPUDPHeader() - if tcpUdpHeader == nil { - log.Infof("sheriff: dropping link %s, profile has declared listen port whitelist, but link is not TCP/UDP", link) - link.AddReason("profile has declared listen port whitelist, but link is not TCP/UDP") - link.UpdateVerdict(network.DROP) - return - } - - // packet *should* be inbound, but we could be deciding on an already active connection. - var localPort uint16 - if connection.Direction { - localPort = tcpUdpHeader.DstPort - } else { - localPort = tcpUdpHeader.SrcPort - } - - matched := false - for _, port := range profile.ListenPorts { - if localPort == port { - matched = true - break - } - } - - if !matched { - log.Infof("sheriff: blocking link %s, local port %d not in profile listen port whitelist", link, localPort) - link.AddReason("listen port not in whitelist") - link.UpdateVerdict(network.BLOCK) - return - } - - } - - log.Infof("sheriff: accepting link %s", link) - link.UpdateVerdict(network.ACCEPT) - -} diff --git a/process/database.go b/process/database.go new file mode 100644 index 00000000..3cb87eed --- /dev/null +++ b/process/database.go @@ -0,0 +1,107 @@ +package process + +import ( + "fmt" + "sync" + "time" + + "github.com/Safing/portbase/database" + "github.com/Safing/portmaster/profile" + "github.com/tevino/abool" +) + +var ( + processes = make(map[int]*Process) + processesLock sync.RWMutex + + dbController *database.Controller + dbControllerFlag = abool.NewBool(false) +) + +// GetProcessFromStorage returns a process from the internal storage. +func GetProcessFromStorage(pid int) (*Process, bool) { + processesLock.RLock() + defer processesLock.RUnlock() + + p, ok := processes[pid] + return p, ok +} + +// All returns a copy of all process objects. +func All() []*Process { + processesLock.RLock() + defer processesLock.RUnlock() + + all := make([]*Process, 0, len(processes)) + for _, proc := range processes { + all = append(all, proc) + } + + return all +} + +// Save saves the process to the internal state and pushes an update. +func (p *Process) Save() { + p.Lock() + defer p.Unlock() + + if !p.KeyIsSet() { + p.SetKey(fmt.Sprintf("network:tree/%d", p.Pid)) + p.CreateMeta() + } + + processesLock.RLock() + _, ok := processes[p.Pid] + processesLock.RUnlock() + + if !ok { + processesLock.Lock() + processes[p.Pid] = p + processesLock.Unlock() + } + + if dbControllerFlag.IsSet() { + go dbController.PushUpdate(p) + } +} + +// Delete deletes a process from the storage and propagates the change. +func (p *Process) Delete() { + p.Lock() + defer p.Unlock() + + processesLock.Lock() + delete(processes, p.Pid) + processesLock.Unlock() + + p.Meta().Delete() + if dbControllerFlag.IsSet() { + go dbController.PushUpdate(p) + } + + // TODO: this should not be necessary, as processes should always have a profileSet. + if p.profileSet != nil { + profile.DeactivateProfileSet(p.profileSet) + } +} + +// CleanProcessStorage cleans the storage from old processes. +func CleanProcessStorage(thresholdDuration time.Duration) { + processesLock.Lock() + defer processesLock.Unlock() + + threshold := time.Now().Add(-thresholdDuration).Unix() + for _, p := range processes { + p.Lock() + if p.FirstConnectionEstablished < threshold && p.ConnectionCount == 0 { + go p.Delete() + } + p.Unlock() + } +} + +// SetDBController sets the database controller and allows the package to push database updates on a save. It must be set by the package that registers the "network" database. +func SetDBController(controller *database.Controller) { + dbController = controller + dbControllerFlag.Set() +} diff --git a/process/doc.go b/process/doc.go index bc77fb9c..6db99ce2 100644 --- a/process/doc.go +++ b/process/doc.go @@ -1,21 +1,7 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - /* - -Profiles - -Profiles describe the network behaviour - - -Profiles are found in 3 different paths: -- /Me/Profiles/: Profiles used for this system -- /Data/Profiles/: Profiles supplied by Safing -- /Company/Profiles/: Profiles supplied by the company - -When a program wants to use the network for the first time, Safing first searches for a Profile in the Company namespace, then in the Data namespace. If neither is found, it searches for a default profile in the same order. - -Default profiles are profiles with a path ending with a "/". The default profile with the longest matching path is chosen. +Package process fetches process and socket information from the operating system. +It can find the process owning a network connection. */ package process diff --git a/process/executable.go b/process/executable.go new file mode 100644 index 00000000..3bb6c00d --- /dev/null +++ b/process/executable.go @@ -0,0 +1,43 @@ +// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. + +package process + +import ( + "crypto" + "encoding/hex" + "hash" + "io" + "os" +) + +// GetExecHash returns the hash of the executable with the given algorithm. +func (p *Process) GetExecHash(algorithm string) (string, error) { + sum, ok := p.ExecHashes[algorithm] + if ok { + return sum, nil + } + + var hasher hash.Hash + switch algorithm { + case "md5": + hasher = crypto.MD5.New() + case "sha1": + hasher = crypto.SHA1.New() + case "sha256": + hasher = crypto.SHA256.New() + } + + file, err := os.Open(p.Path) + if err != nil { + return "", err + } + + _, err = io.Copy(hasher, file) + if err != nil { + return "", err + } + + sum = hex.EncodeToString(hasher.Sum(nil)) + p.ExecHashes[algorithm] = sum + return sum, nil +} diff --git a/process/fileinfo.go b/process/fileinfo.go deleted file mode 100644 index beb67bec..00000000 --- a/process/fileinfo.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package process - -import ( - "github.com/Safing/safing-core/database" - "strings" - "time" - - datastore "github.com/ipfs/go-datastore" -) - -// ExecutableSignature stores a signature of an executable. -type ExecutableSignature []byte - -// FileInfo stores (security) information about a file. -type FileInfo struct { - database.Base - HumanName string - Owners []string - ApproxLastSeen int64 - Signature *ExecutableSignature -} - -var fileInfoModel *FileInfo // only use this as parameter for database.EnsureModel-like functions - -func init() { - database.RegisterModel(fileInfoModel, func() database.Model { return new(FileInfo) }) -} - -// Create saves FileInfo with the provided name in the default namespace. -func (m *FileInfo) Create(name string) error { - return m.CreateObject(&database.FileInfoCache, name, m) -} - -// CreateInNamespace saves FileInfo with the provided name in the provided namespace. -func (m *FileInfo) CreateInNamespace(namespace *datastore.Key, name string) error { - return m.CreateObject(namespace, name, m) -} - -// Save saves FileInfo. -func (m *FileInfo) Save() error { - return m.SaveObject(m) -} - -// getFileInfo fetches FileInfo with the provided name from the default namespace. -func getFileInfo(name string) (*FileInfo, error) { - return getFileInfoFromNamespace(&database.FileInfoCache, name) -} - -// getFileInfoFromNamespace fetches FileInfo with the provided name from the provided namespace. -func getFileInfoFromNamespace(namespace *datastore.Key, name string) (*FileInfo, error) { - object, err := database.GetAndEnsureModel(namespace, name, fileInfoModel) - if err != nil { - return nil, err - } - model, ok := object.(*FileInfo) - if !ok { - return nil, database.NewMismatchError(object, fileInfoModel) - } - return model, nil -} - -// GetFileInfo gathers information about a file and returns *FileInfo -func GetFileInfo(path string) *FileInfo { - // TODO: actually get file information - // TODO: try to load from DB - // TODO: save to DB (key: hash of some sorts) - splittedPath := strings.Split("/", path) - return &FileInfo{ - HumanName: splittedPath[len(splittedPath)-1], - ApproxLastSeen: time.Now().Unix(), - } -} diff --git a/process/find.go b/process/find.go index 9df4840f..842922b9 100644 --- a/process/find.go +++ b/process/find.go @@ -4,14 +4,17 @@ import ( "errors" "net" - "github.com/Safing/safing-core/network/packet" + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/network/packet" ) +// Errors var ( - ErrConnectionNotFound = errors.New("could not find connection") - ErrProcessNotFound = errors.New("could not find process") + ErrConnectionNotFound = errors.New("could not find connection in system state tables") + ErrProcessNotFound = errors.New("could not find process in system state tables") ) +// GetPidByPacket returns the pid of the owner of the packet. func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { var localIP net.IP @@ -50,26 +53,33 @@ func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { } +// GetProcessByPacket returns the process that owns the given packet. func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) { var pid int pid, direction, err = GetPidByPacket(pkt) - if pid < 0 { - return nil, direction, ErrConnectionNotFound - } if err != nil { return nil, direction, err } + if pid < 0 { + return nil, direction, ErrConnectionNotFound + } process, err = GetOrFindProcess(pid) if err != nil { return nil, direction, err } + err = process.FindProfiles() + if err != nil { + log.Errorf("failed to find profiles for process %s: %s", process.String(), err) + } + return process, direction, nil } +// GetPidByEndpoints returns the pid of the owner of the described link. func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (pid int, direction bool, err error) { ipVersion := packet.IPv4 @@ -92,6 +102,7 @@ func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remote } +// GetProcessByEndpoints returns the process that owns the described link. func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) { var pid int @@ -108,41 +119,16 @@ func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, re return nil, err } + err = process.FindProfiles() + if err != nil { + log.Errorf("failed to find profiles for process %s: %s", process.String(), err) + } + return process, nil } +// GetActiveConnectionIDs returns a list of all active connection IDs. func GetActiveConnectionIDs() []string { return getActiveConnectionIDs() } - -// func GetProcessByPid(pid int) *Process { -// process, err := GetOrFindProcess(pid) -// if err != nil { -// log.Warningf("process: failed to get process %d: %s", pid, err) -// return nil -// } -// return process -// } - -// func GetProcessOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (process *Process, status uint8) { -// pid, status := GetPidOfConnection(localIP, localPort, protocol) -// if status == Success { -// process = GetProcessByPid(pid) -// if process == nil { -// return nil, NoProcessInfo -// } -// } -// return -// } - -// func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, status uint8) { -// pid, direction, status := GetPidByPacket(pkt) -// if status == Success { -// process = GetProcessByPid(pid) -// if process == nil { -// return nil, direction, NoProcessInfo -// } -// } -// return -// } diff --git a/process/getpid_linux.go b/process/getpid_linux.go index 098b443e..49f26e17 100644 --- a/process/getpid_linux.go +++ b/process/getpid_linux.go @@ -1,6 +1,8 @@ package process -import "github.com/Safing/safing-core/process/proc" +import ( + "github.com/Safing/portmaster/process/proc" +) var ( getTCP4PacketInfo = proc.GetTCP4PacketInfo diff --git a/process/iphelper/test/test.exe b/process/iphelper/test/test.exe deleted file mode 100644 index 7848762d..00000000 Binary files a/process/iphelper/test/test.exe and /dev/null differ diff --git a/process/matching.go b/process/matching.go new file mode 100644 index 00000000..bce476a3 --- /dev/null +++ b/process/matching.go @@ -0,0 +1,115 @@ +package process + +import ( + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/profile" +) + +var ( + profileDB = database.NewInterface(nil) +) + +// FindProfiles finds and assigns a profile set to the process. +func (p *Process) FindProfiles() error { + + p.Lock() + defer p.Unlock() + + // only find profiles if not already done. + if p.profileSet != nil { + return nil + } + + // User Profile + it, err := profileDB.Query(query.New(profile.MakeProfileKey(profile.UserNamespace, "")).Where(query.Where("LinkedPath", query.SameAs, p.Path))) + if err != nil { + return err + } + + var userProfile *profile.Profile + for r := range it.Next { + it.Cancel() + userProfile, err = profile.EnsureProfile(r) + if err != nil { + return err + } + break + } + if it.Err() != nil { + return it.Err() + } + + // create new profile if it does not exist. + if userProfile == nil { + // create new profile + userProfile = profile.New() + userProfile.Name = p.ExecName + userProfile.LinkedPath = p.Path + } + + if userProfile.MarkUsed() { + userProfile.Save(profile.UserNamespace) + } + + // Stamp + // Find/Re-evaluate Stamp profile + // 1. check linked stamp profile + // 2. if last check is was more than a week ago, fetch from stamp: + // 3. send path identifier to stamp + // 4. evaluate all returned profiles + // 5. select best + // 6. link stamp profile to user profile + // FIXME: implement! + + p.UserProfileKey = userProfile.Key() + p.profileSet = profile.NewSet(userProfile, nil) + go p.Save() + + return nil +} + +func selectProfile(p *Process, profs []*profile.Profile) (selectedProfile *profile.Profile) { + var highestScore int + for _, prof := range profs { + score := matchProfile(p, prof) + if score > highestScore { + selectedProfile = prof + } + } + return +} + +func matchProfile(p *Process, prof *profile.Profile) (score int) { + for _, fp := range prof.Fingerprints { + score += matchFingerprint(p, fp) + } + return +} + +func matchFingerprint(p *Process, fp *profile.Fingerprint) (score int) { + if !fp.MatchesOS() { + return 0 + } + + switch fp.Type { + case "full_path": + if p.Path == fp.Value { + } + return profile.GetFingerprintWeight(fp.Type) + case "partial_path": + // FIXME: if full_path matches, do not match partial paths + return profile.GetFingerprintWeight(fp.Type) + case "md5_sum", "sha1_sum", "sha256_sum": + // FIXME: one sum is enough, check sums in a grouped form, start with the best + sum, err := p.GetExecHash(fp.Type) + if err != nil { + log.Errorf("process: failed to get hash of executable: %s", err) + } else if sum == fp.Value { + return profile.GetFingerprintWeight(fp.Type) + } + } + + return 0 +} diff --git a/process/proc/gather.go b/process/proc/gather.go index 39a1fe60..98a83be4 100644 --- a/process/proc/gather.go +++ b/process/proc/gather.go @@ -13,14 +13,19 @@ const ( NoProcess ) -func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { +var ( + waitTime = 15 * time.Millisecond +) + +// GetPidOfConnection returns the PID of the given connection. +func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { uid, inode, ok := getConnectionSocket(localIP, localPort, protocol) if !ok { uid, inode, ok = getListeningSocket(localIP, localPort, protocol) for i := 0; i < 3 && !ok; i++ { // give kernel some time, then try again // log.Tracef("process: giving kernel some time to think") - time.Sleep(15 * time.Millisecond) + time.Sleep(waitTime) uid, inode, ok = getConnectionSocket(localIP, localPort, protocol) if !ok { uid, inode, ok = getListeningSocket(localIP, localPort, protocol) @@ -30,27 +35,48 @@ func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid return -1, NoSocket } } + pid, ok = GetPidOfInode(uid, inode) for i := 0; i < 3 && !ok; i++ { // give kernel some time, then try again // log.Tracef("process: giving kernel some time to think") - time.Sleep(15 * time.Millisecond) + time.Sleep(waitTime) pid, ok = GetPidOfInode(uid, inode) } if !ok { return -1, NoProcess } + return } -func GetPidOfIncomingConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { +// GetPidOfConnection returns the PID of the given incoming connection. +func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { uid, inode, ok := getListeningSocket(localIP, localPort, protocol) if !ok { - return -1, NoSocket + // for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version. + switch protocol { + case TCP4: + uid, inode, ok = getListeningSocket(localIP, localPort, TCP6) + case UDP4: + uid, inode, ok = getListeningSocket(localIP, localPort, UDP6) + } + + if !ok { + return -1, NoSocket + } } + pid, ok = GetPidOfInode(uid, inode) + for i := 0; i < 3 && !ok; i++ { + // give kernel some time, then try again + // log.Tracef("process: giving kernel some time to think") + time.Sleep(waitTime) + pid, ok = GetPidOfInode(uid, inode) + } if !ok { return -1, NoProcess } + return } diff --git a/process/proc/get.go b/process/proc/get.go index 9f3fc0ed..174ec823 100644 --- a/process/proc/get.go +++ b/process/proc/get.go @@ -6,39 +6,39 @@ import ( ) func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(TCP4, localIP, localPort, direction) + return search(TCP4, localIP, localPort, pktDirection) } func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(TCP6, localIP, localPort, direction) + return search(TCP6, localIP, localPort, pktDirection) } func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(UDP4, localIP, localPort, direction) + return search(UDP4, localIP, localPort, pktDirection) } func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(UDP6, localIP, localPort, direction) + return search(UDP6, localIP, localPort, pktDirection) } func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) { var status uint8 if pktDirection { - pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol) + pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol) if pid >= 0 { return pid, true, nil } - // pid, status = GetPidOfConnection(&localIP, localPort, protocol) + // pid, status = GetPidOfConnection(localIP, localPort, protocol) // if pid >= 0 { // return pid, false, nil // } } else { - pid, status = GetPidOfConnection(&localIP, localPort, protocol) + pid, status = GetPidOfConnection(localIP, localPort, protocol) if pid >= 0 { return pid, false, nil } - // pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol) + // pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol) // if pid >= 0 { // return pid, true, nil // } diff --git a/process/proc/processfinder.go b/process/proc/processfinder.go index 00edf3cb..b2c25ef2 100644 --- a/process/proc/processfinder.go +++ b/process/proc/processfinder.go @@ -10,7 +10,7 @@ import ( "sync" "syscall" - "github.com/Safing/safing-core/log" + "github.com/Safing/portbase/log" ) var ( diff --git a/process/proc/sockets.go b/process/proc/sockets.go index e453ca12..49b443ea 100644 --- a/process/proc/sockets.go +++ b/process/proc/sockets.go @@ -13,7 +13,7 @@ import ( "sync" "unicode" - "github.com/Safing/safing-core/log" + "github.com/Safing/portbase/log" ) /* @@ -81,7 +81,7 @@ var ( globalListeningUDP6 = make(map[uint16][]int) ) -func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int, int, bool) { +func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) { // listeningSocketsLock.Lock() // defer listeningSocketsLock.Unlock() @@ -98,10 +98,10 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]})) case TCP6: procFile = TCP6Data - localIPHex = hex.EncodeToString([]byte(*localIP)) + localIPHex = hex.EncodeToString([]byte(localIP)) case UDP6: procFile = UDP6Data - localIPHex = hex.EncodeToString([]byte(*localIP)) + localIPHex = hex.EncodeToString([]byte(localIP)) } localPortHex := fmt.Sprintf("%04X", localPort) @@ -162,38 +162,38 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int } -func getListeningSocket(localIP *net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) { +func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) { listeningSocketsLock.Lock() defer listeningSocketsLock.Unlock() - var addressListening *map[string][]int - var globalListening *map[uint16][]int + var addressListening map[string][]int + var globalListening map[uint16][]int switch protocol { case TCP4: - addressListening = &addressListeningTCP4 - globalListening = &globalListeningTCP4 + addressListening = addressListeningTCP4 + globalListening = globalListeningTCP4 case UDP4: - addressListening = &addressListeningUDP4 - globalListening = &globalListeningUDP4 + addressListening = addressListeningUDP4 + globalListening = globalListeningUDP4 case TCP6: - addressListening = &addressListeningTCP6 - globalListening = &globalListeningTCP6 + addressListening = addressListeningTCP6 + globalListening = globalListeningTCP6 case UDP6: - addressListening = &addressListeningUDP6 - globalListening = &globalListeningUDP6 + addressListening = addressListeningUDP6 + globalListening = globalListeningUDP6 } - data, ok := (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)] + data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)] if !ok { - data, ok = (*globalListening)[localPort] + data, ok = globalListening[localPort] } if ok { return data[0], data[1], true } updateListeners(protocol) - data, ok = (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)] + data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)] if !ok { - data, ok = (*globalListening)[localPort] + data, ok = globalListening[localPort] } if ok { return data[0], data[1], true @@ -206,7 +206,7 @@ func procDelimiter(c rune) bool { return unicode.IsSpace(c) || c == ':' } -func convertIPv4(data string) *net.IP { +func convertIPv4(data string) net.IP { decoded, err := hex.DecodeString(data) if err != nil { log.Warningf("process: could not parse IPv4 %s: %s", data, err) @@ -217,10 +217,10 @@ func convertIPv4(data string) *net.IP { return nil } ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) - return &ip + return ip } -func convertIPv6(data string) *net.IP { +func convertIPv6(data string) net.IP { decoded, err := hex.DecodeString(data) if err != nil { log.Warningf("process: could not parse IPv6 %s: %s", data, err) @@ -231,7 +231,7 @@ func convertIPv6(data string) *net.IP { return nil } ip := net.IP(decoded) - return &ip + return ip } func updateListeners(protocol uint8) { @@ -247,7 +247,7 @@ func updateListeners(protocol uint8) { } } -func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) *net.IP) (map[string][]int, map[uint16][]int) { +func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) { addressListening := make(map[string][]int) globalListening := make(map[uint16][]int) @@ -312,6 +312,7 @@ func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter return addressListening, globalListening } +// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS. func GetActiveConnectionIDs() []string { var connections []string @@ -323,7 +324,7 @@ func GetActiveConnectionIDs() []string { return connections } -func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) *net.IP) []string { +func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string { var connections []string // open file diff --git a/process/proc/sockets_test.go b/process/proc/sockets_test.go index 305c73a3..6bc77bfc 100644 --- a/process/proc/sockets_test.go +++ b/process/proc/sockets_test.go @@ -22,14 +22,14 @@ func TestSockets(t *testing.T) { t.Logf("addressListeningUDP6: %v", addressListeningUDP6) t.Logf("globalListeningUDP6: %v", globalListeningUDP6) - getListeningSocket(&net.IPv4zero, 53, TCP4) - getListeningSocket(&net.IPv4zero, 53, UDP4) - getListeningSocket(&net.IPv6zero, 53, TCP6) - getListeningSocket(&net.IPv6zero, 53, UDP6) + getListeningSocket(net.IPv4zero, 53, TCP4) + getListeningSocket(net.IPv4zero, 53, UDP4) + getListeningSocket(net.IPv6zero, 53, TCP6) + getListeningSocket(net.IPv6zero, 53, UDP6) // spotify: 192.168.0.102:5353 192.121.140.65:80 localIP := net.IPv4(192, 168, 127, 10) - uid, inode, ok := getConnectionSocket(&localIP, 46634, TCP4) + uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4) t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok) activeConnectionIDs := GetActiveConnectionIDs() diff --git a/process/process.go b/process/process.go index 2043f84a..65757dbe 100644 --- a/process/process.go +++ b/process/process.go @@ -5,89 +5,93 @@ package process import ( "fmt" "runtime" - "strconv" "strings" + "sync" + "time" - datastore "github.com/ipfs/go-datastore" processInfo "github.com/shirou/gopsutil/process" - "github.com/Safing/safing-core/database" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/profiles" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/profile" ) // A Process represents a process running on the operating system type Process struct { - database.Base - UserID int - UserName string - UserHome string - Pid int - ParentPid int - Path string - Cwd string - FileInfo *FileInfo - CmdLine string - FirstArg string - ProfileKey string - Profile *profiles.Profile - Name string - Icon string + record.Base + sync.Mutex + + UserID int + UserName string + UserHome string + Pid int + ParentPid int + Path string + Cwd string + CmdLine string + FirstArg string + + ExecName string + ExecHashes map[string]string + // ExecOwner ... + // ExecSignature ... + + UserProfileKey string + profileSet *profile.Set + Name string + Icon string // Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for database cache path or "c:"/"a:" for a the icon key to fetch it from a company / authoritative node and cache it in its own cache. + + FirstConnectionEstablished int64 + LastConnectionEstablished int64 + ConnectionCount uint } -var processModel *Process // only use this as parameter for database.EnsureModel-like functions +// ProfileSet returns the assigned profile set. +func (p *Process) ProfileSet() *profile.Set { + p.Lock() + defer p.Unlock() -func init() { - database.RegisterModel(processModel, func() database.Model { return new(Process) }) + return p.profileSet } -// Create saves Process with the provided name in the default namespace. -func (m *Process) Create(name string) error { - return m.CreateObject(&database.Processes, name, m) -} +// Strings returns a string represenation of process. +func (p *Process) String() string { + p.Lock() + defer p.Unlock() -// CreateInNamespace saves Process with the provided name in the provided namespace. -func (m *Process) CreateInNamespace(namespace *datastore.Key, name string) error { - return m.CreateObject(namespace, name, m) -} - -// Save saves Process. -func (m *Process) Save() error { - return m.SaveObject(m) -} - -// GetProcess fetches Process with the provided name from the default namespace. -func GetProcess(name string) (*Process, error) { - return GetProcessFromNamespace(&database.Processes, name) -} - -// GetProcessFromNamespace fetches Process with the provided name from the provided namespace. -func GetProcessFromNamespace(namespace *datastore.Key, name string) (*Process, error) { - object, err := database.GetAndEnsureModel(namespace, name, processModel) - if err != nil { - return nil, err - } - model, ok := object.(*Process) - if !ok { - return nil, database.NewMismatchError(object, processModel) - } - return model, nil -} - -func (m *Process) String() string { - if m == nil { + if p == nil { return "?" } - if m.Profile != nil && !m.Profile.Default { - return fmt.Sprintf("%s:%s:%d", m.UserName, m.Profile, m.Pid) - } - return fmt.Sprintf("%s:%s:%d", m.UserName, m.Path, m.Pid) + return fmt.Sprintf("%s:%s:%d", p.UserName, p.Path, p.Pid) } +// AddConnection increases the connection counter and the last connection timestamp. +func (p *Process) AddConnection() { + p.Lock() + defer p.Unlock() + + p.ConnectionCount++ + p.LastConnectionEstablished = time.Now().Unix() + if p.FirstConnectionEstablished == 0 { + p.FirstConnectionEstablished = p.LastConnectionEstablished + } +} + +// RemoveConnection lowers the connection counter by one. +func (p *Process) RemoveConnection() { + p.Lock() + defer p.Unlock() + + if p.ConnectionCount > 0 { + p.ConnectionCount-- + } +} + +// GetOrFindProcess returns the process for the given PID. func GetOrFindProcess(pid int) (*Process, error) { - process, err := GetProcess(strconv.Itoa(pid)) - if err == nil { + process, ok := GetProcessFromStorage(pid) + if ok { return process, nil } @@ -96,13 +100,9 @@ func GetOrFindProcess(pid int) (*Process, error) { } switch { - case (pid == 0 && runtime.GOOS == "linux") || (pid == 4 && runtime.GOOS == "windows"): + case new.IsKernel(): new.UserName = "Kernel" new.Name = "Operating System" - new.Profile = &profiles.Profile{ - Name: "OS", - Flags: []int8{profiles.Internet, profiles.LocalNet, profiles.Directconnect, profiles.Service}, - } default: pInfo, err := processInfo.NewProcess(int32(pid)) @@ -113,7 +113,8 @@ func GetOrFindProcess(pid int) (*Process, error) { // UID // net yet implemented for windows if runtime.GOOS == "linux" { - uids, err := pInfo.Uids() + var uids []int32 + uids, err = pInfo.Uids() if err != nil { log.Warningf("process: failed to get UID: %s", err) } else { @@ -167,85 +168,87 @@ func GetOrFindProcess(pid int) (*Process, error) { // new.Icon, err = // get Profile - processPath := new.Path - var applyProfile *profiles.Profile - iterations := 0 - for applyProfile == nil { + // processPath := new.Path + // var applyProfile *profiles.Profile + // iterations := 0 + // for applyProfile == nil { + // + // iterations++ + // if iterations > 10 { + // log.Warningf("process: got into loop while getting profile for %s", new) + // break + // } + // + // applyProfile, err = profiles.GetActiveProfileByPath(processPath) + // if err == database.ErrNotFound { + // applyProfile, err = profiles.FindProfileByPath(processPath, new.UserHome) + // } + // if err != nil { + // log.Warningf("process: could not get profile for %s: %s", new, err) + // } else if applyProfile == nil { + // log.Warningf("process: no default profile found for %s", new) + // } else { + // + // // TODO: there is a lot of undefined behaviour if chaining framework profiles + // + // // process framework + // if applyProfile.Framework != nil { + // if applyProfile.Framework.FindParent > 0 { + // var ppid int32 + // for i := uint8(1); i < applyProfile.Framework.FindParent; i++ { + // parent, err := pInfo.Parent() + // if err != nil { + // return nil, err + // } + // ppid = parent.Pid + // } + // if applyProfile.Framework.MergeWithParent { + // return GetOrFindProcess(int(ppid)) + // } + // // processPath, err = os.Readlink(fmt.Sprintf("/proc/%d/exe", pid)) + // // if err != nil { + // // return nil, fmt.Errorf("could not read /proc/%d/exe: %s", pid, err) + // // } + // continue + // } + // + // newCommand, err := applyProfile.Framework.GetNewPath(new.CmdLine, new.Cwd) + // if err != nil { + // return nil, err + // } + // + // // assign + // new.CmdLine = newCommand + // new.Path = strings.SplitN(newCommand, " ", 2)[0] + // processPath = new.Path + // + // // make sure we loop + // applyProfile = nil + // continue + // } + // + // // apply profile to process + // log.Debugf("process: applied profile to %s: %s", new, applyProfile) + // new.Profile = applyProfile + // new.ProfileKey = applyProfile.GetKey().String() + // + // // update Profile with Process icon if Profile does not have one + // if !new.Profile.Default && new.Icon != "" && new.Profile.Icon == "" { + // new.Profile.Icon = new.Icon + // new.Profile.Save() + // } + // } + // } - iterations++ - if iterations > 10 { - log.Warningf("process: got into loop while getting profile for %s", new) - break - } - - applyProfile, err = profiles.GetActiveProfileByPath(processPath) - if err == database.ErrNotFound { - applyProfile, err = profiles.FindProfileByPath(processPath, new.UserHome) - } - if err != nil { - log.Warningf("process: could not get profile for %s: %s", new, err) - } else if applyProfile == nil { - log.Warningf("process: no default profile found for %s", new) - } else { - - // TODO: there is a lot of undefined behaviour if chaining framework profiles - - // process framework - if applyProfile.Framework != nil { - if applyProfile.Framework.FindParent > 0 { - var ppid int32 - for i := uint8(1); i < applyProfile.Framework.FindParent; i++ { - parent, err := pInfo.Parent() - if err != nil { - return nil, err - } - ppid = parent.Pid - } - if applyProfile.Framework.MergeWithParent { - return GetOrFindProcess(int(ppid)) - } - // processPath, err = os.Readlink(fmt.Sprintf("/proc/%d/exe", pid)) - // if err != nil { - // return nil, fmt.Errorf("could not read /proc/%d/exe: %s", pid, err) - // } - continue - } - - newCommand, err := applyProfile.Framework.GetNewPath(new.CmdLine, new.Cwd) - if err != nil { - return nil, err - } - - // assign - new.CmdLine = newCommand - new.Path = strings.SplitN(newCommand, " ", 2)[0] - processPath = new.Path - - // make sure we loop - applyProfile = nil - continue - } - - // apply profile to process - log.Debugf("process: applied profile to %s: %s", new, applyProfile) - new.Profile = applyProfile - new.ProfileKey = applyProfile.GetKey().String() - - // update Profile with Process icon if Profile does not have one - if !new.Profile.Default && new.Icon != "" && new.Profile.Icon == "" { - new.Profile.Icon = new.Icon - new.Profile.Save() - } - } - } - - // get FileInfo - new.FileInfo = GetFileInfo(new.Path) + // Executable Information + // FIXME: use os specific path seperator + splittedPath := strings.Split(new.Path, "/") + new.ExecName = splittedPath[len(splittedPath)-1] } - // save to DB - new.Create(strconv.Itoa(new.Pid)) + // save to storage + new.Save() return new, nil } diff --git a/process/process_linux.go b/process/process_linux.go index edfaa9be..ebfa4104 100644 --- a/process/process_linux.go +++ b/process/process_linux.go @@ -1,13 +1,21 @@ package process +// IsUser returns whether the process is run by a normal user. func (m *Process) IsUser() bool { return m.UserID >= 1000 } +// IsAdmin returns whether the process is run by an admin user. func (m *Process) IsAdmin() bool { return m.UserID >= 0 } +// IsSystem returns whether the process is run by the operating system. func (m *Process) IsSystem() bool { return m.UserID == 0 } + +// IsKernel returns whether the process is the Kernel. +func (m *Process) IsKernel() bool { + return m.Pid == 0 +} diff --git a/process/process_windows.go b/process/process_windows.go index 9a16fe51..b1e38349 100644 --- a/process/process_windows.go +++ b/process/process_windows.go @@ -2,15 +2,23 @@ package process import "strings" +// IsUser returns whether the process is run by a normal user. func (m *Process) IsUser() bool { return m.Pid != 4 && // Kernel !strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!) } +// IsAdmin returns whether the process is run by an admin user. func (m *Process) IsAdmin() bool { return strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!) } +// IsSystem returns whether the process is run by the operating system. func (m *Process) IsSystem() bool { return m.Pid == 4 } + +// IsKernel returns whether the process is the Kernel. +func (m *Process) IsKernel() bool { + return m.Pid == 4 +} diff --git a/process/unknown.go b/process/unknown.go new file mode 100644 index 00000000..07f62056 --- /dev/null +++ b/process/unknown.go @@ -0,0 +1,16 @@ +package process + +var ( + // UnknownProcess is used when a process cannot be found. + UnknownProcess = &Process{ + UserID: -1, + UserName: "Unknown", + Pid: -1, + ParentPid: -1, + Name: "Unknown Processes", + } +) + +func init() { + UnknownProcess.Save() +} diff --git a/profile/active.go b/profile/active.go new file mode 100644 index 00000000..fdf1f057 --- /dev/null +++ b/profile/active.go @@ -0,0 +1,54 @@ +package profile + +import "sync" + +var ( + activeProfileSets = make(map[string]*Set) + activeProfileSetsLock sync.RWMutex +) + +func activateProfileSet(set *Set) { + set.Lock() + defer set.Unlock() + activeProfileSetsLock.Lock() + defer activeProfileSetsLock.Unlock() + activeProfileSets[set.profiles[0].ID] = set +} + +// DeactivateProfileSet marks a profile set as not active. +func DeactivateProfileSet(set *Set) { + set.Lock() + defer set.Unlock() + activeProfileSetsLock.Lock() + defer activeProfileSetsLock.Unlock() + delete(activeProfileSets, set.profiles[0].ID) +} + +func updateActiveUserProfile(profile *Profile) { + activeProfileSetsLock.RLock() + defer activeProfileSetsLock.RUnlock() + activeSet, ok := activeProfileSets[profile.ID] + if ok { + activeSet.Lock() + defer activeSet.Unlock() + activeSet.profiles[0] = profile + } +} + +func updateActiveStampProfile(profile *Profile) { + activeProfileSetsLock.RLock() + defer activeProfileSetsLock.RUnlock() + + for _, activeSet := range activeProfileSets { + activeSet.Lock() + activeProfile := activeSet.profiles[2] + if activeProfile != nil { + activeProfile.Lock() + if activeProfile.ID == profile.ID { + activeSet.profiles[2] = profile + } + activeProfile.Unlock() + } + activeSet.Unlock() + } +} diff --git a/profile/const.go b/profile/const.go new file mode 100644 index 00000000..8eac794d --- /dev/null +++ b/profile/const.go @@ -0,0 +1,8 @@ +package profile + +// Platform identifiers +const ( + PlatformLinux = "linux" + PlatformWindows = "windows" + PlatformMac = "macos" +) diff --git a/profile/const_darwin.go b/profile/const_darwin.go new file mode 100644 index 00000000..5654364a --- /dev/null +++ b/profile/const_darwin.go @@ -0,0 +1,6 @@ +package profile + +// OS Identifier +const ( + osIdentifier = PlatformMac +) diff --git a/profile/const_linux.go b/profile/const_linux.go new file mode 100644 index 00000000..0b8f8874 --- /dev/null +++ b/profile/const_linux.go @@ -0,0 +1,6 @@ +package profile + +// OS Identifier +const ( + osIdentifier = PlatformLinux +) diff --git a/profile/database.go b/profile/database.go new file mode 100644 index 00000000..740ab4a9 --- /dev/null +++ b/profile/database.go @@ -0,0 +1,22 @@ +package profile + +import ( + "github.com/Safing/portbase/database" +) + +// core:profiles/user/12345-1234-125-1234-1235 +// core:profiles/special/default +// /global +// core:profiles/stamp/12334-1235-1234-5123-1234 +// core:profiles/identifier/base64 + +// Namespaces +const ( + UserNamespace = "user" + StampNamespace = "stamp" + SpecialNamespace = "special" +) + +var ( + profileDB = database.NewInterface(nil) +) diff --git a/profile/defaults.go b/profile/defaults.go new file mode 100644 index 00000000..e4551fee --- /dev/null +++ b/profile/defaults.go @@ -0,0 +1,44 @@ +package profile + +import ( + "github.com/Safing/portmaster/status" +) + +func makeDefaultGlobalProfile() *Profile { + return &Profile{ + ID: "global", + Name: "Global Profile", + } +} + +func makeDefaultFallbackProfile() *Profile { + return &Profile{ + ID: "fallback", + Name: "Fallback Profile", + Flags: map[uint8]uint8{ + // Profile Modes + Blacklist: status.SecurityLevelDynamic, + Prompt: status.SecurityLevelSecure, + Whitelist: status.SecurityLevelFortress, + + // Network Locations + Internet: status.SecurityLevelsDynamicAndSecure, + LAN: status.SecurityLevelsDynamicAndSecure, + Localhost: status.SecurityLevelsAll, + + // Specials + Related: status.SecurityLevelDynamic, + PeerToPeer: status.SecurityLevelDynamic, + }, + ServiceEndpoints: []*EndpointPermission{ + &EndpointPermission{ + DomainOrIP: "", + Wildcard: true, + Protocol: 0, + StartPort: 0, + EndPort: 0, + Permit: false, + }, + }, + } +} diff --git a/profile/endpoints.go b/profile/endpoints.go new file mode 100644 index 00000000..3716f824 --- /dev/null +++ b/profile/endpoints.go @@ -0,0 +1,138 @@ +package profile + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Safing/portmaster/intel" +) + +// Endpoints is a list of permitted or denied endpoints. +type Endpoints []*EndpointPermission + +// EndpointPermission holds a decision about an endpoint. +type EndpointPermission struct { + DomainOrIP string + Wildcard bool + Protocol uint8 + StartPort uint16 + EndPort uint16 + Permit bool + Created int64 +} + +// IsSet returns whether the Endpoints object is "set". +func (e Endpoints) IsSet() bool { + if len(e) > 0 { + return true + } + return false +} + +// Check checks if the given domain is governed in the list of domains and returns whether it is permitted. +// If getDomainOfIP (returns reverse and forward dns matching domain name) is supplied, an IP will be resolved to a domain, if necessary. +func (e Endpoints) Check(domainOrIP string, protocol uint8, port uint16, checkReverseIP bool, securityLevel uint8) (permit bool, reason string, ok bool) { + + // ip resolving + var cachedGetDomainOfIP func() string + if checkReverseIP { + var ipResolved bool + var ipName string + // setup caching wrapper + cachedGetDomainOfIP = func() string { + if !ipResolved { + result, err := intel.ResolveIPAndValidate(domainOrIP, securityLevel) + if err != nil { + // log.Debug() + ipName = result + } + ipResolved = true + } + return ipName + } + } + + isDomain := strings.HasSuffix(domainOrIP, ".") + + for _, entry := range e { + if entry != nil { + if ok, reason := entry.Matches(domainOrIP, protocol, port, isDomain, cachedGetDomainOfIP); ok { + return entry.Permit, reason, true + } + } + } + + return false, "", false +} + +func isSubdomainOf(domain, subdomain string) bool { + dotPrefixedDomain := "." + domain + return strings.HasSuffix(subdomain, dotPrefixedDomain) +} + +// Matches checks whether the given endpoint has a managed permission. If getDomainOfIP (returns reverse and forward dns matching domain name) is supplied, this declares an incoming connection. +func (ep EndpointPermission) Matches(domainOrIP string, protocol uint8, port uint16, isDomain bool, getDomainOfIP func() string) (match bool, reason string) { + if ep.Protocol > 0 && protocol != ep.Protocol { + return false, "" + } + + if ep.StartPort > 0 && (port < ep.StartPort || port > ep.EndPort) { + return false, "" + } + + switch { + case ep.Wildcard && len(ep.DomainOrIP) == 0: + // host wildcard + return true, fmt.Sprintf("%s matches %s", domainOrIP, ep) + case domainOrIP == ep.DomainOrIP: + // host match + return true, fmt.Sprintf("%s matches %s", domainOrIP, ep) + case isDomain && ep.Wildcard && isSubdomainOf(ep.DomainOrIP, domainOrIP): + // subdomain match + return true, fmt.Sprintf("%s matches %s", domainOrIP, ep) + case !isDomain && getDomainOfIP != nil && getDomainOfIP() == ep.DomainOrIP: + // resolved IP match + return true, fmt.Sprintf("%s->%s matches %s", domainOrIP, getDomainOfIP(), ep) + case !isDomain && getDomainOfIP != nil && ep.Wildcard && isSubdomainOf(ep.DomainOrIP, getDomainOfIP()): + // resolved IP subdomain match + return true, fmt.Sprintf("%s->%s matches %s", domainOrIP, getDomainOfIP(), ep) + default: + // no match + return false, "" + } +} + +func (e Endpoints) String() string { + var s []string + for _, entry := range e { + s = append(s, entry.String()) + } + return fmt.Sprintf("[%s]", strings.Join(s, ", ")) +} + +func (ep EndpointPermission) String() string { + s := ep.DomainOrIP + + s += " " + + if ep.Protocol > 0 { + s += strconv.Itoa(int(ep.Protocol)) + } else { + s += "*" + } + + s += "/" + + if ep.StartPort > 0 { + if ep.StartPort == ep.EndPort { + s += strconv.Itoa(int(ep.StartPort)) + } else { + s += fmt.Sprintf("%d-%d", ep.StartPort, ep.EndPort) + } + } else { + s += "*" + } + + return s +} diff --git a/profile/endpoints_test.go b/profile/endpoints_test.go new file mode 100644 index 00000000..a87ab0da --- /dev/null +++ b/profile/endpoints_test.go @@ -0,0 +1,61 @@ +package profile + +import ( + "testing" +) + +// TODO: RETIRED +// func testdeMatcher(t *testing.T, value string, expectedResult bool) { +// if domainEndingMatcher.MatchString(value) != expectedResult { +// if expectedResult { +// t.Errorf("domainEndingMatcher should match %s", value) +// } else { +// t.Errorf("domainEndingMatcher should not match %s", value) +// } +// } +// } +// +// func TestdomainEndingMatcher(t *testing.T) { +// testdeMatcher(t, "example.com", true) +// testdeMatcher(t, "com", true) +// testdeMatcher(t, "example.xn--lgbbat1ad8j", true) +// testdeMatcher(t, "xn--lgbbat1ad8j", true) +// testdeMatcher(t, "fe80::beef", false) +// testdeMatcher(t, "fe80::dead:beef", false) +// testdeMatcher(t, "10.2.3.4", false) +// testdeMatcher(t, "4", false) +// } + +func TestEPString(t *testing.T) { + var endpoints Endpoints + endpoints = []*EndpointPermission{ + &EndpointPermission{ + DomainOrIP: "example.com", + Wildcard: false, + Protocol: 6, + Permit: true, + }, + &EndpointPermission{ + DomainOrIP: "8.8.8.8", + Protocol: 17, // TCP + StartPort: 53, // DNS + EndPort: 53, + Permit: false, + }, + &EndpointPermission{ + DomainOrIP: "google.com", + Wildcard: true, + Permit: false, + }, + } + if endpoints.String() != "[example.com 6/*, 8.8.8.8 17/53, google.com */*]" { + t.Errorf("unexpected result: %s", endpoints.String()) + } + + var noEndpoints Endpoints + noEndpoints = []*EndpointPermission{} + if noEndpoints.String() != "[]" { + t.Errorf("unexpected result: %s", noEndpoints.String()) + } + +} diff --git a/profile/fingerprint.go b/profile/fingerprint.go new file mode 100644 index 00000000..3d4b6dc9 --- /dev/null +++ b/profile/fingerprint.go @@ -0,0 +1,48 @@ +package profile + +import "time" + +var ( + fingerprintWeights = map[string]int{ + "full_path": 2, + "partial_path": 1, + "md5_sum": 4, + "sha1_sum": 5, + "sha256_sum": 6, + } +) + +// Fingerprint links processes to profiles. +type Fingerprint struct { + OS string + Type string + Value string + Comment string + LastUsed int64 +} + +// MatchesOS returns whether the Fingerprint is applicable for the current OS. +func (fp *Fingerprint) MatchesOS() bool { + return fp.OS == osIdentifier +} + +// GetFingerprintWeight returns the weight of the given fingerprint type. +func GetFingerprintWeight(fpType string) (weight int) { + weight, ok := fingerprintWeights[fpType] + if ok { + return weight + } + return 0 +} + +// AddFingerprint adds the given fingerprint to the profile. +func (p *Profile) AddFingerprint(fp *Fingerprint) { + if fp.OS == "" { + fp.OS = osIdentifier + } + if fp.LastUsed == 0 { + fp.LastUsed = time.Now().Unix() + } + + p.Fingerprints = append(p.Fingerprints, fp) +} diff --git a/profile/flags.go b/profile/flags.go new file mode 100644 index 00000000..fe8f6c1f --- /dev/null +++ b/profile/flags.go @@ -0,0 +1,130 @@ +package profile + +import ( + "errors" + "fmt" + "strings" + + "github.com/Safing/portmaster/status" +) + +// Flags are used to quickly add common attributes to profiles +type Flags map[uint8]uint8 + +// Profile Flags +const ( + // Profile Modes + Prompt uint8 = 0 // Prompt first-seen connections + Blacklist uint8 = 1 // Allow everything not explicitly denied + Whitelist uint8 = 2 // Only allow everything explicitly allowed + + // Network Locations + Internet uint8 = 16 // Allow connections to the Internet + LAN uint8 = 17 // Allow connections to the local area network + Localhost uint8 = 18 // Allow connections on the local host + + // Specials + Related uint8 = 32 // If and before prompting, allow domains that are related to the program + PeerToPeer uint8 = 33 // Allow program to directly communicate with peers, without resolving DNS first + Service uint8 = 34 // Allow program to accept incoming connections + Independent uint8 = 35 // Ignore profile settings coming from the Community + RequireGate17 uint8 = 36 // Require all connections to go over Gate17 +) + +var ( + // ErrFlagsParseFailed is returned if a an invalid flag is encountered while parsing + ErrFlagsParseFailed = errors.New("profiles: failed to parse flags") + + sortedFlags = []uint8{ + Prompt, + Blacklist, + Whitelist, + Internet, + LAN, + Localhost, + Related, + PeerToPeer, + Service, + Independent, + RequireGate17, + } + + flagIDs = map[string]uint8{ + "Prompt": Prompt, + "Blacklist": Blacklist, + "Whitelist": Whitelist, + "Internet": Internet, + "LAN": LAN, + "Localhost": Localhost, + "Related": Related, + "PeerToPeer": PeerToPeer, + "Service": Service, + "Independent": Independent, + "RequireGate17": RequireGate17, + } + + flagNames = map[uint8]string{ + Prompt: "Prompt", + Blacklist: "Blacklist", + Whitelist: "Whitelist", + Internet: "Internet", + LAN: "LAN", + Localhost: "Localhost", + Related: "Related", + PeerToPeer: "PeerToPeer", + Service: "Service", + Independent: "Independent", + RequireGate17: "RequireGate17", + } +) + +// Check checks if a flag is set at all and if it's active in the given security level. +func (flags Flags) Check(flag, level uint8) (active bool, ok bool) { + if flags == nil { + return false, false + } + + setting, ok := flags[flag] + if ok { + if setting&level > 0 { + return true, true + } + return false, true + } + return false, false +} + +func getLevelMarker(levels, level uint8) string { + if levels&level > 0 { + return "+" + } + return "-" +} + +// String return a string representation of Flags +func (flags Flags) String() string { + var markedFlags []string + for _, flag := range sortedFlags { + levels, ok := flags[flag] + if ok { + s := flagNames[flag] + if levels != status.SecurityLevelsAll { + s += getLevelMarker(levels, status.SecurityLevelDynamic) + s += getLevelMarker(levels, status.SecurityLevelSecure) + s += getLevelMarker(levels, status.SecurityLevelFortress) + } + markedFlags = append(markedFlags, s) + } + } + return fmt.Sprintf("[%s]", strings.Join(markedFlags, ", ")) +} + +// Add adds a flag to the Flags with the given level. +func (flags Flags) Add(flag, levels uint8) { + flags[flag] = levels +} + +// Remove removes a flag from the Flags. +func (flags Flags) Remove(flag uint8) { + delete(flags, flag) +} diff --git a/profile/flags_test.go b/profile/flags_test.go new file mode 100644 index 00000000..75b85b5f --- /dev/null +++ b/profile/flags_test.go @@ -0,0 +1,69 @@ +package profile + +import ( + "testing" + + "github.com/Safing/portmaster/status" +) + +func TestProfileFlags(t *testing.T) { + + // check if all IDs have a name + for key, entry := range flagIDs { + if _, ok := flagNames[entry]; !ok { + t.Errorf("could not find entry for %s in flagNames", key) + } + } + + // check if all names have an ID + for key, entry := range flagNames { + if _, ok := flagIDs[entry]; !ok { + t.Errorf("could not find entry for %d in flagNames", key) + } + } + + testFlags := Flags{ + Prompt: status.SecurityLevelsAll, + Internet: status.SecurityLevelsDynamicAndSecure, + LAN: status.SecurityLevelsDynamicAndSecure, + Localhost: status.SecurityLevelsAll, + Related: status.SecurityLevelDynamic, + RequireGate17: status.SecurityLevelsSecureAndFortress, + } + + if testFlags.String() != "[Prompt, Internet++-, LAN++-, Localhost, Related+--, RequireGate17-++]" { + t.Errorf("unexpected output: %s", testFlags.String()) + } + + // // check Has + // emptyFlags := ProfileFlags{} + // for flag, name := range flagNames { + // if !sortedFlags.Has(flag) { + // t.Errorf("sortedFlags should have flag %s (%d)", name, flag) + // } + // if emptyFlags.Has(flag) { + // t.Errorf("emptyFlags should not have flag %s (%d)", name, flag) + // } + // } + // + // // check ProfileFlags creation from strings + // var allFlagStrings []string + // for _, flag := range *sortedFlags { + // allFlagStrings = append(allFlagStrings, flagNames[flag]) + // } + // newFlags, err := FlagsFromNames(allFlagStrings) + // if err != nil { + // t.Errorf("error while parsing flags: %s", err) + // } + // if newFlags.String() != sortedFlags.String() { + // t.Errorf("parsed flags are not correct (or tests have not been updated to reflect the right number), expected %v, got %v", *sortedFlags, *newFlags) + // } + // + // // check ProfileFlags Stringer + // flagString := newFlags.String() + // check := strings.Join(allFlagStrings, ",") + // if flagString != check { + // t.Errorf("flag string is not correct, expected %s, got %s", check, flagString) + // } + +} diff --git a/profile/framework.go b/profile/framework.go new file mode 100644 index 00000000..64cbbefc --- /dev/null +++ b/profile/framework.go @@ -0,0 +1,76 @@ +package profile + +// DEACTIVATED + +// import ( +// "fmt" +// "os" +// "path/filepath" +// "regexp" +// "strings" +// +// "github.com/Safing/portbase/log" +// ) +// +// type Framework struct { +// // go hirarchy up +// FindParent uint8 `json:",omitempty bson:",omitempty"` +// // get path from parent, amount of levels to go up the tree (1 means parent, 2 means parent of parents, and so on) +// MergeWithParent bool `json:",omitempty bson:",omitempty"` +// // instead of getting the path of the parent, merge with it by presenting connections as if they were from that parent +// +// // go hirarchy down +// Find string `json:",omitempty bson:",omitempty"` +// // Regular expression for finding path elements +// Build string `json:",omitempty bson:",omitempty"` +// // Path definitions for building path +// Virtual bool `json:",omitempty bson:",omitempty"` +// // Treat resulting path as virtual, do not check if valid +// } +// +// func (f *Framework) GetNewPath(command string, cwd string) (string, error) { +// // "/usr/bin/python script" +// // to +// // "/path/to/script" +// regex, err := regexp.Compile(f.Find) +// if err != nil { +// return "", fmt.Errorf("profiles(framework): failed to compile framework regex: %s", err) +// } +// matched := regex.FindAllStringSubmatch(command, -1) +// if len(matched) == 0 || len(matched[0]) < 2 { +// return "", fmt.Errorf("profiles(framework): regex \"%s\" for constructing path did not match command \"%s\"", f.Find, command) +// } +// +// var lastError error +// var buildPath string +// for _, buildPath = range strings.Split(f.Build, "|") { +// +// buildPath = strings.Replace(buildPath, "{CWD}", cwd, -1) +// for i := 1; i < len(matched[0]); i++ { +// buildPath = strings.Replace(buildPath, fmt.Sprintf("{%d}", i), matched[0][i], -1) +// } +// +// buildPath = filepath.Clean(buildPath) +// +// if !f.Virtual { +// if !strings.HasPrefix(buildPath, "~/") && !filepath.IsAbs(buildPath) { +// lastError = fmt.Errorf("constructed path \"%s\" from framework is not absolute", buildPath) +// continue +// } +// if _, err := os.Stat(buildPath); os.IsNotExist(err) { +// lastError = fmt.Errorf("constructed path \"%s\" does not exist", buildPath) +// continue +// } +// } +// +// lastError = nil +// break +// +// } +// +// if lastError != nil { +// return "", fmt.Errorf("profiles(framework): failed to construct valid path, last error: %s", lastError) +// } +// log.Tracef("profiles(framework): transformed \"%s\" (%s) to \"%s\"", command, cwd, buildPath) +// return buildPath, nil +// } diff --git a/profile/framework_test.go b/profile/framework_test.go new file mode 100644 index 00000000..52e423c2 --- /dev/null +++ b/profile/framework_test.go @@ -0,0 +1,30 @@ +package profile + +// DEACTIVATED + +// import ( +// "testing" +// ) +// +// func testGetNewPath(t *testing.T, f *Framework, command, cwd, expect string) { +// newPath, err := f.GetNewPath(command, cwd) +// if err != nil { +// t.Errorf("GetNewPath failed: %s", err) +// } +// if newPath != expect { +// t.Errorf("GetNewPath return unexpected result: got %s, expected %s", newPath, expect) +// } +// } +// +// func TestFramework(t *testing.T) { +// f1 := &Framework{ +// Find: "([^ ]+)$", +// Build: "{CWD}/{1}", +// } +// testGetNewPath(t, f1, "/usr/bin/python bash", "/bin", "/bin/bash") +// f2 := &Framework{ +// Find: "([^ ]+)$", +// Build: "{1}|{CWD}/{1}", +// } +// testGetNewPath(t, f2, "/usr/bin/python /bin/bash", "/tmp", "/bin/bash") +// } diff --git a/profile/identifier_linux.go b/profile/identifier_linux.go new file mode 100644 index 00000000..d8020e5e --- /dev/null +++ b/profile/identifier_linux.go @@ -0,0 +1,47 @@ +package profile + +import ( + "path/filepath" + "strings" + + "github.com/Safing/portbase/utils" +) + +// GetPathIdentifier returns the identifier from the given path +func GetPathIdentifier(path string) string { + // clean path + // TODO: is this necessary? + cleanedPath, err := filepath.EvalSymlinks(path) + if err == nil { + path = cleanedPath + } else { + path = filepath.Clean(path) + } + + splittedPath := strings.Split(path, "/") + + // strip sensitive data + switch { + case strings.HasPrefix(path, "/home/"): + splittedPath = splittedPath[3:] + case strings.HasPrefix(path, "/root/"): + splittedPath = splittedPath[2:] + } + + // common directories with executable + if i := utils.IndexOfString(splittedPath, "bin"); i > 0 { + splittedPath = splittedPath[i:] + return strings.Join(splittedPath, "/") + } + if i := utils.IndexOfString(splittedPath, "sbin"); i > 0 { + splittedPath = splittedPath[i:] + return strings.Join(splittedPath, "/") + } + + // shorten to max 3 + if len(splittedPath) > 3 { + splittedPath = splittedPath[len(splittedPath)-3:] + } + + return strings.Join(splittedPath, "/") +} diff --git a/profile/identifier_linux_test.go b/profile/identifier_linux_test.go new file mode 100644 index 00000000..fef1157a --- /dev/null +++ b/profile/identifier_linux_test.go @@ -0,0 +1,23 @@ +package profile + +import "testing" + +func testPathID(t *testing.T, execPath, identifierPath string) { + result := GetPathIdentifier(execPath) + if result != identifierPath { + t.Errorf("unexpected identifier path for %s: got %s, expected %s", execPath, result, identifierPath) + } +} + +func TestGetPathIdentifier(t *testing.T) { + testPathID(t, "/bin/bash", "bin/bash") + testPathID(t, "/home/user/bin/bash", "bin/bash") + testPathID(t, "/home/user/project/main", "project/main") + testPathID(t, "/root/project/main", "project/main") + testPathID(t, "/tmp/a/b/c/d/install.sh", "c/d/install.sh") + testPathID(t, "/sbin/init", "sbin/init") + testPathID(t, "/lib/systemd/systemd-udevd", "lib/systemd/systemd-udevd") + testPathID(t, "/bundle/ruby/2.4.0/bin/passenger", "bin/passenger") + testPathID(t, "/usr/sbin/cron", "sbin/cron") + testPathID(t, "/usr/local/bin/python", "bin/python") +} diff --git a/profile/index/index.go b/profile/index/index.go new file mode 100644 index 00000000..c5859990 --- /dev/null +++ b/profile/index/index.go @@ -0,0 +1,102 @@ +package index + +import ( + "encoding/base64" + "errors" + "fmt" + "sync" + + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/utils" +) + +// ProfileIndex links an Identifier to Profiles +type ProfileIndex struct { + record.Base + sync.Mutex + + ID string + + UserProfiles []string + StampProfiles []string +} + +func makeIndexRecordKey(fpType, id string) string { + return fmt.Sprintf("index:profiles/%s:%s", fpType, base64.RawURLEncoding.EncodeToString([]byte(id))) +} + +// NewIndex returns a new ProfileIndex. +func NewIndex(id string) *ProfileIndex { + return &ProfileIndex{ + ID: id, + } +} + +// AddUserProfile adds a User Profile to the index. +func (pi *ProfileIndex) AddUserProfile(identifier string) (changed bool) { + if !utils.StringInSlice(pi.UserProfiles, identifier) { + pi.UserProfiles = append(pi.UserProfiles, identifier) + return true + } + return false +} + +// AddStampProfile adds a Stamp Profile to the index. +func (pi *ProfileIndex) AddStampProfile(identifier string) (changed bool) { + if !utils.StringInSlice(pi.StampProfiles, identifier) { + pi.StampProfiles = append(pi.StampProfiles, identifier) + return true + } + return false +} + +// RemoveUserProfile removes a profile from the index. +func (pi *ProfileIndex) RemoveUserProfile(id string) { + pi.UserProfiles = utils.RemoveFromStringSlice(pi.UserProfiles, id) +} + +// RemoveStampProfile removes a profile from the index. +func (pi *ProfileIndex) RemoveStampProfile(id string) { + pi.StampProfiles = utils.RemoveFromStringSlice(pi.StampProfiles, id) +} + +// Get gets a ProfileIndex from the database. +func Get(fpType, id string) (*ProfileIndex, error) { + key := makeIndexRecordKey(fpType, id) + + r, err := indexDB.Get(key) + if err != nil { + return nil, err + } + + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + new := &ProfileIndex{} + err = record.Unwrap(r, new) + if err != nil { + return nil, err + } + return new, nil + } + + // or adjust type + new, ok := r.(*ProfileIndex) + if !ok { + return nil, fmt.Errorf("record not of type *ProfileIndex, but %T", r) + } + return new, nil +} + +// Save saves the Identifiers to the database +func (pi *ProfileIndex) Save() error { + if !pi.KeyIsSet() { + if pi.ID != "" { + pi.SetKey(makeIndexRecordKey(pi.ID)) + } else { + return errors.New("missing identification Key") + } + } + + return indexDB.Put(pi) +} diff --git a/profile/index/indexer.go b/profile/index/indexer.go new file mode 100644 index 00000000..0776d24f --- /dev/null +++ b/profile/index/indexer.go @@ -0,0 +1,103 @@ +package index + +import ( + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/modules" + + "github.com/Safing/portmaster/profile" +) + +// FIXME: listen for profile changes and update the index + +var ( + indexDB = database.NewInterface(&database.Options{ + Local: true, // we want to access crownjewel records + AlwaysMakeCrownjewel: true, // never sync the index + }) + indexSub *database.Subscription + + shutdownIndexer = make(chan struct{}) +) + +func init() { + modules.Register("profile:index", nil, start, stop, "profile", "database") +} + +func start() (err error) { + indexSub, err = indexDB.Subscribe(query.New("core:profiles/user/")) + if err != nil { + return err + } + + return nil +} + +func stop() error { + close(shutdownIndexer) + indexSub.Cancel() + return nil +} + +func indexer() { + for { + select { + case <-shutdownIndexer: + return + case r := <-indexSub.Feed: + if r == nil { + return + } + + prof := ensureProfile(r) + if prof != nil { + for _, fp := range prof.Fingerprints { + if fp.MatchesOS() && fp.Type == "full_path" { + + // get Profile and ensure identifier is set + pi, err := Get("full_path", fp.Value) + if err != nil { + if err == database.ErrNotFound { + pi = NewIndex(id) + } else { + log.Errorf("profile/index: could not save updated profile index: %s", err) + } + } + + if pi.AddUserProfile(prof.ID) { + err := pi.Save() + if err != nil { + log.Errorf("profile/index: could not save updated profile index: %s", err) + } + } + + } + } + } + } + } +} + +func ensureProfile(r record.Record) *profile.Profile { + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + new := &profile.Profile{} + err := record.Unwrap(r, new) + if err != nil { + log.Errorf("profile/index: could not unwrap Profile: %s", err) + return nil + } + return new + } + + // or adjust type + new, ok := r.(*profile.Profile) + if !ok { + log.Errorf("profile/index: record not of type *Profile, but %T", r) + return nil + } + return new +} diff --git a/profile/matching/database.go b/profile/matching/database.go new file mode 100644 index 00000000..0196d0c5 --- /dev/null +++ b/profile/matching/database.go @@ -0,0 +1,17 @@ +package matching + +import ( + "github.com/Safing/portbase/database" +) + +// core:profiles/user/12345-1234-125-1234-1235 +// core:profiles/special/default +// /global +// core:profiles/stamp/12334-1235-1234-5123-1234 +// core:profiles/identifier/base64 + +var ( + profileDB = database.NewInterface(&database.Options{ + Local: true, // we want to access crownjewel records (indexes are) + }) +) diff --git a/profile/matching/fingerprints.go b/profile/matching/fingerprints.go new file mode 100644 index 00000000..e77b7df6 --- /dev/null +++ b/profile/matching/fingerprints.go @@ -0,0 +1,23 @@ +package matcher + +import ( + "strings" + + "github.com/Safing/portmaster/process" + "github.com/Safing/portmaster/profile" +) + +// CheckFingerprints checks what fingerprints match and returns the total score. +func CheckFingerprints(proc *process.Process, prof *profile.Profile) (score int, err error) { + // FIXME: kinda a dummy for now + + for _, fp := range prof.Fingerprints { + if strings.HasPrefix(fp, "fullpath:") { + if fp[9:] == proc.Path { + return 3, nil + } + } + } + + return 0, nil +} diff --git a/profile/matching/identpath_linux.go b/profile/matching/identpath_linux.go new file mode 100644 index 00000000..b5a27262 --- /dev/null +++ b/profile/matching/identpath_linux.go @@ -0,0 +1,18 @@ +package matcher + +import ( + "fmt" + "strings" + + "github.com/Safing/portmaster/process" + "github.com/Safing/portmaster/profile" +) + +// GetIdentificationPath returns the identifier for the given process (linux edition). +func GetIdentificationPath(p *process.Process) string { + splittedPath := strings.Split(p.Path, "/") + if len(splittedPath) > 3 { + return fmt.Sprintf("%s%s", profile.IdentifierPrefix, strings.Join(splittedPath[len(splittedPath)-3:len(splittedPath)], "/")) + } + return fmt.Sprintf("%s%s", profile.IdentifierPrefix, p.Path) +} diff --git a/profile/matching/identpath_linux_test.go b/profile/matching/identpath_linux_test.go new file mode 100644 index 00000000..2b480f00 --- /dev/null +++ b/profile/matching/identpath_linux_test.go @@ -0,0 +1,25 @@ +package matcher + +import ( + "testing" + + "github.com/Safing/portmaster/process" +) + +func TestGetIdentifierLinux(t *testing.T) { + p := &process.Process{ + Path: "/usr/lib/firefox/firefox", + } + + if GetIdentificationPath(p) != "lin:lib/firefox/firefox" { + t.Fatal("mismatch!") + } + + p = &process.Process{ + Path: "/opt/start", + } + + if GetIdentificationPath(p) != "lin:/opt/start" { + t.Fatal("mismatch!") + } +} diff --git a/profile/matching/matcher.go b/profile/matching/matcher.go new file mode 100644 index 00000000..b9c69a91 --- /dev/null +++ b/profile/matching/matcher.go @@ -0,0 +1,61 @@ +package matcher + +import ( + "fmt" + "strings" + + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/process" + "github.com/Safing/portmaster/profile" + "github.com/Safing/portmaster/profile/index" +) + +// GetProfileSet finds a local profile. +func GetProfileSet(proc *process.Process) (set *profile.ProfileSet, err error) { + + identPath := GetIdentificationPath(proc) + pi, err := index.GetIndex(identPath) + + var bestScore int + var bestProfile *profile.Profile + + for _, id := range pi.UserProfiles { + prof, err := profile.GetUserProfile(id) + if err != nil { + log.Errorf("profile/matcher: failed to load profile: %s", err) + continue + } + + score, err := CheckFingerprints(proc, prof) + if score > bestScore { + bestScore = score + bestProfile = prof + } + } + + if bestProfile == nil { + bestProfile = ProfileFromProcess(proc) + } + + // FIXME: fetch stamp profile + set = profile.NewSet(bestProfile, nil) + return set, nil +} + +// ProfileFromProcess creates an initial profile based on the given process. +func ProfileFromProcess(proc *process.Process) *profile.Profile { + new := profile.New() + + splittedPath := strings.Split(proc.Path, "/") + new.Name = strings.ToTitle(splittedPath[len(splittedPath)-1]) + + new.Identifiers = append(new.Identifiers, GetIdentificationPath(proc)) + new.Fingerprints = append(new.Fingerprints, fmt.Sprintf("fullpath:%s", proc.Path)) + + err := new.Save(profile.UserNamespace) + if err != nil { + log.Errorf("profile/matcher: could not save new profile: %s", new.Name) + } + + return new +} diff --git a/profile/module.go b/profile/module.go new file mode 100644 index 00000000..37e6f709 --- /dev/null +++ b/profile/module.go @@ -0,0 +1,24 @@ +package profile + +import "github.com/Safing/portbase/modules" + +var ( + shutdownSignal = make(chan struct{}) +) + +func init() { + modules.Register("profile", nil, start, stop, "global", "database") +} + +func start() error { + err := initSpecialProfiles() + if err != nil { + return err + } + return initUpdateListener() +} + +func stop() error { + close(shutdownSignal) + return nil +} diff --git a/profile/profile.go b/profile/profile.go new file mode 100644 index 00000000..5be0d1a8 --- /dev/null +++ b/profile/profile.go @@ -0,0 +1,141 @@ +package profile + +import ( + "fmt" + "sync" + "time" + + uuid "github.com/satori/go.uuid" + + "github.com/Safing/portbase/database/record" + "github.com/Safing/portmaster/status" +) + +var ( + lastUsedUpdateThreshold = 1 * time.Hour +) + +// Profile is used to predefine a security profile for applications. +type Profile struct { + record.Base + sync.Mutex + + // Profile Metadata + ID string + Name string + Description string + Homepage string + // Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for a database path or "e:" for the encoded data. + Icon string + + // User Profile Only + LinkedPath string + StampProfileID string + StampProfileAssigned int64 + + // Fingerprints + Fingerprints []*Fingerprint + + // The mininum security level to apply to connections made with this profile + SecurityLevel uint8 + Flags Flags + Endpoints Endpoints + ServiceEndpoints Endpoints + + // If a Profile is declared as a Framework (i.e. an Interpreter and the likes), then the real process must be found + // Framework *Framework `json:",omitempty bson:",omitempty"` + + // When this Profile was approximately last used (for performance reasons not every single usage is saved) + Created int64 + ApproxLastUsed int64 +} + +// New returns a new Profile. +func New() *Profile { + return &Profile{ + Created: time.Now().Unix(), + } +} + +// MakeProfileKey creates the correct key for a profile with the given namespace and ID. +func MakeProfileKey(namespace, ID string) string { + return fmt.Sprintf("core:profiles/%s/%s", namespace, ID) +} + +// Save saves the profile to the database +func (profile *Profile) Save(namespace string) error { + if profile.ID == "" { + u, err := uuid.NewV4() + if err != nil { + return err + } + profile.ID = u.String() + } + + if !profile.KeyIsSet() { + if namespace == "" { + return fmt.Errorf("no key or namespace defined for profile %s", profile.String()) + } + profile.SetKey(MakeProfileKey(namespace, profile.ID)) + } + + return profileDB.Put(profile) +} + +// MarkUsed marks the profile as used, eventually. +func (profile *Profile) MarkUsed() (updated bool) { + if time.Now().Add(-lastUsedUpdateThreshold).Unix() > profile.ApproxLastUsed { + profile.ApproxLastUsed = time.Now().Unix() + return true + } + return false +} + +// String returns a string representation of the Profile. +func (profile *Profile) String() string { + return profile.Name +} + +// DetailedString returns a more detailed string representation of theProfile. +func (profile *Profile) DetailedString() string { + return fmt.Sprintf("%s(SL=%s Flags=%s Endpoints=%s)", profile.Name, status.FmtSecurityLevel(profile.SecurityLevel), profile.Flags.String(), profile.Endpoints.String()) +} + +// GetUserProfile loads a profile from the database. +func GetUserProfile(ID string) (*Profile, error) { + return getProfile(UserNamespace, ID) +} + +// GetStampProfile loads a profile from the database. +func GetStampProfile(ID string) (*Profile, error) { + return getProfile(StampNamespace, ID) +} + +func getProfile(namespace, ID string) (*Profile, error) { + r, err := profileDB.Get(MakeProfileKey(namespace, ID)) + if err != nil { + return nil, err + } + return EnsureProfile(r) +} + +// EnsureProfile ensures that the given record is a *Profile, and returns it. +func EnsureProfile(r record.Record) (*Profile, error) { + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + new := &Profile{} + err := record.Unwrap(r, new) + if err != nil { + return nil, err + } + return new, nil + } + + // or adjust type + new, ok := r.(*Profile) + if !ok { + return nil, fmt.Errorf("record not of type *Example, but %T", r) + } + return new, nil +} diff --git a/profile/set.go b/profile/set.go new file mode 100644 index 00000000..9cb53ff1 --- /dev/null +++ b/profile/set.go @@ -0,0 +1,168 @@ +package profile + +import ( + "sync" + + "github.com/Safing/portmaster/status" +) + +var ( + emptyFlags = Flags{} +) + +// Set handles Profile chaining. +type Set struct { + sync.Mutex + + profiles [4]*Profile + // Application + // Global + // Stamp + // Default + + combinedSecurityLevel uint8 + independent bool +} + +// NewSet returns a new profile set with given the profiles. +func NewSet(user, stamp *Profile) *Set { + new := &Set{ + profiles: [4]*Profile{ + user, // Application + nil, // Global + stamp, // Stamp + nil, // Default + }, + } + activateProfileSet(new) + new.Update(status.SecurityLevelFortress) + return new +} + +// UserProfile returns the user profile. +func (set *Set) UserProfile() *Profile { + return set.profiles[0] +} + +// Update gets the new global and default profile and updates the independence status. It must be called when reusing a profile set for a series of calls. +func (set *Set) Update(securityLevel uint8) { + set.Lock() + + specialProfileLock.RLock() + defer specialProfileLock.RUnlock() + + // update profiles + set.profiles[1] = globalProfile + set.profiles[3] = fallbackProfile + + // update security level + profileSecurityLevel := set.getSecurityLevel() + if profileSecurityLevel > securityLevel { + set.combinedSecurityLevel = profileSecurityLevel + } else { + set.combinedSecurityLevel = securityLevel + } + + set.Unlock() + // update independence + if set.CheckFlag(Independent) { + set.Lock() + set.independent = true + set.Unlock() + } else { + set.Lock() + set.independent = false + set.Unlock() + } +} + +// SecurityLevel returns the applicable security level for the profile set. +func (set *Set) SecurityLevel() uint8 { + set.Lock() + defer set.Unlock() + + return set.combinedSecurityLevel +} + +// GetProfileMode returns the active profile mode. +func (set *Set) GetProfileMode() uint8 { + switch { + case set.CheckFlag(Whitelist): + return Whitelist + case set.CheckFlag(Prompt): + return Prompt + case set.CheckFlag(Blacklist): + return Blacklist + default: + return Whitelist + } +} + +// CheckFlag returns whether a given flag is set. +func (set *Set) CheckFlag(flag uint8) (active bool) { + set.Lock() + defer set.Unlock() + + for i, profile := range set.profiles { + if i == 2 && set.independent { + continue + } + + if profile != nil { + active, ok := profile.Flags.Check(flag, set.combinedSecurityLevel) + if ok { + return active + } + } + } + + return false +} + +// CheckEndpoint checks if the given protocol and port are governed in any the lists of ports and returns whether it is permitted. +func (set *Set) CheckEndpoint(domainOrIP string, protocol uint8, port uint16, inbound bool) (permit bool, reason string, ok bool) { + set.Lock() + defer set.Unlock() + + for i, profile := range set.profiles { + if i == 2 && set.independent { + continue + } + + if profile != nil { + if inbound { + if permit, reason, ok = profile.ServiceEndpoints.Check(domainOrIP, protocol, port, inbound, set.combinedSecurityLevel); ok { + return + } + } else { + if permit, reason, ok = profile.Endpoints.Check(domainOrIP, protocol, port, inbound, set.combinedSecurityLevel); ok { + return + } + } + } + } + + return false, "", false +} + +// getSecurityLevel returns the highest prioritized security level. +func (set *Set) getSecurityLevel() uint8 { + if set == nil { + return 0 + } + + for i, profile := range set.profiles { + if i == 2 { + // Stamp profiles do not have the SecurityLevel setting + continue + } + + if profile != nil { + if profile.SecurityLevel > 0 { + return profile.SecurityLevel + } + } + } + + return 0 +} diff --git a/profile/set_test.go b/profile/set_test.go new file mode 100644 index 00000000..332ec35c --- /dev/null +++ b/profile/set_test.go @@ -0,0 +1,160 @@ +package profile + +import ( + "testing" + "time" + + "github.com/Safing/portmaster/status" +) + +var ( + testUserProfile *Profile + testStampProfile *Profile +) + +func init() { + specialProfileLock.Lock() + defer specialProfileLock.Unlock() + + globalProfile = makeDefaultGlobalProfile() + fallbackProfile = makeDefaultFallbackProfile() + + testUserProfile = &Profile{ + ID: "unit-test-user", + Name: "Unit Test User Profile", + SecurityLevel: status.SecurityLevelDynamic, + Flags: map[uint8]uint8{ + Independent: status.SecurityLevelFortress, + }, + Endpoints: []*EndpointPermission{ + &EndpointPermission{ + DomainOrIP: "good.bad.example.com.", + Wildcard: false, + Permit: true, + Created: time.Now().Unix(), + }, + &EndpointPermission{ + DomainOrIP: "bad.example.com.", + Wildcard: true, + Permit: false, + Created: time.Now().Unix(), + }, + &EndpointPermission{ + DomainOrIP: "example.com.", + Wildcard: false, + Permit: true, + Created: time.Now().Unix(), + }, + &EndpointPermission{ + DomainOrIP: "", + Wildcard: true, + Permit: true, + Protocol: 6, + StartPort: 22000, + EndPort: 22000, + Created: time.Now().Unix(), + }, + }, + } + + testStampProfile = &Profile{ + ID: "unit-test-stamp", + Name: "Unit Test Stamp Profile", + SecurityLevel: status.SecurityLevelFortress, + // Flags: map[uint8]uint8{ + // Internet: status.SecurityLevelsAll, + // }, + Endpoints: []*EndpointPermission{ + &EndpointPermission{ + DomainOrIP: "bad2.example.com.", + Wildcard: true, + Permit: false, + Created: time.Now().Unix(), + }, + &EndpointPermission{ + DomainOrIP: "", + Wildcard: true, + Permit: true, + Protocol: 6, + StartPort: 80, + EndPort: 80, + Created: time.Now().Unix(), + }, + }, + ServiceEndpoints: []*EndpointPermission{ + &EndpointPermission{ + DomainOrIP: "", + Wildcard: true, + Permit: true, + Protocol: 17, + StartPort: 12345, + EndPort: 12347, + Created: time.Now().Unix(), + }, + &EndpointPermission{ // default deny + DomainOrIP: "", + Wildcard: true, + Permit: false, + Created: time.Now().Unix(), + }, + }, + } +} + +func testFlag(t *testing.T, set *Set, flag uint8, shouldBeActive bool) { + active := set.CheckFlag(flag) + if active != shouldBeActive { + t.Errorf("unexpected result: flag %s: permitted=%v, expected=%v", flagNames[flag], active, shouldBeActive) + } +} + +func testEndpoint(t *testing.T, set *Set, domainOrIP string, protocol uint8, port uint16, inbound bool, shouldBePermitted bool) { + var permitted, ok bool + permitted, _, ok = set.CheckEndpoint(domainOrIP, protocol, port, inbound) + if !ok { + t.Errorf("endpoint %s/%d/%d/%v should be in test profile set", domainOrIP, protocol, port, inbound) + } + if permitted != shouldBePermitted { + t.Errorf("unexpected result for endpoint %s/%d/%d/%v: permitted=%v, expected=%v", domainOrIP, protocol, port, inbound, permitted, shouldBePermitted) + } +} + +func testUnregulatedEndpoint(t *testing.T, set *Set, domainOrIP string, protocol uint8, port uint16, inbound bool) { + _, _, ok := set.CheckEndpoint(domainOrIP, protocol, port, inbound) + if ok { + t.Errorf("endpoint %s/%d/%d/%v should not be in test profile set", domainOrIP, protocol, port, inbound) + } +} + +func TestProfileSet(t *testing.T) { + + set := NewSet(testUserProfile, testStampProfile) + + set.Update(status.SecurityLevelDynamic) + testFlag(t, set, Whitelist, false) + // testFlag(t, set, Internet, true) + testEndpoint(t, set, "example.com.", 0, 0, false, true) + testEndpoint(t, set, "bad.example.com.", 0, 0, false, false) + testEndpoint(t, set, "other.bad.example.com.", 0, 0, false, false) + testEndpoint(t, set, "good.bad.example.com.", 0, 0, false, true) + testEndpoint(t, set, "bad2.example.com.", 0, 0, false, false) + testEndpoint(t, set, "10.2.3.4", 6, 22000, false, true) + testEndpoint(t, set, "fd00::1", 6, 22000, false, true) + testEndpoint(t, set, "test.local.", 6, 22000, false, true) + testUnregulatedEndpoint(t, set, "other.example.com.", 0, 0, false) + testUnregulatedEndpoint(t, set, "10.2.3.4", 17, 53, false) + testUnregulatedEndpoint(t, set, "10.2.3.4", 17, 443, false) + testUnregulatedEndpoint(t, set, "10.2.3.4", 6, 12346, false) + testEndpoint(t, set, "10.2.3.4", 17, 12345, true, true) + testEndpoint(t, set, "fd00::1", 17, 12347, true, true) + + set.Update(status.SecurityLevelSecure) + // testFlag(t, set, Internet, true) + + set.Update(status.SecurityLevelFortress) // Independent! + testFlag(t, set, Whitelist, true) + testEndpoint(t, set, "10.2.3.4", 17, 12345, true, false) + testEndpoint(t, set, "fd00::1", 17, 12347, true, false) + testUnregulatedEndpoint(t, set, "10.2.3.4", 6, 80, false) + testUnregulatedEndpoint(t, set, "bad2.example.com.", 0, 0, false) +} diff --git a/profile/specialprofiles.go b/profile/specialprofiles.go new file mode 100644 index 00000000..d2e5ceb6 --- /dev/null +++ b/profile/specialprofiles.go @@ -0,0 +1,71 @@ +package profile + +import ( + "sync" + + "github.com/Safing/portbase/database" +) + +var ( + globalProfile *Profile + fallbackProfile *Profile + + specialProfileLock sync.RWMutex +) + +func initSpecialProfiles() (err error) { + + specialProfileLock.Lock() + defer specialProfileLock.Unlock() + + globalProfile, err = getSpecialProfile("global") + if err != nil { + if err != database.ErrNotFound { + return err + } + globalProfile = makeDefaultGlobalProfile() + globalProfile.Save(SpecialNamespace) + } + + fallbackProfile, err = getSpecialProfile("fallback") + if err != nil { + if err != database.ErrNotFound { + return err + } + fallbackProfile = makeDefaultFallbackProfile() + ensureServiceEndpointsDenyAll(fallbackProfile) + fallbackProfile.Save(SpecialNamespace) + } + ensureServiceEndpointsDenyAll(fallbackProfile) + + return nil +} + +func getSpecialProfile(ID string) (*Profile, error) { + return getProfile(SpecialNamespace, ID) +} + +func ensureServiceEndpointsDenyAll(p *Profile) (changed bool) { + for _, ep := range p.ServiceEndpoints { + if ep != nil { + if ep.DomainOrIP == "" && + ep.Wildcard == true && + ep.Protocol == 0 && + ep.StartPort == 0 && + ep.EndPort == 0 && + ep.Permit == false { + return false + } + } + } + + p.ServiceEndpoints = append(p.ServiceEndpoints, &EndpointPermission{ + DomainOrIP: "", + Wildcard: true, + Protocol: 0, + StartPort: 0, + EndPort: 0, + Permit: false, + }) + return true +} diff --git a/profile/updates.go b/profile/updates.go new file mode 100644 index 00000000..1aa0ddf3 --- /dev/null +++ b/profile/updates.go @@ -0,0 +1,66 @@ +package profile + +import ( + "strings" + + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/log" +) + +func initUpdateListener() error { + sub, err := profileDB.Subscribe(query.New(MakeProfileKey(SpecialNamespace, ""))) + if err != nil { + return err + } + + go updateListener(sub) + return nil +} + +func updateListener(sub *database.Subscription) { + for { + select { + case <-shutdownSignal: + return + case r := <-sub.Feed: + + if r.Meta().IsDeleted() { + continue + } + + profile, err := EnsureProfile(r) + if err != nil { + log.Errorf("profile: received update for special profile, but could not read: %s", err) + continue + } + + switch profile.DatabaseKey() { + case "profiles/special/global": + specialProfileLock.Lock() + globalProfile = profile + specialProfileLock.Unlock() + case "profiles/special/fallback": + profile.Lock() + if ensureServiceEndpointsDenyAll(profile) { + profile.Unlock() + profile.Save(SpecialNamespace) + continue + } + profile.Unlock() + + specialProfileLock.Lock() + fallbackProfile = profile + specialProfileLock.Unlock() + default: + switch { + case strings.HasPrefix(profile.Key(), MakeProfileKey(UserNamespace, "")): + updateActiveUserProfile(profile) + case strings.HasPrefix(profile.Key(), MakeProfileKey(StampNamespace, "")): + updateActiveStampProfile(profile) + } + } + + } + } +} diff --git a/profiles/framework.go b/profiles/framework.go deleted file mode 100644 index 3bb19aad..00000000 --- a/profiles/framework.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package profiles - -import ( - "fmt" - "os" - "path/filepath" - "regexp" - "strings" - - "github.com/Safing/safing-core/log" -) - -type Framework struct { - // go hirarchy up - FindParent uint8 `json:",omitempty bson:",omitempty"` - // get path from parent, amount of levels to go up the tree (1 means parent, 2 means parent of parents, and so on) - MergeWithParent bool `json:",omitempty bson:",omitempty"` - // instead of getting the path of the parent, merge with it by presenting connections as if they were from that parent - - // go hirarchy down - Find string `json:",omitempty bson:",omitempty"` - // Regular expression for finding path elements - Build string `json:",omitempty bson:",omitempty"` - // Path definitions for building path - Virtual bool `json:",omitempty bson:",omitempty"` - // Treat resulting path as virtual, do not check if valid -} - -func (f *Framework) GetNewPath(command string, cwd string) (string, error) { - // "/usr/bin/python script" - // to - // "/path/to/script" - regex, err := regexp.Compile(f.Find) - if err != nil { - return "", fmt.Errorf("profiles(framework): failed to compile framework regex: %s", err) - } - matched := regex.FindAllStringSubmatch(command, -1) - if len(matched) == 0 || len(matched[0]) < 2 { - return "", fmt.Errorf("profiles(framework): regex \"%s\" for constructing path did not match command \"%s\"", f.Find, command) - } - - var lastError error - var buildPath string - for _, buildPath = range strings.Split(f.Build, "|") { - - buildPath = strings.Replace(buildPath, "{CWD}", cwd, -1) - for i := 1; i < len(matched[0]); i++ { - buildPath = strings.Replace(buildPath, fmt.Sprintf("{%d}", i), matched[0][i], -1) - } - - buildPath = filepath.Clean(buildPath) - - if !f.Virtual { - if !strings.HasPrefix(buildPath, "~/") && !filepath.IsAbs(buildPath) { - lastError = fmt.Errorf("constructed path \"%s\" from framework is not absolute", buildPath) - continue - } - if _, err := os.Stat(buildPath); os.IsNotExist(err) { - lastError = fmt.Errorf("constructed path \"%s\" does not exist", buildPath) - continue - } - } - - lastError = nil - break - - } - - if lastError != nil { - return "", fmt.Errorf("profiles(framework): failed to construct valid path, last error: %s", lastError) - } - log.Tracef("profiles(framework): transformed \"%s\" (%s) to \"%s\"", command, cwd, buildPath) - return buildPath, nil -} diff --git a/profiles/framework_test.go b/profiles/framework_test.go deleted file mode 100644 index 2a6e739e..00000000 --- a/profiles/framework_test.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package profiles - -import ( - "testing" -) - -func testGetNewPath(t *testing.T, f *Framework, command, cwd, expect string) { - newPath, err := f.GetNewPath(command, cwd) - if err != nil { - t.Errorf("GetNewPath failed: %s", err) - } - if newPath != expect { - t.Errorf("GetNewPath return unexpected result: got %s, expected %s", newPath, expect) - } -} - -func TestFramework(t *testing.T) { - f1 := &Framework{ - Find: "([^ ]+)$", - Build: "{CWD}/{1}", - } - testGetNewPath(t, f1, "/usr/bin/python bash", "/bin", "/bin/bash") - f2 := &Framework{ - Find: "([^ ]+)$", - Build: "{1}|{CWD}/{1}", - } - testGetNewPath(t, f2, "/usr/bin/python /bin/bash", "/tmp", "/bin/bash") -} diff --git a/profiles/profile.go b/profiles/profile.go deleted file mode 100644 index 15284ad6..00000000 --- a/profiles/profile.go +++ /dev/null @@ -1,265 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package profiles - -import ( - "encoding/hex" - "strings" - - datastore "github.com/ipfs/go-datastore" - dsq "github.com/ipfs/go-datastore/query" - - "github.com/Safing/safing-core/database" - "github.com/Safing/safing-core/intel" - "github.com/Safing/safing-core/log" -) - -// Profile is used to predefine a security profile for applications. -type Profile struct { - database.Base - Name string - Path string - Description string `json:",omitempty bson:",omitempty"` - Icon string `json:",omitempty bson:",omitempty"` - // Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for database cache path or "c:"/"a:" for a the icon key to fetch it from a company / authoritative node and cache it in its own cache. - - // TODO: Think more about using one profile for multiple paths - // Refer string `json:",omitempty bson:",omitempty"` - - // If a Profile is declared as a Framework (i.e. an Interpreter and the likes), then the real process must be found - Framework *Framework `json:",omitempty bson:",omitempty"` - // The format how to real process is to be found is yet to be defined. - // Ideas: - // - Regex for finding the executed script in the arguments, prepend working directory if path is not absolute - // - Parent Process? - // Use Cases: - // - Interpreters (Python, Java, ...) - // - Sandboxes (Flatpak, Snapd, Docker, ...) - // - Subprocesses of main application process - - SecurityLevel int8 `json:",omitempty bson:",omitempty"` - // The mininum security level to apply to connections made with this profile - Flags ProfileFlags - - ClassificationBlacklist *intel.EntityClassification `json:",omitempty bson:",omitempty"` - ClassificationWhitelist *intel.EntityClassification `json:",omitempty bson:",omitempty"` - DomainWhitelistIsBlacklist bool `json:",omitempty bson:",omitempty"` - DomainWhitelist []string `json:",omitempty bson:",omitempty"` - - ConnectPorts []uint16 `json:",omitempty bson:",omitempty"` - ListenPorts []uint16 `json:",omitempty bson:",omitempty"` - - Default bool `json:",omitempty bson:",omitempty"` - // This flag indicates that this profile is a default profile. If no profile is found for a process, the default profile with the longest matching prefix path is used. - PromptUserToAdapt bool `json:",omitempty bson:",omitempty"` - // This flag is only valid for default profiles and indicates that should this profile be used for a process, the user will be prompted to adapt it for the process and save a new profile. - Authoritative bool `json:",omitempty bson:",omitempty"` - // This flag indicates that this profile was loaded from an authoritative source - the Safing Community or the Company. Authoritative Profiles that have been modified can be reverted back to their original state. - Locked bool `json:",omitempty bson:",omitempty"` - // This flag indicates that this profile was locked by the company. This means that the profile may not be edited by the user. If an authoritative default profile is locked, it wins over non-authoritative profiles and the user will not be prompted to adapt the profile, should the PromptUserToAdapt flag be set. - Modified bool `json:",omitempty bson:",omitempty"` - // This flag indicates that this profile has been modified by the user. Non-modified authoritative profiles will be automatically overwritten with new versions. - Orphaned bool `json:",omitempty bson:",omitempty"` - // This flag indicates that the associated program (on path) does not exist (Either this entry was manually created, or the program has been uninstalled). Only valid for non-default profiles. - - ApproxLastUsed int64 `json:",omitempty bson:",omitempty"` - // When this Profile was approximately last used (for performance reasons not every single usage is saved) -} - -var profileModel *Profile // only use this as parameter for database.EnsureModel-like functions - -func init() { - database.RegisterModel(profileModel, func() database.Model { return new(Profile) }) -} - -// Create saves Profile with the provided name in the Profiles namespace. -func (m *Profile) Create() error { - name := hex.EncodeToString([]byte(m.Path)) - if m.Default { - name = "d-" + name - } - return m.CreateObject(&database.Profiles, name, m) -} - -// CreateInNamespace saves Profile with the provided name in the provided namespace. -func (m *Profile) CreateInNamespace(namespace *datastore.Key) error { - name := hex.EncodeToString([]byte(m.Path)) - if m.Default { - name = "d-" + name - } - return m.CreateObject(namespace, name, m) -} - -// CreateInDist saves Profile with the (hash of the) path as the name in the Dist namespace. -func (m *Profile) CreateInDist() error { - return m.CreateInNamespace(&database.DistProfiles) -} - -// CreateInCompany saves Profile with the (hash of the) path as the name in the Company namespace. -func (m *Profile) CreateInCompany() error { - return m.CreateInNamespace(&database.CompanyProfiles) -} - -// Save saves Profile. -func (m *Profile) Save() error { - return m.SaveObject(m) -} - -// String returns a string representation of Profile. -func (m *Profile) String() string { - if m.Default { - return "[D] " + m.Name - } - return m.Name -} - -// GetProfile fetches Profile with the provided name from the default namespace. -func GetProfile(name string) (*Profile, error) { - return GetProfileFromNamespace(&database.Profiles, name) -} - -// GetProfileFromNamespace fetches Profile with the provided name from the provided namespace. -func GetProfileFromNamespace(namespace *datastore.Key, name string) (*Profile, error) { - object, err := database.GetAndEnsureModel(namespace, name, profileModel) - if err != nil { - return nil, err - } - model, ok := object.(*Profile) - if !ok { - return nil, database.NewMismatchError(object, profileModel) - } - return model, nil -} - -// GetActiveProfileByPath fetches Profile with the (hash of the) path as the name from the default namespace. -func GetActiveProfileByPath(path string) (*Profile, error) { - return GetProfileFromNamespace(&database.Profiles, hex.EncodeToString([]byte(path))) - // TODO: check for locked authoritative default profiles -} - -// FindProfileByPath looks for a Profile first in the Company namespace and then in the Dist namespace. Should no Profile be available it searches for a Default Profile. -func FindProfileByPath(path, homeDir string) (profile *Profile, err error) { - name := hex.EncodeToString([]byte(path)) - var homeName string - slashedHomeDir := strings.TrimRight(homeDir, "/") + "/" - if homeDir != "" && strings.HasPrefix(path, slashedHomeDir) { - homeName = hex.EncodeToString([]byte("~/" + path[len(slashedHomeDir):])) - } - - // check for available company profiles - profile, err = GetProfileFromNamespace(&database.CompanyProfiles, name) - if err != database.ErrNotFound { - if err == nil { - return profile.Activate() - } - return - } - if homeName != "" { - profile, err = GetProfileFromNamespace(&database.CompanyProfiles, homeName) - if err != database.ErrNotFound { - if err == nil { - return profile.Activate() - } - return - } - } - - // check for available dist profiles - profile, err = GetProfileFromNamespace(&database.DistProfiles, name) - if err != database.ErrNotFound { - if err == nil { - return profile.Activate() - } - return - } - if homeName != "" { - profile, err = GetProfileFromNamespace(&database.DistProfiles, homeName) - if err != database.ErrNotFound { - if err == nil { - return profile.Activate() - } - return - } - } - - // search for best-matching default profile - err = nil - profile, _ = SearchForDefaultProfile(name, homeName, len(slashedHomeDir)-2, &database.Profiles) - return -} - -func (m *Profile) Activate() (*Profile, error) { - return m, m.Create() -} - -func SearchForDefaultProfile(matchKey, matchHomeKey string, addHomeDirLen int, namespace *datastore.Key) (*Profile, int) { - - // log.Tracef("profiles: searching for default profile with %s", matchKey) - - query := dsq.Query{ - Prefix: namespace.ChildString("Profile:d-").String(), - } - - // filter := LongestMatch{ - // Offset: len(query.Prefix), - // Longest: 0, - // Match: hex.EncodeToString([]byte(path)), - // } - // query.Filters = []dsq.Filter{ - // filter, - // } - - prefixOffset := len(query.Prefix) - longest := 0 - var longestMatch interface{} - - currentLongestIsHomeBased := false - currentLength := 0 - - iterator, err := database.Query(query) - if err != nil { - return nil, 0 - } - - for entry, ok := iterator.NextSync(); ok; entry, ok = iterator.NextSync() { - // log.Tracef("profiles: checking %s for default profile", entry.Key) - // TODO: prioritize locked profiles - prefix := entry.Key[prefixOffset:] - // skip directly if current longest match is longer than the key - // log.Tracef("profiles: comparing %s to %s", matchKey, prefix) - - switch { - case strings.HasPrefix(matchKey, prefix): - currentLength = len(prefix) - currentLongestIsHomeBased = false - case strings.HasPrefix(matchHomeKey, prefix): - currentLength = len(prefix) + addHomeDirLen - currentLongestIsHomeBased = true - default: - continue - } - // longest wins, if a root-based and home-based tie, root-based wins. - if currentLength > longest || (currentLongestIsHomeBased && currentLength == longest) { - longest = currentLength - longestMatch = entry.Value - // log.Tracef("profiles: found new longest (%d) default profile match: %s", currentLength, entry.Key) - } - - } - - if longestMatch == nil { - return nil, 0 - } - matched, ok := longestMatch.(database.Model) - if !ok { - log.Warningf("profiles: matched default profile is not of type database.Model") - return nil, 0 - } - - profile, ok := database.SilentEnsureModel(matched, profileModel).(*Profile) - if !ok { - log.Warningf("profiles: matched default profile is not of type *Profile") - return nil, 0 - } - return profile, longest -} diff --git a/profiles/profileflags.go b/profiles/profileflags.go deleted file mode 100644 index 219f350d..00000000 --- a/profiles/profileflags.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package profiles - -import ( - "errors" - "strings" -) - -// ProfileFlags are used to quickly add common attributes to profiles -type ProfileFlags []int8 - -const ( - // Who? - // System apps must be run by system user, else deny - System int8 = iota + 1 - // Admin apps must be run by user with admin privileges, else deny - Admin - // User apps must be run by user (identified by having an active safing UI), else deny - User - - // Where? - // Internet apps may connect to the Internet, if unset, all connections to the Internet are denied - Internet - // LocalNet apps may connect to the local network (i.e. private IP address spaces), if unset, all connections to the local network are denied - LocalNet - - // How? - // Strict apps may only connect to domains that are related to themselves - Strict - // Gateway apps will connect to user-defined servers - Gateway - // Browser apps connect to multitudes of different servers and require special handling - Browser - // Directconnect apps may connect to any IP without dns association (e.g. P2P apps, network analysis tools) - Directconnect - // Service apps may accept incoming connections - Service -) - -var ( - // ErrProfileFlagsParseFailed is returned if a an invalid flag is encountered while parsing - ErrProfileFlagsParseFailed = errors.New("profiles: failed to parse flags") - - sortedFlags = &ProfileFlags{ - System, - Admin, - User, - Internet, - LocalNet, - Strict, - Gateway, - Service, - Directconnect, - Browser, - } - - flagIDs = map[string]int8{ - "System": System, - "Admin": Admin, - "User": User, - "Internet": Internet, - "LocalNet": LocalNet, - "Strict": Strict, - "Gateway": Gateway, - "Service": Service, - "Directconnect": Directconnect, - "Browser": Browser, - } - - flagNames = map[int8]string{ - System: "System", - Admin: "Admin", - User: "User", - Internet: "Internet", - LocalNet: "LocalNet", - Strict: "Strict", - Gateway: "Gateway", - Service: "Service", - Directconnect: "Directconnect", - Browser: "Browser", - } -) - -// FlagsFromNames creates ProfileFlags from a comma seperated list of flagnames (e.g. "System,Strict,Secure") -func FlagsFromNames(words []string) (*ProfileFlags, error) { - var flags ProfileFlags - for _, entry := range words { - flag, ok := flagIDs[entry] - if !ok { - return nil, ErrProfileFlagsParseFailed - } - flags = append(flags, flag) - } - return &flags, nil -} - -// Has checks if a ProfileFlags object has a flag -func (pf *ProfileFlags) Has(searchFlag int8) bool { - for _, flag := range *pf { - if flag == searchFlag { - return true - } - } - return false -} - -// String return a string representation of ProfileFlags -func (pf *ProfileFlags) String() string { - var namedFlags []string - for _, flag := range *pf { - namedFlags = append(namedFlags, flagNames[flag]) - } - return strings.Join(namedFlags, ",") -} diff --git a/profiles/profileflags_test.go b/profiles/profileflags_test.go deleted file mode 100644 index a2258d75..00000000 --- a/profiles/profileflags_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package profiles - -import ( - "strings" - "testing" -) - -func TestProfileFlags(t *testing.T) { - - // check if SYSTEM is 1 - if System != 1 { - t.Errorf("System ist first const and must be 1") - } - if Admin != 2 { - t.Errorf("Admin ist second const and must be 2") - } - - // check if all IDs have a name - for key, entry := range flagIDs { - if _, ok := flagNames[entry]; !ok { - t.Errorf("could not find entry for %s in flagNames", key) - } - } - - // check if all names have an ID - for key, entry := range flagNames { - if _, ok := flagIDs[entry]; !ok { - t.Errorf("could not find entry for %d in flagNames", key) - } - } - - // check Has - emptyFlags := ProfileFlags{} - for flag, name := range flagNames { - if !sortedFlags.Has(flag) { - t.Errorf("sortedFlags should have flag %s (%d)", name, flag) - } - if emptyFlags.Has(flag) { - t.Errorf("emptyFlags should not have flag %s (%d)", name, flag) - } - } - - // check ProfileFlags creation from strings - var allFlagStrings []string - for _, flag := range *sortedFlags { - allFlagStrings = append(allFlagStrings, flagNames[flag]) - } - newFlags, err := FlagsFromNames(allFlagStrings) - if err != nil { - t.Errorf("error while parsing flags: %s", err) - } - if newFlags.String() != sortedFlags.String() { - t.Errorf("parsed flags are not correct (or tests have not been updated to reflect the right number), expected %v, got %v", *sortedFlags, *newFlags) - } - - // check ProfileFlags Stringer - flagString := newFlags.String() - check := strings.Join(allFlagStrings, ",") - if flagString != check { - t.Errorf("flag string is not correct, expected %s, got %s", check, flagString) - } - -} diff --git a/profiles/sampledata.go b/profiles/sampledata.go deleted file mode 100644 index c2050717..00000000 --- a/profiles/sampledata.go +++ /dev/null @@ -1,263 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package profiles - -import ( - "runtime" - - ds "github.com/ipfs/go-datastore" - - "github.com/Safing/safing-core/database" - "github.com/Safing/safing-core/log" -) - -func init() { - - // Data here is for demo purposes, Profiles will be served over network soon™. - - log.Tracef("profiles: loading sample profiles for %s", runtime.GOOS) - - switch runtime.GOOS { - case "linux": - - log.Trace("profiles: loading linux sample profiles") - - (&Profile{ - Name: "Chromium", - Description: "Browser by Google", - Path: "/usr/lib/chromium-browser/chromium-browser", - Flags: []int8{User, Internet, LocalNet, Browser}, - ConnectPorts: []uint16{80, 443}, - }).CreateInDist() - - (&Profile{ - Name: "Evolution", - Description: "PIM solution by GNOME", - Path: "/usr/bin/evolution", - Flags: []int8{User, Internet, Gateway}, - ConnectPorts: []uint16{25, 80, 143, 443, 465, 587, 993, 995}, - SecurityLevel: 2, - }).CreateInDist() - - (&Profile{ - Name: "Evolution Calendar", - Description: "PIM solution by GNOME - Calendar", - Path: "/usr/lib/evolution/evolution-calendar-factory-subprocess", - Flags: []int8{User, Internet, Gateway}, - ConnectPorts: []uint16{80, 443}, - SecurityLevel: 2, - }).CreateInDist() - - (&Profile{ - Name: "Spotify", - Description: "Music streaming", - Path: "/usr/share/spotify/spotify", - ConnectPorts: []uint16{80, 443, 4070}, - Flags: []int8{User, Internet, Strict}, - }).CreateInDist() - - (&Profile{ - // flatpak edition - Name: "Spotify", - Description: "Music streaming", - Path: "/newroot/app/extra/share/spotify/spotify", - ConnectPorts: []uint16{80, 443, 4070}, - Flags: []int8{User, Internet, Strict}, - }).CreateInDist() - - (&Profile{ - Name: "Evince", - Description: "PDF Document Reader", - Path: "/usr/bin/evince", - Flags: []int8{}, - SecurityLevel: 2, - }).CreateInDist() - - (&Profile{ - Name: "Ahavi", - Description: "mDNS service", - Path: "/usr/bin/avahi-daemon", - Flags: []int8{System, LocalNet, Service, Directconnect}, - }).CreateInDist() - - (&Profile{ - Name: "Python 2.7 Framework", - Description: "Correctly handle python scripts", - Path: "/usr/bin/python2.7", - Framework: &Framework{ - Find: "^[^ ]+ ([^ ]+)", - Build: "{1}|{CWD}/{1}", - }, - }).CreateInDist() - - (&Profile{ - Name: "Python 3.5 Framework", - Description: "Correctly handle python scripts", - Path: "/usr/bin/python3.5", - Framework: &Framework{ - Find: "^[^ ]+ ([^ ]+)", - Build: "{1}|{CWD}/{1}", - }, - }).CreateInDist() - - (&Profile{ - Name: "DHCP Client", - Description: "Client software for the DHCP protocol", - Path: "/sbin/dhclient", - Framework: &Framework{ - FindParent: 1, - MergeWithParent: true, - }, - }).CreateInDist() - - // Default Profiles - // Until Profiles are distributed over the network, default profiles are activated when the Default Profile for "/" is missing. - - if ok, err := database.Has(ds.NewKey("/Data/Profiles/Profile_d-2f")); !ok || err != nil { - - log.Trace("profiles: loading linux default sample profiles") - - (&Profile{ - Name: "Default Base", - Description: "Default Profile for /", - Path: "/", - Flags: []int8{Internet, LocalNet, Strict}, - Default: true, - }).Create() - - (&Profile{ - Name: "Installed Applications", - Description: "Default Profile for /usr/bin", - Path: "/usr/bin/", - Flags: []int8{Internet, LocalNet, Gateway}, - Default: true, - }).Create() - - (&Profile{ - Name: "System Binaries (/sbin)", - Description: "Default Profile for ~/Downloads", - Path: "/sbin/", - Flags: []int8{Internet, LocalNet, Directconnect, Service, System}, - Default: true, - }).Create() - - (&Profile{ - Name: "System Binaries (/usr/sbin)", - Description: "Default Profile for ~/Downloads", - Path: "/usr/sbin/", - Flags: []int8{Internet, LocalNet, Directconnect, Service, System}, - Default: true, - }).Create() - - (&Profile{ - Name: "System Tmp folder", - Description: "Default Profile for /tmp", - Path: "/tmp/", - Flags: []int8{}, // deny all - Default: true, - }).Create() - - (&Profile{ - Name: "User Home", - Description: "Default Profile for ~/", - Path: "~/", - Flags: []int8{Internet, LocalNet, Gateway}, - Default: true, - }).Create() - - (&Profile{ - Name: "User Downloads", - Description: "Default Profile for ~/Downloads", - Path: "~/Downloads/", - Flags: []int8{}, // deny all - Default: true, - }).Create() - - (&Profile{ - Name: "User Cache", - Description: "Default Profile for ~/.cache", - Path: "~/.cache/", - Flags: []int8{}, // deny all - Default: true, - }).Create() - - } - - case "windows": - - log.Trace("profiles: loading windows sample profiles") - - (&Profile{ - Name: "Firefox", - Description: "Firefox Browser by Mozilla", - Path: "C:\\Program Files\\Mozilla Firefox\\firefox.exe", - Flags: []int8{User, Internet, LocalNet, Browser}, - ConnectPorts: []uint16{80, 443}, - }).CreateInDist() - - // Default Profiles - // Until Profiles are distributed over the network, default profiles are activated when the Default Profile for "C" is missing. - - if ok, err := database.Has(ds.NewKey("/Data/Profiles/Profile:d-C")); !ok || err != nil { - - log.Trace("profiles: loading windows default sample profiles") - - (&Profile{ - Name: "Default Base", - Description: "Default Profile for C", - Path: "C", - Flags: []int8{Internet, LocalNet, Strict}, - Default: true, - }).Create() - - (&Profile{ - Name: "Installed Applications", - Description: "Default Profile for C:\\Program Files", - Path: "C:\\Program Files\\", - Flags: []int8{Internet, LocalNet, Gateway}, - Default: true, - }).Create() - - (&Profile{ - Name: "Installed Applications (x86)", - Description: "Default Profile for C:\\Program Files (x86)", - Path: "C:\\Program Files (x86)\\", - Flags: []int8{Internet, LocalNet, Gateway}, - Default: true, - }).Create() - - (&Profile{ - Name: "System Applications (C:\\Windows\\System32)", - Description: "Default Profile for C:\\Windows\\System32", - Path: "C:\\Windows\\System32\\", - Flags: []int8{Internet, LocalNet, Directconnect, Service, System}, - Default: true, - }).Create() - - (&Profile{ - Name: "User Home", - Description: "Default Profile for ~/", - Path: "~/", - Flags: []int8{Internet, LocalNet, Gateway}, - Default: true, - }).Create() - - (&Profile{ - Name: "User Downloads", - Description: "Default Profile for ~/Downloads", - Path: "~/Downloads/", - Flags: []int8{}, // deny all - Default: true, - }).Create() - - (&Profile{ - Name: "User Cache", - Description: "Default Profile for ~/.cache", - Path: "~/.cache/", - Flags: []int8{}, // deny all - Default: true, - }).Create() - } - } - -} diff --git a/status/const.go b/status/const.go new file mode 100644 index 00000000..db12ed68 --- /dev/null +++ b/status/const.go @@ -0,0 +1,20 @@ +package status + +// Definitions of Security and Status Levels +const ( + SecurityLevelOff uint8 = 0 + + SecurityLevelDynamic uint8 = 1 + SecurityLevelSecure uint8 = 2 + SecurityLevelFortress uint8 = 4 + + SecurityLevelsDynamicAndSecure uint8 = SecurityLevelDynamic | SecurityLevelSecure + SecurityLevelsDynamicAndFortress uint8 = SecurityLevelDynamic | SecurityLevelFortress + SecurityLevelsSecureAndFortress uint8 = SecurityLevelSecure | SecurityLevelFortress + SecurityLevelsAll uint8 = SecurityLevelDynamic | SecurityLevelSecure | SecurityLevelFortress + + StatusOff uint8 = 0 + StatusError uint8 = 1 + StatusWarning uint8 = 2 + StatusOk uint8 = 3 +) diff --git a/status/get-config.go b/status/get-config.go new file mode 100644 index 00000000..134a9cef --- /dev/null +++ b/status/get-config.go @@ -0,0 +1,33 @@ +package status + +import ( + "github.com/Safing/portbase/config" +) + +type ( + // SecurityLevelOption defines the returned function by ConfigIsActive. + SecurityLevelOption func(minSecurityLevel uint8) bool +) + +func max(a, b uint8) uint8 { + if a > b { + return a + } + return b +} + +// ConfigIsActive returns whether the given security level dependent config option is on or off. +func ConfigIsActive(name string) SecurityLevelOption { + activeAtLevel := config.GetAsInt(name, int64(SecurityLevelDynamic)) + return func(minSecurityLevel uint8) bool { + return uint8(activeAtLevel()) <= max(CurrentSecurityLevel(), minSecurityLevel) + } +} + +// ConfigIsActiveConcurrent returns whether the given security level dependent config option is on or off and is concurrency safe. +func ConfigIsActiveConcurrent(name string) SecurityLevelOption { + activeAtLevel := config.Concurrent.GetAsInt(name, int64(SecurityLevelDynamic)) + return func(minSecurityLevel uint8) bool { + return uint8(activeAtLevel()) <= max(CurrentSecurityLevel(), minSecurityLevel) + } +} diff --git a/status/get.go b/status/get.go new file mode 100644 index 00000000..158a36eb --- /dev/null +++ b/status/get.go @@ -0,0 +1,54 @@ +package status + +import ( + "sync/atomic" +) + +var ( + currentSecurityLevel *uint32 + selectedSecurityLevel *uint32 + threatLevel *uint32 + portmasterStatus *uint32 + gate17Status *uint32 +) + +func init() { + var ( + currentSecurityLevelValue uint32 + selectedSecurityLevelValue uint32 + threatLevelValue uint32 + portmasterStatusValue uint32 + gate17StatusValue uint32 + ) + + currentSecurityLevel = ¤tSecurityLevelValue + selectedSecurityLevel = &selectedSecurityLevelValue + threatLevel = &threatLevelValue + portmasterStatus = &portmasterStatusValue + gate17Status = &gate17StatusValue +} + +// CurrentSecurityLevel returns the current security level. +func CurrentSecurityLevel() uint8 { + return uint8(atomic.LoadUint32(currentSecurityLevel)) +} + +// SelectedSecurityLevel returns the selected security level. +func SelectedSecurityLevel() uint8 { + return uint8(atomic.LoadUint32(selectedSecurityLevel)) +} + +// ThreatLevel returns the current threat level. +func ThreatLevel() uint8 { + return uint8(atomic.LoadUint32(threatLevel)) +} + +// PortmasterStatus returns the current Portmaster status. +func PortmasterStatus() uint8 { + return uint8(atomic.LoadUint32(portmasterStatus)) +} + +// Gate17Status returns the current Gate17 status. +func Gate17Status() uint8 { + return uint8(atomic.LoadUint32(gate17Status)) +} diff --git a/status/get_test.go b/status/get_test.go new file mode 100644 index 00000000..2777157e --- /dev/null +++ b/status/get_test.go @@ -0,0 +1,18 @@ +package status + +import "testing" + +func TestGet(t *testing.T) { + + // only test for panics + CurrentSecurityLevel() + SelectedSecurityLevel() + ThreatLevel() + PortmasterStatus() + Gate17Status() + option := ConfigIsActive("invalid") + option(0) + option = ConfigIsActiveConcurrent("invalid") + option(0) + +} diff --git a/status/module.go b/status/module.go new file mode 100644 index 00000000..bd0c3a42 --- /dev/null +++ b/status/module.go @@ -0,0 +1,25 @@ +package status + +import ( + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/modules" +) + +func init() { + modules.Register("status", prep, nil, nil) +} + +func prep() error { + + if CurrentSecurityLevel() == SecurityLevelOff { + log.Infof("switching to default active security level: dynamic") + SetCurrentSecurityLevel(SecurityLevelDynamic) + } + + if SelectedSecurityLevel() == SecurityLevelOff { + log.Infof("switching to default selected security level: dynamic") + SetSelectedSecurityLevel(SecurityLevelDynamic) + } + + return nil +} diff --git a/status/set.go b/status/set.go new file mode 100644 index 00000000..f46b3ded --- /dev/null +++ b/status/set.go @@ -0,0 +1,65 @@ +package status + +import "sync/atomic" + +// SetCurrentSecurityLevel sets the current security level. +func SetCurrentSecurityLevel(level uint8) { + sysStatusLock.Lock() + defer sysStatusLock.Unlock() + sysStatus.CurrentSecurityLevel = level + atomicUpdateCurrentSecurityLevel(level) +} + +// SetSelectedSecurityLevel sets the selected security level. +func SetSelectedSecurityLevel(level uint8) { + sysStatusLock.Lock() + defer sysStatusLock.Unlock() + sysStatus.SelectedSecurityLevel = level + atomicUpdateSelectedSecurityLevel(level) +} + +// SetThreatLevel sets the current threat level. +func SetThreatLevel(level uint8) { + sysStatusLock.Lock() + defer sysStatusLock.Unlock() + sysStatus.ThreatLevel = level + atomicUpdateThreatLevel(level) +} + +// SetPortmasterStatus sets the current Portmaster status. +func SetPortmasterStatus(status uint8) { + sysStatusLock.Lock() + defer sysStatusLock.Unlock() + sysStatus.PortmasterStatus = status + atomicUpdatePortmasterStatus(status) +} + +// SetGate17Status sets the current Gate17 status. +func SetGate17Status(status uint8) { + sysStatusLock.Lock() + defer sysStatusLock.Unlock() + sysStatus.Gate17Status = status + atomicUpdateGate17Status(status) +} + +// update functions for atomic stuff + +func atomicUpdateCurrentSecurityLevel(level uint8) { + atomic.StoreUint32(currentSecurityLevel, uint32(level)) +} + +func atomicUpdateSelectedSecurityLevel(level uint8) { + atomic.StoreUint32(selectedSecurityLevel, uint32(level)) +} + +func atomicUpdateThreatLevel(level uint8) { + atomic.StoreUint32(threatLevel, uint32(level)) +} + +func atomicUpdatePortmasterStatus(status uint8) { + atomic.StoreUint32(portmasterStatus, uint32(status)) +} + +func atomicUpdateGate17Status(status uint8) { + atomic.StoreUint32(gate17Status, uint32(status)) +} diff --git a/status/set_test.go b/status/set_test.go new file mode 100644 index 00000000..626882ab --- /dev/null +++ b/status/set_test.go @@ -0,0 +1,14 @@ +package status + +import "testing" + +func TestSet(t *testing.T) { + + // only test for panics + SetCurrentSecurityLevel(0) + SetSelectedSecurityLevel(0) + SetThreatLevel(0) + SetPortmasterStatus(0) + SetGate17Status(0) + +} diff --git a/status/status.go b/status/status.go new file mode 100644 index 00000000..86a34f8c --- /dev/null +++ b/status/status.go @@ -0,0 +1,63 @@ +package status + +import "sync" + +var ( + sysStatus *SystemStatus + sysStatusLock sync.RWMutex +) + +func init() { + sysStatus = &SystemStatus{} +} + +// SystemStatus saves basic information about the current system status. +type SystemStatus struct { + // database.Base + CurrentSecurityLevel uint8 + SelectedSecurityLevel uint8 + + ThreatLevel uint8 `json:",omitempty" bson:",omitempty"` + ThreatReason string `json:",omitempty" bson:",omitempty"` + + PortmasterStatus uint8 `json:",omitempty" bson:",omitempty"` + PortmasterStatusMsg string `json:",omitempty" bson:",omitempty"` + + Gate17Status uint8 `json:",omitempty" bson:",omitempty"` + Gate17StatusMsg string `json:",omitempty" bson:",omitempty"` +} + +// FmtCurrentSecurityLevel returns the current security level as a string. +func FmtCurrentSecurityLevel() string { + current := CurrentSecurityLevel() + selected := SelectedSecurityLevel() + s := FmtSecurityLevel(current) + if current != selected { + s += "*" + } + return s +} + +// FmtSecurityLevel returns the given security level as a string. +func FmtSecurityLevel(level uint8) string { + switch level { + case SecurityLevelOff: + return "Off" + case SecurityLevelDynamic: + return "Dynamic" + case SecurityLevelSecure: + return "Secure" + case SecurityLevelFortress: + return "Fortress" + case SecurityLevelsDynamicAndSecure: + return "Dynamic and Secure" + case SecurityLevelsDynamicAndFortress: + return "Dynamic and Fortress" + case SecurityLevelsSecureAndFortress: + return "Secure and Fortress" + case SecurityLevelsAll: + return "Dynamic, Secure and Fortress" + default: + return "INVALID" + } +} diff --git a/status/status_test.go b/status/status_test.go new file mode 100644 index 00000000..d4c4d328 --- /dev/null +++ b/status/status_test.go @@ -0,0 +1,36 @@ +package status + +import "testing" + +func TestStatus(t *testing.T) { + + SetCurrentSecurityLevel(SecurityLevelOff) + SetSelectedSecurityLevel(SecurityLevelOff) + if FmtCurrentSecurityLevel() != "Off" { + t.Error("unexpected string representation") + } + + SetCurrentSecurityLevel(SecurityLevelDynamic) + SetSelectedSecurityLevel(SecurityLevelDynamic) + if FmtCurrentSecurityLevel() != "Dynamic" { + t.Error("unexpected string representation") + } + + SetCurrentSecurityLevel(SecurityLevelSecure) + SetSelectedSecurityLevel(SecurityLevelSecure) + if FmtCurrentSecurityLevel() != "Secure" { + t.Error("unexpected string representation") + } + + SetCurrentSecurityLevel(SecurityLevelFortress) + SetSelectedSecurityLevel(SecurityLevelFortress) + if FmtCurrentSecurityLevel() != "Fortress" { + t.Error("unexpected string representation") + } + + SetSelectedSecurityLevel(SecurityLevelDynamic) + if FmtCurrentSecurityLevel() != "Fortress*" { + t.Error("unexpected string representation") + } + +} diff --git a/ui/launch.go b/ui/launch.go new file mode 100644 index 00000000..e0f4086b --- /dev/null +++ b/ui/launch.go @@ -0,0 +1,66 @@ +package ui + +import ( + "errors" + "flag" + "fmt" + "os" + "os/exec" + "runtime" + + "github.com/Safing/portbase/modules" + "github.com/Safing/portmaster/updates" +) + +var ( + launchUI bool +) + +func init() { + flag.BoolVar(&launchUI, "ui", false, "launch user interface and exit") +} + +func launchUIByFlag() error { + if !launchUI { + return nil + } + + err := updates.ReloadLatest() + if err != nil { + return err + } + + osAndPlatform := fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH) + + switch osAndPlatform { + case "linux_amd64": + + file, err := updates.GetPlatformFile("app/portmaster-ui") + if err != nil { + return fmt.Errorf("ui currently not available: %s - you may need to first start portmaster and wait for it to fetch the update index", err) + } + + // check permission + info, err := os.Stat(file.Path()) + if info.Mode() != 0755 { + fmt.Printf("%v\n", info.Mode()) + err := os.Chmod(file.Path(), 0755) + if err != nil { + return fmt.Errorf("failed to set exec permissions on %s: %s", file.Path(), err) + } + } + + // exec + cmd := exec.Command(file.Path()) + err = cmd.Start() + if err != nil { + return fmt.Errorf("failed to start ui: %s", err) + } + + // gracefully exit portmaster + return modules.ErrCleanExit + + default: + return errors.New("this os/platform is no UI support yet") + } +} diff --git a/ui/module.go b/ui/module.go new file mode 100644 index 00000000..16acf9b1 --- /dev/null +++ b/ui/module.go @@ -0,0 +1,18 @@ +package ui + +import ( + "github.com/Safing/portbase/modules" +) + +func init() { + modules.Register("ui", prep, nil, nil, "updates", "api") +} + +func prep() error { + err := launchUIByFlag() + if err != nil { + return err + } + + return registerRoutes() +} diff --git a/ui/serve.go b/ui/serve.go new file mode 100644 index 00000000..d6ef8af6 --- /dev/null +++ b/ui/serve.go @@ -0,0 +1,139 @@ +package ui + +import ( + "fmt" + "io" + "mime" + "net/http" + "net/url" + "path/filepath" + "strings" + "sync" + + resources "github.com/cookieo9/resources-go" + "github.com/gorilla/mux" + + "github.com/Safing/portbase/api" + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/updates" +) + +var ( + apps = make(map[string]*resources.BundleSequence) + appsLock sync.RWMutex + assets *resources.BundleSequence + assetsLock sync.RWMutex +) + +func registerRoutes() error { + api.RegisterHandleFunc("/assets/{resPath:[a-zA-Z0-9/\\._-]+}", ServeBundle("assets")).Methods("GET", "HEAD") + api.RegisterHandleFunc("/ui/modules/{moduleName:[a-z]+}", redirAddSlash).Methods("GET", "HEAD") + api.RegisterHandleFunc("/ui/modules/{moduleName:[a-z]+}/", ServeBundle("")).Methods("GET", "HEAD") + api.RegisterHandleFunc("/ui/modules/{moduleName:[a-z]+}/{resPath:[a-zA-Z0-9/\\._-]+}", ServeBundle("")).Methods("GET", "HEAD") + api.RegisterHandleFunc("/", RedirectToBase) + + return nil +} + +// ServeBundle serves bundles. +func ServeBundle(defaultModuleName string) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + + // log.Tracef("ui: request for %s", r.RequestURI) + + vars := mux.Vars(r) + moduleName, ok := vars["moduleName"] + if !ok { + moduleName = defaultModuleName + if moduleName == "" { + http.Error(w, "missing module name", http.StatusBadRequest) + return + } + } + + resPath, ok := vars["resPath"] + if !ok || strings.HasSuffix(resPath, "/") { + resPath = "index.html" + } + + appsLock.RLock() + bundle, ok := apps[moduleName] + appsLock.RUnlock() + if ok { + ServeFileFromBundle(w, r, moduleName, bundle, resPath) + return + } + + // get file from update system + zipFile, err := updates.GetFile(fmt.Sprintf("ui/modules/%s.zip", moduleName)) + if err != nil { + if err == updates.ErrNotFound { + log.Tracef("ui: requested module %s does not exist", moduleName) + http.Error(w, err.Error(), http.StatusNotFound) + } else { + log.Tracef("ui: error loading module %s: %s", moduleName, err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + + // open bundle + newBundle, err := resources.OpenZip(zipFile.Path()) + if err != nil { + log.Tracef("ui: error prepping module %s: %s", moduleName, err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + bundle = &resources.BundleSequence{newBundle} + appsLock.Lock() + apps[moduleName] = bundle + appsLock.Unlock() + + ServeFileFromBundle(w, r, moduleName, bundle, resPath) + } +} + +// ServeFileFromBundle serves a file from the given bundle. +func ServeFileFromBundle(w http.ResponseWriter, r *http.Request, bundleName string, bundle *resources.BundleSequence, path string) { + readCloser, err := bundle.Open(path) + if err != nil { + log.Tracef("ui: error opening module %s: %s", bundleName, err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + _, ok := w.Header()["Content-Type"] + if !ok { + contentType := mime.TypeByExtension(filepath.Ext(path)) + if contentType != "" { + w.Header().Set("Content-Type", contentType) + } + } + + w.WriteHeader(http.StatusOK) + if r.Method != "HEAD" { + _, err = io.Copy(w, readCloser) + if err != nil { + log.Errorf("ui: failed to serve file: %s", err) + return + } + } + + readCloser.Close() + return +} + +// RedirectToBase redirects the requests to the control app +func RedirectToBase(w http.ResponseWriter, r *http.Request) { + u, err := url.Parse("/ui/modules/base/") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + http.Redirect(w, r, r.URL.ResolveReference(u).String(), http.StatusPermanentRedirect) +} + +func redirAddSlash(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, r.RequestURI+"/", http.StatusPermanentRedirect) +} diff --git a/updates/doc.go b/updates/doc.go new file mode 100644 index 00000000..4dd1c7c3 --- /dev/null +++ b/updates/doc.go @@ -0,0 +1,9 @@ +package updates + +// current paths: +// all/ui/assets.zip +// all/ui/modules/base.zip +// all/ui/modules/settings.zip +// all/ui/modules/profilemgr.zip +// all/ui/modules/monitor.zip +// linux_amd64/app/portmaster-ui diff --git a/updates/fetch.go b/updates/fetch.go new file mode 100644 index 00000000..9330d0c5 --- /dev/null +++ b/updates/fetch.go @@ -0,0 +1,122 @@ +package updates + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "time" + + "github.com/google/renameio" + + "github.com/Safing/portbase/log" +) + +var ( + updateURLs = []string{ + "https://updates.safing.io", + } +) + +func fetchFile(realFilepath, updateFilepath string, tries int) error { + // backoff when retrying + if tries > 0 { + time.Sleep(time.Duration(tries*tries) * time.Second) + } + + // create URL + downloadURL, err := joinURLandPath(updateURLs[tries%len(updateURLs)], updateFilepath) + if err != nil { + return fmt.Errorf("error build url (%s + %s): %s", updateURLs[tries%len(updateURLs)], updateFilepath, err) + } + + // create destination dir + dirPath := filepath.Dir(realFilepath) + err = os.MkdirAll(dirPath, 0755) + if err != nil { + return fmt.Errorf("updates: could not create updates folder: %s", dirPath) + } + + // open file for writing + atomicFile, err := renameio.TempFile(filepath.Join(updateStoragePath, "tmp"), realFilepath) + if err != nil { + return fmt.Errorf("updates: could not create temp file for download: %s", err) + } + defer atomicFile.Cleanup() + + // start file download + resp, err := http.Get(downloadURL) + if err != nil { + return fmt.Errorf("error fetching url (%s): %s", downloadURL, err) + } + defer resp.Body.Close() + + // download and write file + n, err := io.Copy(atomicFile, resp.Body) + if err != nil { + return fmt.Errorf("failed downloading %s: %s", downloadURL, err) + } + if resp.ContentLength != n { + return fmt.Errorf("download unfinished, written %d out of %d bytes.", n, resp.ContentLength) + } + + // finalize file + err = atomicFile.CloseAtomicallyReplace() + if err != nil { + return fmt.Errorf("updates: failed to finalize file %s: %s", realFilepath, err) + } + // set permissions + err = os.Chmod(realFilepath, 0644) + if err != nil { + log.Warningf("updates: failed to set permissions on downloaded file %s: %s", realFilepath, err) + } + + log.Infof("update: fetched %s (stored to %s)", downloadURL, realFilepath) + return nil +} + +func fetchData(downloadPath string, tries int) ([]byte, error) { + // backoff when retrying + if tries > 0 { + time.Sleep(time.Duration(tries*tries) * time.Second) + } + + // create URL + downloadURL, err := joinURLandPath(updateURLs[tries%len(updateURLs)], downloadPath) + if err != nil { + return nil, fmt.Errorf("error build url (%s + %s): %s", updateURLs[tries%len(updateURLs)], downloadPath, err) + } + + // start file download + resp, err := http.Get(downloadURL) + if err != nil { + return nil, fmt.Errorf("error fetching url (%s): %s", downloadURL, err) + } + defer resp.Body.Close() + + // download and write file + buf := bytes.NewBuffer(make([]byte, 0, resp.ContentLength)) + n, err := io.Copy(buf, resp.Body) + if err != nil { + return nil, fmt.Errorf("failed downloading %s: %s", downloadURL, err) + } + if resp.ContentLength != n { + return nil, fmt.Errorf("download unfinished, written %d out of %d bytes.", n, resp.ContentLength) + } + + return buf.Bytes(), nil +} + +func joinURLandPath(baseURL, urlPath string) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", err + } + + u.Path = path.Join(u.Path, urlPath) + return u.String(), nil +} diff --git a/updates/file.go b/updates/file.go new file mode 100644 index 00000000..3f0d133a --- /dev/null +++ b/updates/file.go @@ -0,0 +1,46 @@ +package updates + +// File represents a file from the update system. +type File struct { + filepath string + version string + stable bool +} + +func newFile(filepath string, version string, stable bool) *File { + return &File{ + filepath: filepath, + version: version, + stable: stable, + } +} + +// Path returns the filepath of the file. +func (f *File) Path() string { + return f.filepath +} + +// Version returns the version of the file. +func (f *File) Version() string { + return f.version +} + +// Stable returns whether the file is from a stable release. +func (f *File) Stable() bool { + return f.stable +} + +// Open opens the file and returns the +func (f *File) Open() { + +} + +// ReportError reports an error back to Safing. This will not automatically blacklist the file. +func (f *File) ReportError() { + +} + +// Blacklist notifies the update system that this file is somehow broken, and should be ignored from now on. +func (f *File) Blacklist() { + +} diff --git a/updates/filename.go b/updates/filename.go new file mode 100644 index 00000000..cca6cc4f --- /dev/null +++ b/updates/filename.go @@ -0,0 +1,41 @@ +package updates + +import ( + "fmt" + "regexp" + "strings" +) + +var versionRegex = regexp.MustCompile("_v[0-9]+-[0-9]+-[0-9]+b?") + +func getIdentifierAndVersion(versionedPath string) (identifier, version string, ok bool) { + // extract version + rawVersion := versionRegex.FindString(versionedPath) + if rawVersion == "" { + return "", "", false + } + + // replace - with . and trim _ + version = strings.Replace(strings.TrimLeft(rawVersion, "_v"), "-", ".", -1) + + // put together without version + i := strings.Index(versionedPath, rawVersion) + if i < 0 { + // extracted version not in string (impossible) + return "", "", false + } + return versionedPath[:i] + versionedPath[i+len(rawVersion):], version, true +} + +func getVersionedPath(identifier, version string) (versionedPath string) { + // split in half + splittedFilePath := strings.SplitN(identifier, ".", 2) + // replace . with - + transformedVersion := strings.Replace(version, ".", "-", -1) + + // put together + if len(splittedFilePath) == 1 { + return fmt.Sprintf("%s_v%s", splittedFilePath[0], transformedVersion) + } + return fmt.Sprintf("%s_v%s.%s", splittedFilePath[0], transformedVersion, splittedFilePath[1]) +} diff --git a/updates/get.go b/updates/get.go new file mode 100644 index 00000000..d72ac0a3 --- /dev/null +++ b/updates/get.go @@ -0,0 +1,77 @@ +package updates + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + + "github.com/Safing/portbase/log" +) + +var ( + ErrNotFound = errors.New("the requested file could not be found") +) + +// GetPlatformFile returns the latest platform specific file identified by the given identifier. +func GetPlatformFile(identifier string) (*File, error) { + identifier = filepath.Join(fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), identifier) + // From https://golang.org/pkg/runtime/#GOARCH + // GOOS is the running program's operating system target: one of darwin, freebsd, linux, and so on. + // GOARCH is the running program's architecture target: one of 386, amd64, arm, s390x, and so on. + return loadOrFetchFile(identifier) +} + +// GetFile returns the latest generic file identified by the given identifier. +func GetFile(identifier string) (*File, error) { + identifier = filepath.Join("all", identifier) + return loadOrFetchFile(identifier) +} + +func getLatestFilePath(identifier string) (versionedFilePath, version string, stable bool, ok bool) { + updatesLock.RLock() + version, ok = stableUpdates[identifier] + if !ok { + version, ok = latestUpdates[identifier] + if !ok { + log.Tracef("updates: file %s does not exist", identifier) + return "", "", false, false + // TODO: if in development mode, reload latest index to check for newly sideloaded updates + // err := reloadLatest() + } + } + updatesLock.RUnlock() + + // TODO: Fix for stable release + return getVersionedPath(identifier, version), version, false, true +} + +func loadOrFetchFile(identifier string) (*File, error) { + versionedFilePath, version, stable, ok := getLatestFilePath(identifier) + if !ok { + // TODO: if in development mode, search updates dir for sideloaded apps + return nil, ErrNotFound + } + + // build final filepath + realFilePath := filepath.Join(updateStoragePath, versionedFilePath) + if _, err := os.Stat(realFilePath); err == nil { + // file exists + return newFile(realFilePath, version, stable), nil + } + + // download file + log.Tracef("updates: starting download of %s", versionedFilePath) + var err error + for tries := 0; tries < 5; tries++ { + err := fetchFile(realFilePath, versionedFilePath, tries) + if err != nil { + log.Tracef("updates: failed to download %s: %s, retrying (%d)", versionedFilePath, err, tries+1) + } else { + return newFile(realFilePath, version, stable), nil + } + } + log.Warningf("updates: failed to download %s: %s", versionedFilePath, err) + return nil, err +} diff --git a/updates/get_test.go b/updates/get_test.go new file mode 100644 index 00000000..92dbc324 --- /dev/null +++ b/updates/get_test.go @@ -0,0 +1,24 @@ +package updates + +import "testing" + +func testBuildVersionedFilePath(t *testing.T, identifier, version, expectedVersionedFilePath string) { + updatesLock.Lock() + stableUpdates[identifier] = version + // betaUpdates[identifier] = version + updatesLock.Unlock() + + versionedFilePath, _, _, ok := getLatestFilePath(identifier) + if !ok { + t.Errorf("identifier %s should exist", identifier) + } + if versionedFilePath != expectedVersionedFilePath { + t.Errorf("unexpected versionedFilePath: %s", versionedFilePath) + } +} + +func TestBuildVersionedFilePath(t *testing.T) { + testBuildVersionedFilePath(t, "path/to/asset.zip", "1.2.3", "path/to/asset_v1-2-3.zip") + testBuildVersionedFilePath(t, "path/to/asset.tar.gz", "1.2.3b", "path/to/asset_v1-2-3b.tar.gz") + testBuildVersionedFilePath(t, "path/to/asset", "1.2.3b", "path/to/asset_v1-2-3b") +} diff --git a/updates/latest.go b/updates/latest.go new file mode 100644 index 00000000..fc41b252 --- /dev/null +++ b/updates/latest.go @@ -0,0 +1,141 @@ +package updates + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + + "github.com/Safing/portbase/log" +) + +var ( + stableUpdates = make(map[string]string) + betaUpdates = make(map[string]string) + latestUpdates = make(map[string]string) + updatesLock sync.RWMutex +) + +// ReloadLatest reloads available updates from disk. +func ReloadLatest() error { + newLatestUpdates := make(map[string]string) + + // all + new, err1 := ScanForLatest(filepath.Join(updateStoragePath, "all"), false) + for key, val := range new { + newLatestUpdates[key] = val + } + + // os_platform + new, err2 := ScanForLatest(filepath.Join(updateStoragePath, fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH)), false) + for key, val := range new { + newLatestUpdates[key] = val + } + + if err1 != nil && err2 != nil { + return fmt.Errorf("could not load latest update versions: %s, %s", err1, err2) + } + + log.Tracef("updates: loading latest updates:") + + for key, val := range newLatestUpdates { + log.Tracef("updates: %s v%s", key, val) + } + + updatesLock.Lock() + latestUpdates = newLatestUpdates + updatesLock.Unlock() + + log.Tracef("updates: load complete") + + if len(stableUpdates) == 0 { + err := loadIndexesFromDisk() + if err != nil { + return err + } + } + + return nil +} + +func ScanForLatest(baseDir string, hardFail bool) (latest map[string]string, lastError error) { + var added int + latest = make(map[string]string) + + filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + lastError = err + if hardFail { + return err + } + log.Warningf("updates: could not read %s", path) + return nil + } + if !info.IsDir() { + added++ + } + + relativePath := strings.TrimLeft(strings.TrimPrefix(path, baseDir), "/") + identifierPath, version, ok := getIdentifierAndVersion(relativePath) + if !ok { + return nil + } + + // add/update index + storedVersion, ok := latest[identifierPath] + if ok { + // FIXME: this will fail on multi-digit version segments! + if version > storedVersion { + latest[identifierPath] = version + } + } else { + latest[identifierPath] = version + } + + return nil + }) + + if lastError != nil { + if hardFail { + return nil, lastError + } + if added == 0 { + return latest, lastError + } + } + return latest, nil +} + +func loadIndexesFromDisk() error { + data, err := ioutil.ReadFile(filepath.Join(updateStoragePath, "stable.json")) + if err != nil { + if os.IsNotExist(err) { + log.Infof("updates: stable.json does not yet exist, waiting for first update cycle") + return nil + } + return err + } + + newStableUpdates := make(map[string]string) + err = json.Unmarshal(data, &newStableUpdates) + if err != nil { + return err + } + + if len(newStableUpdates) == 0 { + return errors.New("stable.json is empty") + } + + log.Tracef("updates: loaded stable.json") + + updatesLock.Lock() + stableUpdates = newStableUpdates + updatesLock.Unlock() + + return nil +} diff --git a/updates/latest_test.go b/updates/latest_test.go new file mode 100644 index 00000000..542bdb45 --- /dev/null +++ b/updates/latest_test.go @@ -0,0 +1,73 @@ +package updates + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +func testLoadLatestScope(t *testing.T, basePath, filePath, expectedIdentifier, expectedVersion string) { + fullPath := filepath.Join(basePath, filePath) + + // create dir + dirPath := filepath.Dir(fullPath) + err := os.MkdirAll(dirPath, 0755) + if err != nil { + t.Fatalf("could not create test dir: %s\n", err) + return + } + + // touch file + err = ioutil.WriteFile(fullPath, []byte{}, 0644) + if err != nil { + t.Fatalf("could not create test file: %s\n", err) + return + } + + // run loadLatestScope + latest, err := ScanForLatest(basePath, true) + if err != nil { + t.Errorf("could not update latest: %s\n", err) + return + } + for key, val := range latest { + latestUpdates[key] = val + } + + // test result + version, ok := latestUpdates[expectedIdentifier] + if !ok { + t.Errorf("identifier %s not in map", expectedIdentifier) + t.Errorf("current map: %v", latestUpdates) + } + if version != expectedVersion { + t.Errorf("unexpected version for %s: %s", filePath, version) + } +} + +func TestLoadLatestScope(t *testing.T) { + + updatesLock.Lock() + defer updatesLock.Unlock() + + tmpDir, err := ioutil.TempDir("", "testing_") + if err != nil { + t.Fatalf("could not create test dir: %s\n", err) + return + } + defer os.RemoveAll(tmpDir) + + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3.zip", "all/ui/assets.zip", "1.2.3") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3b.zip", "all/ui/assets.zip", "1.2.3b") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-4.zip", "all/ui/assets.zip", "1.2.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-3-4.zip", "all/ui/assets.zip", "1.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v2-3-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-3-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v1-2-3", "os_platform/portmaster/portmaster", "1.2.3") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v2-1-1", "os_platform/portmaster/portmaster", "2.1.1") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v1-2-3", "os_platform/portmaster/portmaster", "2.1.1") + +} diff --git a/updates/main.go b/updates/main.go new file mode 100644 index 00000000..4f6695b7 --- /dev/null +++ b/updates/main.go @@ -0,0 +1,95 @@ +package updates + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/modules" +) + +var ( + updateStoragePath string +) + +func init() { + modules.Register("updates", prep, start, nil, "database") +} + +func prep() error { + updateStoragePath = filepath.Join(database.GetDatabaseRoot(), "updates") + + err := checkUpdateDirs() + if err != nil { + return err + } + + return nil +} + +func start() error { + err := ReloadLatest() + if err != nil { + return err + } + + go updater() + return nil +} + +func stop() error { + return os.RemoveAll(filepath.Join(updateStoragePath, "tmp")) +} + +func checkUpdateDirs() error { + // all + err := checkDir(filepath.Join(updateStoragePath, "all")) + if err != nil { + return err + } + + // os_platform + err = checkDir(filepath.Join(updateStoragePath, fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH))) + if err != nil { + return err + } + + // tmp + err = checkDir(filepath.Join(updateStoragePath, "tmp")) + if err != nil { + return err + } + + return nil +} + +func checkDir(dirPath string) error { + f, err := os.Stat(dirPath) + if err == nil { + // file exists + if f.IsDir() { + return nil + } + err = os.Remove(dirPath) + if err != nil { + return fmt.Errorf("could not remove file %s to place dir: %s", dirPath, err) + } + err = os.MkdirAll(dirPath, 0755) + if err != nil { + return fmt.Errorf("could not create dir %s: %s", dirPath, err) + } + return nil + } + // file does not exist + if os.IsNotExist(err) { + err = os.MkdirAll(dirPath, 0755) + if err != nil { + return fmt.Errorf("could not create dir %s: %s", dirPath, err) + } + return nil + } + // other error + return fmt.Errorf("failed to access %s: %s", dirPath, err) +} diff --git a/updates/updater.go b/updates/updater.go new file mode 100644 index 00000000..56a39738 --- /dev/null +++ b/updates/updater.go @@ -0,0 +1,88 @@ +package updates + +import ( + "encoding/json" + "errors" + "io/ioutil" + "path/filepath" + "time" + + "github.com/Safing/portbase/log" +) + +func updater() { + time.Sleep(10 * time.Second) + for { + err := checkForUpdates() + if err != nil { + log.Warningf("updates: failed to check for updates: %s", err) + } + time.Sleep(1 * time.Hour) + } +} + +func checkForUpdates() error { + + // download new index + var data []byte + var err error + for tries := 0; tries < 3; tries++ { + data, err = fetchData("stable.json", tries) + if err == nil { + break + } + } + if err != nil { + return err + } + + newStableUpdates := make(map[string]string) + err = json.Unmarshal(data, &newStableUpdates) + if err != nil { + return err + } + + if len(newStableUpdates) == 0 { + return errors.New("stable.json is empty") + } + + // FIXINSTABLE: correct log line + log.Infof("updates: downloaded new update index: stable.json (alpha until we actually reach stable)") + + // update existing files + log.Tracef("updates: updating existing files") + updatesLock.RLock() + for identifier, newVersion := range newStableUpdates { + oldVersion, ok := latestUpdates[identifier] + if ok && newVersion != oldVersion { + + filePath := getVersionedPath(identifier, newVersion) + realFilePath := filepath.Join(updateStoragePath, filePath) + for tries := 0; tries < 3; tries++ { + err := fetchFile(realFilePath, filePath, tries) + if err == nil { + break + } + } + if err != nil { + log.Warningf("failed to update %s to %s: %s", identifier, newVersion, err) + } + + } + } + updatesLock.RUnlock() + log.Tracef("updates: finished updating existing files") + + // update stable index + updatesLock.Lock() + stableUpdates = newStableUpdates + updatesLock.Unlock() + + // save stable index + err = ioutil.WriteFile(filepath.Join(updateStoragePath, "stable.json"), data, 0644) + if err != nil { + log.Warningf("updates: failed to save new version of stable.json: %s", err) + } + + return nil +} diff --git a/updates/uptool/.gitignore b/updates/uptool/.gitignore new file mode 100644 index 00000000..c5074cf6 --- /dev/null +++ b/updates/uptool/.gitignore @@ -0,0 +1 @@ +uptool diff --git a/updates/uptool/root.go b/updates/uptool/root.go new file mode 100644 index 00000000..8c75312c --- /dev/null +++ b/updates/uptool/root.go @@ -0,0 +1,23 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "uptool", + Short: "helper tool for the update process", + Run: func(cmd *cobra.Command, args []string) { + cmd.Usage() + }, +} + +func main() { + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/updates/uptool/scan.go b/updates/uptool/scan.go new file mode 100644 index 00000000..ccb5185d --- /dev/null +++ b/updates/uptool/scan.go @@ -0,0 +1,35 @@ +package main + +import ( + "encoding/json" + "fmt" + + "github.com/Safing/portmaster/updates" + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(scanCmd) +} + +var scanCmd = &cobra.Command{ + Use: "scan", + Short: "Scan the current directory and print the result", + RunE: scan, +} + +func scan(cmd *cobra.Command, args []string) error { + + latest, err := updates.ScanForLatest(".", true) + if err != nil { + return err + } + + data, err := json.MarshalIndent(latest, "", " ") + if err != nil { + return err + } + + fmt.Println(string(data)) + return nil +}