diff --git a/Gopkg.lock b/Gopkg.lock index 68febc82..2c6a4c07 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -84,12 +84,12 @@ version = "v0.1.0" [[projects]] - digest = "1:cbec35fe4d5a4fba369a656a8cd65e244ea2c743007d8f6c1ccb132acf9d1296" - name = "github.com/gorilla/mux" + digest = "1:88e0b0baeb9072f0a4afbcf12dda615fc8be001d1802357538591155998da21b" + name = "github.com/hashicorp/go-version" packages = ["."] pruneopts = "UT" - revision = "00bdffe0f3c77e27d2cf6f5c70232a2d3e4d9c15" - version = "v1.7.3" + revision = "ac23dc3fea5d1a983c43f6a0f6e2c13f0195d8bd" + version = "v1.2.0" [[projects]] digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" @@ -219,12 +219,17 @@ [[projects]] branch = "master" - digest = "1:7e3e63385ebe2dd5210be1534c516d8ae33c7ffac74126e5243d43ba7222e0d4" + digest = "1:84945c0665ea5fc3ccbd067c35890a7d28e369131ac411b8a820b40115245c19" name = "golang.org/x/sys" packages = [ "cpu", "unix", "windows", + "windows/registry", + "windows/svc", + "windows/svc/debug", + "windows/svc/eventlog", + "windows/svc/mgr", ] pruneopts = "UT" revision = "04f50cda93cbb67f2afa353c52f342100e80e625" @@ -242,7 +247,7 @@ "github.com/google/gopacket/layers", "github.com/google/gopacket/tcpassembly", "github.com/google/renameio", - "github.com/gorilla/mux", + "github.com/hashicorp/go-version", "github.com/miekg/dns", "github.com/oschwald/maxminddb-golang", "github.com/satori/go.uuid", @@ -256,6 +261,10 @@ "golang.org/x/net/icmp", "golang.org/x/net/ipv4", "golang.org/x/sys/windows", + "golang.org/x/sys/windows/svc", + "golang.org/x/sys/windows/svc/debug", + "golang.org/x/sys/windows/svc/eventlog", + "golang.org/x/sys/windows/svc/mgr", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index d398ae51..0fe7f61a 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -54,10 +54,6 @@ ignored = ["github.com/safing/portbase/*"] name = "github.com/google/renameio" version = "0.1.0" -[[constraint]] - name = "github.com/gorilla/mux" - version = "1.7.3" - [[constraint]] name = "github.com/miekg/dns" version = "1.1.15" @@ -101,3 +97,7 @@ ignored = ["github.com/safing/portbase/*"] [prune] go-tests = true unused-packages = true + +[[constraint]] + name = "github.com/hashicorp/go-version" + version = "1.2.0" diff --git a/analytics/algs/lms.go b/analytics/algs/lms.go index 6f4ceaba..cc7b227d 100644 --- a/analytics/algs/lms.go +++ b/analytics/algs/lms.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package algs import ( diff --git a/analytics/algs/lms_test.go b/analytics/algs/lms_test.go index a73db9d9..c465f654 100644 --- a/analytics/algs/lms_test.go +++ b/analytics/algs/lms_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package algs import "testing" diff --git a/core/base.go b/core/base.go new file mode 100644 index 00000000..a3e823ba --- /dev/null +++ b/core/base.go @@ -0,0 +1,60 @@ +package core + +import ( + "errors" + "flag" + + "github.com/safing/portbase/api" + "github.com/safing/portbase/database/dbmodule" + "github.com/safing/portbase/modules" + "github.com/safing/portbase/notifications" + + "github.com/safing/portmaster/core/structure" +) + +var ( + dataDir string + databaseDir string + + baseModule = modules.Register("base", prepBase, nil, nil) +) + +func init() { + flag.StringVar(&dataDir, "data", "", "set data directory") + flag.StringVar(&databaseDir, "db", "", "alias to --data (deprecated)") + + notifications.SetPersistenceBasePath("core:notifications") +} + +func prepBase() error { + // backwards compatibility + if dataDir == "" { + dataDir = databaseDir + } + + // check data dir + if dataDir == "" { + return errors.New("please set the data directory using --data=/path/to/data/dir") + } + + // initialize structure + err := structure.Initialize(dataDir, 0755) + if err != nil { + return err + } + + // set database location + dbmodule.SetDatabaseLocation("", structure.Root()) + + // init config + logFlagOverrides() + err = registerConfig() + if err != nil { + return err + } + + // set api listen address + api.SetDefaultAPIListenAddress("127.0.0.1:817") + + return nil +} diff --git a/core/config.go b/core/config.go new file mode 100644 index 00000000..f4e1a7ae --- /dev/null +++ b/core/config.go @@ -0,0 +1,39 @@ +package core + +import ( + "flag" + + "github.com/safing/portbase/config" + "github.com/safing/portbase/log" +) + +var ( + devMode config.BoolOption + defaultDevMode bool +) + +func init() { + flag.BoolVar(&defaultDevMode, "devmode", false, "force development mode") +} + +func logFlagOverrides() { + if defaultDevMode { + log.Warning("core: core/devMode default config is being forced by -devmode flag") + } +} + +func registerConfig() error { + err := config.Register(&config.Option{ + Name: "Development Mode", + Key: "core/devMode", + Description: "In Development Mode security restrictions are lifted/softened to enable easier access to Portmaster for debugging and testing purposes.", + ExpertiseLevel: config.ExpertiseLevelDeveloper, + OptType: config.OptTypeBool, + DefaultValue: defaultDevMode, + }) + if err != nil { + return err + } + + return nil +} diff --git a/core/core.go b/core/core.go new file mode 100644 index 00000000..9b34f92f --- /dev/null +++ b/core/core.go @@ -0,0 +1,19 @@ +package core + +import ( + "fmt" + + "github.com/safing/portbase/modules" +) + +var ( + coreModule = modules.Register("core", nil, startCore, nil, "base", "database", "config", "api", "random") +) + +func startCore() error { + if err := startPlatformSpecific(); err != nil { + return fmt.Errorf("failed to start plattform-specific components: %s", err) + } + + return registerDatabases() +} diff --git a/core/databases.go b/core/databases.go index 3373ea6e..ddb9cf9f 100644 --- a/core/databases.go +++ b/core/databases.go @@ -2,21 +2,12 @@ package core import ( "github.com/safing/portbase/database" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/notifications" // module dependencies - _ "github.com/safing/portbase/database/dbmodule" _ "github.com/safing/portbase/database/storage/bbolt" ) -func init() { - modules.Register("core", nil, start, nil, "database") - - notifications.SetPersistenceBasePath("core:notifications") -} - -func start() error { +func registerDatabases() error { _, err := database.Register(&database.Database{ Name: "core", Description: "Holds core data, such as settings and profiles", diff --git a/core/os_default.go b/core/os_default.go new file mode 100644 index 00000000..459e41a2 --- /dev/null +++ b/core/os_default.go @@ -0,0 +1,8 @@ +// +build !windows + +package core + +// only return on Fatal error! +func startPlatformSpecific() error { + return nil +} diff --git a/core/os_windows.go b/core/os_windows.go new file mode 100644 index 00000000..41f1f594 --- /dev/null +++ b/core/os_windows.go @@ -0,0 +1,16 @@ +package core + +import ( + "github.com/safing/portbase/log" + "github.com/safing/portbase/utils/osdetail" +) + +// only return on Fatal error! +func startPlatformSpecific() error { + // We can't catch errors when calling WindowsNTVersion() in logging, so we call the function here, just to catch possible errors + if _, err := osdetail.WindowsNTVersion(); err != nil { + log.Errorf("failed to obtain WindowsNTVersion: %s", err) + } + + return nil +} diff --git a/core/structure/dirs.go b/core/structure/dirs.go new file mode 100644 index 00000000..c2df06f9 --- /dev/null +++ b/core/structure/dirs.go @@ -0,0 +1,27 @@ +package structure + +import ( + "os" + + "github.com/safing/portbase/utils" +) + +var ( + root *utils.DirStructure +) + +// Initialize initializes the data root directory +func Initialize(rootDir string, perm os.FileMode) error { + root = utils.NewDirStructure(rootDir, perm) + return root.Ensure() +} + +// Root returns the data root directory. +func Root() *utils.DirStructure { + return root +} + +// NewRootDir calls ChildDir() on the data root directory. +func NewRootDir(dirName string, perm os.FileMode) (childDir *utils.DirStructure) { + return root.ChildDir(dirName, perm) +} diff --git a/firewall/api.go b/firewall/api.go new file mode 100644 index 00000000..3803126c --- /dev/null +++ b/firewall/api.go @@ -0,0 +1,110 @@ +package firewall + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "strconv" + "strings" + + "github.com/safing/portbase/utils" + "github.com/safing/portmaster/core/structure" + + "github.com/safing/portbase/log" + + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/process" + + "github.com/safing/portbase/api" +) + +var ( + dataRoot *utils.DirStructure + + apiPortSet bool + apiPort uint16 +) + +func prepAPIAuth() error { + dataRoot = structure.Root() + return api.SetAuthenticator(apiAuthenticator) +} + +func startAPIAuth() { + var err error + _, apiPort, err = parseHostPort(apiListenAddress()) + if err != nil { + log.Warningf("firewall: failed to parse API address for improved api auth mechanism: %s", err) + return + } + apiPortSet = true + log.Tracef("firewall: api port set to %d", apiPort) +} + +func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err error) { + if devMode() { + return true, nil + } + + // get local IP/Port + localIP, localPort, err := parseHostPort(s.Addr) + if err != nil { + return false, fmt.Errorf("failed to get local IP/Port: %s", err) + } + + // get remote IP/Port + remoteIP, remotePort, err := parseHostPort(r.RemoteAddr) + if err != nil { + return false, fmt.Errorf("failed to get remote IP/Port: %s", err) + } + + var procsChecked []string + + // get process + proc, err := process.GetProcessByEndpoints(r.Context(), remoteIP, remotePort, localIP, localPort, packet.TCP) // switch reverse/local to get remote process + if err != nil { + return false, fmt.Errorf("failed to get process: %s", err) + } + + // go up up to two levels, if we don't match + for i := 0; i < 3; i++ { + // check if the requesting process is in database root / updates dir + if strings.HasPrefix(proc.Path, dataRoot.Path) { + return true, nil + } + // add checked process to list + procsChecked = append(procsChecked, proc.Path) + + if i < 2 { + // get parent process + proc, err = process.GetOrFindProcess(context.Background(), proc.ParentPid) + if err != nil { + return false, fmt.Errorf("failed to get process: %s", err) + } + } + } + + log.Debugf("firewall: denying api access to %s - also checked %s (trusted root is %s)", procsChecked[0], strings.Join(procsChecked[1:], " "), dataRoot.Path) + return false, nil +} + +func parseHostPort(address string) (net.IP, uint16, error) { + ipString, portString, err := net.SplitHostPort(address) + if err != nil { + return nil, 0, err + } + + ip := net.ParseIP(ipString) + if ip == nil { + return nil, 0, errors.New("invalid IP address") + } + + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, 0, err + } + + return ip, uint16(port), nil +} diff --git a/firewall/config.go b/firewall/config.go index 8e79e686..985138c8 100644 --- a/firewall/config.go +++ b/firewall/config.go @@ -9,6 +9,10 @@ var ( permanentVerdicts config.BoolOption filterDNSByScope status.SecurityLevelOption filterDNSByProfile status.SecurityLevelOption + promptTimeout config.IntOption + + devMode config.BoolOption + apiListenAddress config.StringOption ) func registerConfig() error { @@ -55,5 +59,21 @@ func registerConfig() error { } filterDNSByProfile = status.ConfigIsActiveConcurrent("firewall/filterDNSByProfile") + err = config.Register(&config.Option{ + Name: "Timeout for prompt notifications", + Key: "firewall/promptTimeout", + Description: "Amount of time how long Portmaster will wait for a response when prompting about a connection via a notification. In seconds.", + ExpertiseLevel: config.ExpertiseLevelUser, + OptType: config.OptTypeInt, + DefaultValue: 60, + }) + if err != nil { + return err + } + promptTimeout = config.Concurrent.GetAsInt("firewall/promptTimeout", 30) + + devMode = config.Concurrent.GetAsBool("firewall/permanentVerdicts", false) + apiListenAddress = config.GetAsString("api/listenAddress", "") + return nil } diff --git a/firewall/firewall.go b/firewall/firewall.go index 74fb95bb..e8675455 100644 --- a/firewall/firewall.go +++ b/firewall/firewall.go @@ -50,6 +50,11 @@ func prep() (err error) { return err } + err = prepAPIAuth() + if err != nil { + return err + } + _, localNet4, err = net.ParseCIDR("127.0.0.0/24") // Yes, this would normally be 127.0.0.0/8 // TODO: figure out any side effects @@ -77,12 +82,9 @@ func prep() (err error) { } func start() error { + startAPIAuth() go statLogger() go run() - // go run() - // go run() - // go run() - go portsInUseCleaner() return interception.Start() @@ -108,6 +110,15 @@ func handlePacket(pkt packet.Packet) { return } + // allow api access, if address was parsed successfully + if apiPortSet { + if (pkt.Info().DstPort == apiPort || pkt.Info().SrcPort == apiPort) && pkt.Info().Src.Equal(pkt.Info().Dst) { + log.Debugf("accepting api connection: %s", pkt) + pkt.PermanentAccept() + return + } + } + // // redirect dns (if we know that it's not our own request) // if pkt.IsOutbound() && intel.RemoteIsActiveNameserver(pkt) { // log.Debugf("redirecting dns: %s", pkt) diff --git a/firewall/inspection/inspection.go b/firewall/inspection/inspection.go index a3a536fa..3ecd23cc 100644 --- a/firewall/inspection/inspection.go +++ b/firewall/inspection/inspection.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package inspection import ( diff --git a/firewall/inspection/tls/tls_test.go b/firewall/inspection/tls/tls_test.go index e000f026..79ec4be4 100644 --- a/firewall/inspection/tls/tls_test.go +++ b/firewall/inspection/tls/tls_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package tls import ( diff --git a/firewall/interception/interception_linux.go b/firewall/interception/interception_linux.go index bf453449..0021d9e3 100644 --- a/firewall/interception/interception_linux.go +++ b/firewall/interception/interception_linux.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package interception import "github.com/safing/portmaster/network/packet" diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index 57db89f3..3844022e 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -81,7 +81,7 @@ func notifyDisableDNSCache() { ID: "windows-disable-dns-cache", Message: "The Portmaster needs the Windows Service \"DNS Client\" (dnscache) to be disabled for best effectiveness.", Type: notifications.Warning, - }).Init().Save() + }).Save() } func notifyRebootRequired() { @@ -89,5 +89,5 @@ func notifyRebootRequired() { ID: "windows-dnscache-reboot-required", Message: "Please restart your system to complete Portmaster integration.", Type: notifications.Warning, - }).Init().Save() + }).Save() } diff --git a/firewall/interception/nfqueue/multiqueue.go b/firewall/interception/nfqueue/multiqueue.go index 2c5aaa41..884947bd 100644 --- a/firewall/interception/nfqueue/multiqueue.go +++ b/firewall/interception/nfqueue/multiqueue.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package nfqueue // suspended for now diff --git a/firewall/interception/nfqueue/nfqueue.go b/firewall/interception/nfqueue/nfqueue.go index 1c2d5d08..27665240 100644 --- a/firewall/interception/nfqueue/nfqueue.go +++ b/firewall/interception/nfqueue/nfqueue.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package nfqueue /* diff --git a/firewall/interception/nfqueue/packet.go b/firewall/interception/nfqueue/packet.go index cbbfc27a..ea873ddc 100644 --- a/firewall/interception/nfqueue/packet.go +++ b/firewall/interception/nfqueue/packet.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package nfqueue import ( diff --git a/firewall/master.go b/firewall/master.go index 9a752cae..8795c975 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -5,10 +5,9 @@ import ( "net" "os" "strings" - "time" + "github.com/miekg/dns" "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" "github.com/safing/portmaster/intel" "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" @@ -16,7 +15,6 @@ import ( "github.com/safing/portmaster/process" "github.com/safing/portmaster/profile" "github.com/safing/portmaster/status" - "github.com/miekg/dns" "github.com/agext/levenshtein" ) @@ -137,93 +135,7 @@ func DecideOnCommunicationAfterIntel(comm *network.Communication, fqdn string, r } // prompt - - // first check if there is an existing notification for this. - nID := fmt.Sprintf("firewall-prompt-%d-%s", comm.Process().Pid, comm.Domain) - nTTL := 15 * time.Second - n := notifications.Get(nID) - if n != nil { - // we were not here first, only get verdict, do not make changes - select { - case promptResponse := <-n.Response(): - switch promptResponse { - case "permit-all", "permit-distinct": - comm.Accept("permitted by user") - default: - comm.Deny("denied by user") - } - case <-time.After(nTTL): - comm.SetReason("user did not respond to prompt") - } - return - } - - // create new notification - n = (¬ifications.Notification{ - ID: nID, - Message: fmt.Sprintf("Application %s wants to connect to %s", comm.Process(), comm.Domain), - Type: notifications.Prompt, - AvailableActions: []*notifications.Action{ - ¬ifications.Action{ - ID: "permit-all", - Text: fmt.Sprintf("Permit all %s", comm.Domain), - }, - ¬ifications.Action{ - ID: "permit-distinct", - Text: fmt.Sprintf("Permit %s", comm.Domain), - }, - ¬ifications.Action{ - ID: "deny", - Text: "Deny", - }, - }, - Expires: time.Now().Add(nTTL).Unix(), - }).Init().Save() - - // react - select { - case promptResponse := <-n.Response(): - n.Cancel() - - new := &profile.EndpointPermission{ - Type: profile.EptDomain, - Value: comm.Domain, - Permit: true, - Created: time.Now().Unix(), - } - - switch promptResponse { - case "permit-all": - new.Value = "." + new.Value - case "permit-distinct": - // everything already set - default: - // deny - new.Permit = false - } - - if new.Permit { - log.Infof("firewall: user permitted communication %s -> %s", comm.Process(), new.Value) - comm.Accept("permitted by user") - } else { - log.Infof("firewall: user denied communication %s -> %s", comm.Process(), new.Value) - comm.Deny("denied by user") - } - - profileSet.Lock() - defer profileSet.Unlock() - userProfile := profileSet.UserProfile() - userProfile.Lock() - defer userProfile.Unlock() - - userProfile.Endpoints = append(userProfile.Endpoints, new) - go userProfile.Save("") - - case <-time.After(nTTL): - n.Cancel() - comm.SetReason("user did not respond to prompt") - - } + prompt(comm, nil, nil, fqdn) } // FilterDNSResponse filters a dns response according to the application profile and settings. @@ -573,134 +485,8 @@ func DecideOnLink(comm *network.Communication, link *network.Link, pkt packet.Pa } } - // first check if there is an existing notification for this. - var nID string - switch { - case comm.Direction: - nID = fmt.Sprintf("firewall-prompt-%d-%s-%s-%d-%d", comm.Process().Pid, comm.Domain, remoteIP, protocol, dstPort) - case fqdn == "": - nID = fmt.Sprintf("firewall-prompt-%d-%s-%s-%d-%d", comm.Process().Pid, comm.Domain, remoteIP, protocol, dstPort) - default: - nID = fmt.Sprintf("firewall-prompt-%d-%s-%s-%d-%d", comm.Process().Pid, comm.Domain, remoteIP, protocol, dstPort) - } - nTTL := 15 * time.Second - n := notifications.Get(nID) - - if n != nil { - // we were not here first, only get verdict, do not make changes - select { - case promptResponse := <-n.Response(): - switch promptResponse { - case "permit-domain-all", "permit-domain-distinct", "permit-ip", "permit-ip-incoming": - link.Accept("permitted by user") - default: - link.Deny("denied by user") - } - case <-time.After(nTTL): - link.Deny("user did not respond to prompt") - } - return - } - - // create new notification - n = (¬ifications.Notification{ - ID: nID, - Type: notifications.Prompt, - Expires: time.Now().Add(nTTL).Unix(), - }) - - switch { - case comm.Direction: - n.Message = fmt.Sprintf("Application %s wants to accept connections from %s (%d/%d)", comm.Process(), remoteIP, protocol, dstPort) - n.AvailableActions = []*notifications.Action{ - ¬ifications.Action{ - ID: "permit-ip-incoming", - Text: fmt.Sprintf("Permit serving to %s", remoteIP), - }, - } - case fqdn == "": - n.Message = fmt.Sprintf("Application %s wants to connect to %s (%d/%d)", comm.Process(), remoteIP, protocol, dstPort) - n.AvailableActions = []*notifications.Action{ - ¬ifications.Action{ - ID: "permit-ip", - Text: fmt.Sprintf("Permit %s", remoteIP), - }, - } - default: - n.Message = fmt.Sprintf("Application %s wants to connect to %s (%s %d/%d)", comm.Process(), comm.Domain, remoteIP, protocol, dstPort) - n.AvailableActions = []*notifications.Action{ - ¬ifications.Action{ - ID: "permit-domain-all", - Text: fmt.Sprintf("Permit all %s", comm.Domain), - }, - ¬ifications.Action{ - ID: "permit-domain-distinct", - Text: fmt.Sprintf("Permit %s", comm.Domain), - }, - } - } - - n.AvailableActions = append(n.AvailableActions, ¬ifications.Action{ - ID: "deny", - Text: "deny", - }) - n.Init().Save() - - // react - select { - case promptResponse := <-n.Response(): - n.Cancel() - - new := &profile.EndpointPermission{ - Type: profile.EptDomain, - Value: comm.Domain, - Permit: true, - Created: time.Now().Unix(), - } - - switch promptResponse { - case "permit-domain-all": - new.Value = "." + new.Value - case "permit-domain-distinct": - // everything already set - case "permit-ip", "permit-ip-incoming": - if pkt.Info().Version == packet.IPv4 { - new.Type = profile.EptIPv4 - } else { - new.Type = profile.EptIPv6 - } - new.Value = remoteIP.String() - default: - // deny - new.Permit = false - } - - if new.Permit { - log.Infof("firewall: user permitted link %s -> %s", comm.Process(), new.Value) - link.Accept("permitted by user") - } else { - log.Infof("firewall: user denied link %s -> %s", comm.Process(), new.Value) - link.Deny("denied by user") - } - - profileSet.Lock() - defer profileSet.Unlock() - userProfile := profileSet.UserProfile() - userProfile.Lock() - defer userProfile.Unlock() - - if promptResponse == "permit-ip-incoming" { - userProfile.ServiceEndpoints = append(userProfile.ServiceEndpoints, new) - } else { - userProfile.Endpoints = append(userProfile.Endpoints, new) - } - go userProfile.Save("") - - case <-time.After(nTTL): - n.Cancel() - link.Deny("user did not respond to prompt") - - } + // prompt + prompt(comm, link, pkt, fqdn) } func checkRelation(comm *network.Communication, fqdn string) (related bool) { diff --git a/firewall/prompt.go b/firewall/prompt.go new file mode 100644 index 00000000..274927c7 --- /dev/null +++ b/firewall/prompt.go @@ -0,0 +1,194 @@ +package firewall + +import ( + "fmt" + "time" + + "github.com/safing/portbase/log" + "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/network" + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/profile" +) + +const ( + // notification action IDs + permitDomainAll = "permit-domain-all" + permitDomainDistinct = "permit-domain-distinct" + denyDomainAll = "deny-domain-all" + denyDomainDistinct = "deny-domain-distinct" + + permitIP = "permit-ip" + denyIP = "deny-ip" + permitServingIP = "permit-serving-ip" + denyServingIP = "deny-serving-ip" +) + +func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet, fqdn string) { + nTTL := time.Duration(promptTimeout()) * time.Second + + // first check if there is an existing notification for this. + // build notification ID + var nID string + switch { + case comm.Direction, fqdn == "": // connection to/from IP + if pkt == nil { + log.Error("firewall: could not prompt for incoming/direct connection: missing pkt") + if link != nil { + link.Deny("internal error") + } else { + comm.Deny("internal error") + } + return + } + nID = fmt.Sprintf("firewall-prompt-%d-%s-%s", comm.Process().Pid, comm.Domain, pkt.Info().RemoteIP) + default: // connection to domain + nID = fmt.Sprintf("firewall-prompt-%d-%s", comm.Process().Pid, comm.Domain) + } + n := notifications.Get(nID) + saveResponse := true + + if n != nil { + // update with new expiry + n.Update(time.Now().Add(nTTL).Unix()) + // do not save response to profile + saveResponse = false + } else { + // create new notification + n = (¬ifications.Notification{ + ID: nID, + Type: notifications.Prompt, + Expires: time.Now().Add(nTTL).Unix(), + }) + + // add message and actions + switch { + case comm.Direction: // incoming + n.Message = fmt.Sprintf("Application %s wants to accept connections from %s (on %d/%d)", comm.Process(), pkt.Info().RemoteIP(), pkt.Info().Protocol, pkt.Info().LocalPort()) + n.AvailableActions = []*notifications.Action{ + ¬ifications.Action{ + ID: permitServingIP, + Text: "Permit", + }, + ¬ifications.Action{ + ID: denyServingIP, + Text: "Deny", + }, + } + case fqdn == "": // direct connection + n.Message = fmt.Sprintf("Application %s wants to connect to %s (on %d/%d)", comm.Process(), pkt.Info().RemoteIP(), pkt.Info().Protocol, pkt.Info().RemotePort()) + n.AvailableActions = []*notifications.Action{ + ¬ifications.Action{ + ID: permitIP, + Text: "Permit", + }, + ¬ifications.Action{ + ID: denyIP, + Text: "Deny", + }, + } + default: // connection to domain + if pkt != nil { + n.Message = fmt.Sprintf("Application %s wants to connect to %s (%s %d/%d)", comm.Process(), comm.Domain, pkt.Info().RemoteIP(), pkt.Info().Protocol, pkt.Info().RemotePort()) + } else { + n.Message = fmt.Sprintf("Application %s wants to connect to %s", comm.Process(), comm.Domain) + } + n.AvailableActions = []*notifications.Action{ + ¬ifications.Action{ + ID: permitDomainAll, + Text: "Permit all", + }, + ¬ifications.Action{ + ID: permitDomainDistinct, + Text: "Permit", + }, + ¬ifications.Action{ + ID: denyDomainDistinct, + Text: "Deny", + }, + } + } + // save new notification + n.Save() + } + + // wait for response/timeout + select { + case promptResponse := <-n.Response(): + switch promptResponse { + case permitDomainAll, permitDomainDistinct, permitIP, permitServingIP: + if link != nil { + link.Accept("permitted by user") + } else { + comm.Accept("permitted by user") + } + default: // deny + if link != nil { + link.Accept("denied by user") + } else { + comm.Accept("denied by user") + } + } + + // end here if we won't save the response to the profile + if !saveResponse { + return + } + + new := &profile.EndpointPermission{ + Type: profile.EptDomain, + Value: comm.Domain, + Permit: false, + Created: time.Now().Unix(), + } + + // permission type + switch promptResponse { + case permitDomainAll, denyDomainAll: + new.Value = "." + new.Value + case permitIP, permitServingIP, denyIP, denyServingIP: + if pkt == nil { + log.Warningf("firewall: received invalid prompt response: %s for %s", promptResponse, comm.Domain) + return + } + if pkt.Info().Version == packet.IPv4 { + new.Type = profile.EptIPv4 + } else { + new.Type = profile.EptIPv6 + } + new.Value = pkt.Info().RemoteIP().String() + } + + // permission verdict + switch promptResponse { + case permitDomainAll, permitDomainDistinct, permitIP, permitServingIP: + new.Permit = false + } + + // get user profile + profileSet := comm.Process().ProfileSet() + profileSet.Lock() + defer profileSet.Unlock() + userProfile := profileSet.UserProfile() + userProfile.Lock() + defer userProfile.Unlock() + + // add to correct list + switch promptResponse { + case permitServingIP, denyServingIP: + userProfile.ServiceEndpoints = append(userProfile.ServiceEndpoints, new) + default: + userProfile.Endpoints = append(userProfile.Endpoints, new) + } + + // save! + go userProfile.Save("") + + case <-n.Expired(): + if link != nil { + link.Accept("no response to prompt") + } else { + comm.Accept("no response to prompt") + } + } +} diff --git a/intel/doc.go b/intel/doc.go index 70edfcf7..ffa1f767 100644 --- a/intel/doc.go +++ b/intel/doc.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - /* Package intel is responsible for fetching intelligence data, including DNS, on remote entities. diff --git a/intel/intel.go b/intel/intel.go index ce7822ca..66125213 100644 --- a/intel/intel.go +++ b/intel/intel.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package intel import ( diff --git a/intel/mdns.go b/intel/mdns.go index a03436a5..c9aa675c 100644 --- a/intel/mdns.go +++ b/intel/mdns.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package intel import ( diff --git a/intel/resolve.go b/intel/resolve.go index 2ccc8e5e..fa0ec343 100644 --- a/intel/resolve.go +++ b/intel/resolve.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package intel import ( diff --git a/intel/resolve_test.go b/intel/resolve_test.go index b206bdcf..4a57d114 100644 --- a/intel/resolve_test.go +++ b/intel/resolve_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package intel // DISABLE TESTING FOR NOW: find a way to have tests with the module system diff --git a/intel/rrcache.go b/intel/rrcache.go index f33abbf0..c99f8b4b 100644 --- a/intel/rrcache.go +++ b/intel/rrcache.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package intel import ( diff --git a/intel/special.go b/intel/special.go index c979599b..e2301df3 100644 --- a/intel/special.go +++ b/intel/special.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package intel import "strings" diff --git a/main.go b/main.go index 1600e363..18dc4377 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "flag" "fmt" "os" @@ -21,17 +22,18 @@ import ( ) var ( - printStackOnExit bool + printStackOnExit bool + enableInputSignals bool ) func init() { flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down") + flag.BoolVar(&enableInputSignals, "input-signals", false, "emulate signals using stdin") } func main() { - // Set Info - info.Set("Portmaster", "0.3.1", "AGPLv3", true) + info.Set("Portmaster", "0.3.8", "AGPLv3", true) // Start err := modules.Start() @@ -47,6 +49,9 @@ func main() { // Shutdown // catch interrupt for clean shutdown signalCh := make(chan os.Signal) + if enableInputSignals { + go inputSignals(signalCh) + } signal.Notify( signalCh, os.Interrupt, @@ -82,9 +87,9 @@ func main() { } go func() { - time.Sleep(5 * time.Second) - fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + time.Sleep(10 * time.Second) + fmt.Fprintln(os.Stderr, "===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") + pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) os.Exit(1) }() @@ -99,3 +104,19 @@ func main() { } } + +func inputSignals(signalCh chan os.Signal) { + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + switch scanner.Text() { + case "SIGHUP": + signalCh <- syscall.SIGHUP + case "SIGINT": + signalCh <- syscall.SIGINT + case "SIGQUIT": + signalCh <- syscall.SIGQUIT + case "SIGTERM": + signalCh <- syscall.SIGTERM + } + } +} diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index ed0b9831..d2e1aa5a 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package nameserver import ( @@ -29,7 +27,7 @@ var ( ) func init() { - modules.Register("nameserver", prep, start, nil, "intel") + modules.Register("nameserver", prep, start, nil, "core", "intel") if runtime.GOOS == "windows" { listenAddress = "0.0.0.0:53" diff --git a/nameserver/only/nameserver.go b/nameserver/only/nameserver.go index 07cc5ae8..8b6c2ba1 100644 --- a/nameserver/only/nameserver.go +++ b/nameserver/only/nameserver.go @@ -21,7 +21,7 @@ var ( ) func init() { - modules.Register("nameserver", prep, start, nil, "intel") + modules.Register("nameserver", prep, start, nil, "core", "intel") } func prep() error { diff --git a/nameserver/takeover.go b/nameserver/takeover.go index 82d4e42c..ab245b05 100644 --- a/nameserver/takeover.go +++ b/nameserver/takeover.go @@ -26,7 +26,7 @@ func checkForConflictingService(err error) { (¬ifications.Notification{ ID: "nameserver-stopped-conflicting-service", Message: fmt.Sprintf("Portmaster stopped a conflicting name service (pid %d) to gain required system integration.", pid), - }).Init().Save() + }).Save() // wait for a short duration for the other service to shut down time.Sleep(100 * time.Millisecond) diff --git a/network/clean.go b/network/clean.go index 92313576..b4e5560c 100644 --- a/network/clean.go +++ b/network/clean.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package network import ( diff --git a/network/communication.go b/network/communication.go index 369904fa..cedd26f6 100644 --- a/network/communication.go +++ b/network/communication.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package network import ( diff --git a/network/environment/dbus_linux.go b/network/environment/dbus_linux.go index 8e6751c3..8260d343 100644 --- a/network/environment/dbus_linux.go +++ b/network/environment/dbus_linux.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - // +build !server package environment diff --git a/network/environment/dbus_linux_mock.go b/network/environment/dbus_linux_mock.go index 1e4e6e01..d03f8248 100644 --- a/network/environment/dbus_linux_mock.go +++ b/network/environment/dbus_linux_mock.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - // +build !linux package environment diff --git a/network/environment/dbus_linux_test.go b/network/environment/dbus_linux_test.go index 2c0fd729..a274297c 100644 --- a/network/environment/dbus_linux_test.go +++ b/network/environment/dbus_linux_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package environment import "testing" diff --git a/network/environment/environment.go b/network/environment/environment.go index 03d92bf9..701ee8c0 100644 --- a/network/environment/environment.go +++ b/network/environment/environment.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package environment import ( diff --git a/network/environment/environment_linux.go b/network/environment/environment_linux.go index 49d48314..3c2d3d14 100644 --- a/network/environment/environment_linux.go +++ b/network/environment/environment_linux.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package environment import ( diff --git a/network/environment/interface.go b/network/environment/interface.go index 8ae6390f..338eee87 100644 --- a/network/environment/interface.go +++ b/network/environment/interface.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package environment import ( diff --git a/network/environment/location.go b/network/environment/location.go index db765021..806009f3 100644 --- a/network/environment/location.go +++ b/network/environment/location.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package environment import ( diff --git a/network/environment/location_test.go b/network/environment/location_test.go index 9004ed4f..44f9f51f 100644 --- a/network/environment/location_test.go +++ b/network/environment/location_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - // +build root package environment diff --git a/network/link.go b/network/link.go index 4dbd07fa..a1cbd04f 100644 --- a/network/link.go +++ b/network/link.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package network import ( diff --git a/network/module.go b/network/module.go index 85b3d170..1bd8ba9f 100644 --- a/network/module.go +++ b/network/module.go @@ -5,7 +5,7 @@ import ( ) func init() { - modules.Register("network", nil, start, nil, "database") + modules.Register("network", nil, start, nil, "core") } func start() error { diff --git a/network/netutils/cleandns.go b/network/netutils/cleandns.go index 66bcb24a..33cec6cc 100644 --- a/network/netutils/cleandns.go +++ b/network/netutils/cleandns.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package netutils import ( diff --git a/network/netutils/ip.go b/network/netutils/ip.go index 780c6901..651a5d95 100644 --- a/network/netutils/ip.go +++ b/network/netutils/ip.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package netutils import "net" diff --git a/network/netutils/tcpassembly.go b/network/netutils/tcpassembly.go index 51f0c1bc..0d5d8762 100644 --- a/network/netutils/tcpassembly.go +++ b/network/netutils/tcpassembly.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package netutils import ( diff --git a/network/packet/packet.go b/network/packet/packet.go index f150476e..69bf590c 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package packet import ( diff --git a/network/packet/packetinfo.go b/network/packet/packetinfo.go index d1dc1ce0..74698156 100644 --- a/network/packet/packetinfo.go +++ b/network/packet/packetinfo.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package packet import ( diff --git a/network/status.go b/network/status.go index 6e2ab7b8..16534025 100644 --- a/network/status.go +++ b/network/status.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package network // Verdict describes the decision made about a connection or link. diff --git a/pmctl/build b/pmctl/build index a311b43b..26fb3525 100755 --- a/pmctl/build +++ b/pmctl/build @@ -43,10 +43,32 @@ if [[ "$BUILD_SOURCE" == "" ]]; then exit 1 fi +# build tools +EXTRA_LD_FLAGS="" +if [[ $GOOS == "windows" ]]; then + # checks + if [[ $CC_FOR_windows_amd64 == "" ]]; then + echo "ENV variable CC_FOR_windows_amd64 (c compiler) is not set. Please set it to the cross compiler you want to use for compiling for windows_amd64" + exit 1 + fi + if [[ $CXX_FOR_windows_amd64 == "" ]]; then + echo "ENV variable CXX_FOR_windows_amd64 (c++ compiler) is not set. Please set it to the cross compiler you want to use for compiling for windows_amd64" + exit 1 + fi + # compilers + export CC=$CC_FOR_windows_amd64 + export CXX=$CXX_FOR_windows_amd64 + # custom + export CGO_ENABLED=1 + EXTRA_LD_FLAGS='-H windowsgui' # Hide console window by default (but we attach to parent console if available) + # generate resource.syso for windows metadata / icon + go generate +fi + echo "Please notice, that this build script includes metadata into the build." echo "This information is useful for debugging and license compliance." echo "Run the compiled binary with the -version flag to see the information included." # build BUILD_PATH="github.com/safing/portbase/info" -go build -ldflags "-X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $* +go build -ldflags "$EXTRA_LD_FLAGS -X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $* diff --git a/pmctl/console_default.go b/pmctl/console_default.go new file mode 100644 index 00000000..7f2fc446 --- /dev/null +++ b/pmctl/console_default.go @@ -0,0 +1,12 @@ +// +build !windows + +package main + +import "os/exec" + +func attachToParentConsole() (attached bool, err error) { + return true, nil +} + +func hideWindow(cmd *exec.Cmd) { +} diff --git a/pmctl/console_windows.go b/pmctl/console_windows.go new file mode 100644 index 00000000..5a1c09b3 --- /dev/null +++ b/pmctl/console_windows.go @@ -0,0 +1,150 @@ +package main + +// Parts of this file are FORKED +// from https://github.com/apenwarr/fixconsole/blob/35b2e7d921eb80a71a5f04f166ff0a1405bddf79/fixconsole_windows.go +// on 16.07.2019 +// with Apache-2.0 license +// authored by https://github.com/apenwarr + +// docs/sources: +// Stackoverflow Question: https://stackoverflow.com/questions/23743217/printing-output-to-a-command-window-when-golang-application-is-compiled-with-ld +// MS AttachConsole: https://docs.microsoft.com/en-us/windows/console/attachconsole + +import ( + "log" + "os" + "os/exec" + "syscall" + + "golang.org/x/sys/windows" +) + +const ( + windowsAttachParentProcess = ^uintptr(0) // (DWORD)-1 +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procAttachConsole = kernel32.NewProc("AttachConsole") +) + +// Windows console output is a mess. +// +// If you compile as "-H windows", then if you launch your program without +// a console, Windows forcibly creates one to use as your stdin/stdout, which +// is silly for a GUI app, so we can't do that. +// +// If you compile as "-H windowsgui", then it doesn't create a console for +// your app... but also doesn't provide a working stdin/stdout/stderr even if +// you *did* launch from the console. However, you can use AttachConsole() +// to get a handle to your parent process's console, if any, and then +// os.NewFile() to turn that handle into a fd usable as stdout/stderr. +// +// However, then you have the problem that if you redirect stdout or stderr +// from the shell, you end up ignoring the redirection by forcing it to the +// console. +// +// To fix *that*, we have to detect whether there was a pre-existing stdout +// or not. We can check GetStdHandle(), which returns 0 for "should be +// console" and nonzero for "already pointing at a file." +// +// Be careful though! As soon as you run AttachConsole(), it resets *all* +// the GetStdHandle() handles to point them at the console instead, thus +// throwing away the original file redirects. So we have to GetStdHandle() +// *before* AttachConsole(). +// +// For some reason, powershell redirections provide a valid file handle, but +// writing to that handle doesn't write to the file. I haven't found a way +// to work around that. (Windows 10.0.17763.379) +// +// Net result is as follows. +// Before: +// SHELL NON-REDIRECTED REDIRECTED +// explorer.exe no console n/a +// cmd.exe broken works +// powershell broken broken +// WSL bash broken works +// After +// SHELL NON-REDIRECTED REDIRECTED +// explorer.exe no console n/a +// cmd.exe works works +// powershell works broken +// WSL bash works works +// +// We don't seem to make anything worse, at least. +func attachToParentConsole() (attached bool, err error) { + // get std handles before we attempt to attach to parent console + stdin, _ := syscall.GetStdHandle(syscall.STD_INPUT_HANDLE) + stdout, _ := syscall.GetStdHandle(syscall.STD_OUTPUT_HANDLE) + stderr, _ := syscall.GetStdHandle(syscall.STD_ERROR_HANDLE) + + // attempt to attach to parent console + err = procAttachConsole.Find() + if err != nil { + return false, err + } + r1, _, err := procAttachConsole.Call(windowsAttachParentProcess) + if r1 == 0 { + // possible errors: + // ERROR_ACCESS_DENIED: already attached to console + // ERROR_INVALID_HANDLE: process does not have console + // ERROR_INVALID_PARAMETER: process does not exist + return false, nil + } + + // get std handles after we attached to console + var invalid syscall.Handle + con := invalid + + if stdin == invalid { + stdin, _ = syscall.GetStdHandle(syscall.STD_INPUT_HANDLE) + } + if stdout == invalid { + stdout, _ = syscall.GetStdHandle(syscall.STD_OUTPUT_HANDLE) + con = stdout + } + if stderr == invalid { + stderr, _ = syscall.GetStdHandle(syscall.STD_ERROR_HANDLE) + con = stderr + } + + // correct output mode + if con != invalid { + // Make sure the console is configured to convert + // \n to \r\n, like Go programs expect. + h := windows.Handle(con) + var st uint32 + err := windows.GetConsoleMode(h, &st) + if err != nil { + log.Printf("failed to get console mode: %s\n", err) + } else { + err = windows.SetConsoleMode(h, st&^windows.DISABLE_NEWLINE_AUTO_RETURN) + if err != nil { + log.Printf("failed to set console mode: %s\n", err) + } + } + } + + // fix std handles to correct values (ie. redirects) + if stdin != invalid { + os.Stdin = os.NewFile(uintptr(stdin), "stdin") + log.Printf("fixed os.Stdin after attaching to parent console\n") + } + if stdout != invalid { + os.Stdout = os.NewFile(uintptr(stdout), "stdout") + log.Printf("fixed os.Stdout after attaching to parent console\n") + } + if stderr != invalid { + os.Stderr = os.NewFile(uintptr(stderr), "stderr") + log.Printf("fixed os.Stderr after attaching to parent console\n") + } + + log.Println("attached to parent console") + return true, nil +} + +func hideWindow(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + HideWindow: true, + } +} diff --git a/pmctl/get.go b/pmctl/get.go index 248e8f4f..81cc3842 100644 --- a/pmctl/get.go +++ b/pmctl/get.go @@ -2,7 +2,7 @@ package main import ( "errors" - "fmt" + "log" "time" "github.com/safing/portmaster/updates" @@ -22,7 +22,7 @@ func getFile(opts *Options) (*updates.File, error) { // download if opts.AllowDownload { - fmt.Printf("%s downloading %s...\n", logPrefix, opts.Identifier) + log.Printf("downloading %s...\n", opts.Identifier) // download indexes err = updates.UpdateIndexes() @@ -39,7 +39,7 @@ func getFile(opts *Options) (*updates.File, error) { } // wait for 30 seconds - fmt.Printf("%s waiting for download of %s (by Portmaster Core) to complete...\n", logPrefix, opts.Identifier) + log.Printf("waiting for download of %s (by Portmaster Core) to complete...\n", opts.Identifier) // try every 0.5 secs for tries := 0; tries < 60; tries++ { diff --git a/pmctl/install_windows.go b/pmctl/install_windows.go new file mode 100644 index 00000000..63cfa5e6 --- /dev/null +++ b/pmctl/install_windows.go @@ -0,0 +1,209 @@ +package main + +// Based on the offical Go examples from +// https://github.com/golang/sys/blob/master/windows/svc/example +// by The Go Authors. +// Original LICENSE (sha256sum: 2d36597f7117c38b006835ae7f537487207d8ec407aa9d9980794b2030cbc067) can be found in vendor/pkg cache directory. + +import ( + "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + + "github.com/spf13/cobra" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" +) + +func init() { + rootCmd.AddCommand(installCmd) + installCmd.AddCommand(installService) + + rootCmd.AddCommand(uninstallCmd) + uninstallCmd.AddCommand(uninstallService) +} + +var installCmd = &cobra.Command{ + Use: "install", + Short: "Install system integrations", +} + +var uninstallCmd = &cobra.Command{ + Use: "uninstall", + Short: "Uninstall system integrations", +} + +var installService = &cobra.Command{ + Use: "core-service", + Short: "Install Portmaster Core Windows Service", + RunE: installWindowsService, +} + +var uninstallService = &cobra.Command{ + Use: "core-service", + Short: "Uninstall Portmaster Core Windows Service", + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + // non-nil dummy to override db flag requirement + return nil + }, + RunE: uninstallWindowsService, +} + +func getExePath() (string, error) { + // get own filepath + prog := os.Args[0] + p, err := filepath.Abs(prog) + if err != nil { + return "", err + } + // check if the path is valid + fi, err := os.Stat(p) + if err == nil { + if !fi.Mode().IsDir() { + return p, nil + } + err = fmt.Errorf("%s is directory", p) + } + // check if we have a .exe extension, add and check if not + if filepath.Ext(p) == "" { + p += ".exe" + fi, err := os.Stat(p) + if err == nil { + if !fi.Mode().IsDir() { + return p, nil + } + err = fmt.Errorf("%s is directory", p) + } + } + return "", err +} + +func getServiceExecCommand(exePath string, escape bool) []string { + return []string{ + maybeEscape(exePath, escape), + "run", + "core-service", + "--db", + maybeEscape(dataRoot.Path, escape), + "--input-signals", + } +} + +func maybeEscape(s string, escape bool) string { + if escape { + return windows.EscapeArg(s) + } + return s +} + +func getServiceConfig(exePath string) mgr.Config { + return mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + BinaryPathName: strings.Join(getServiceExecCommand(exePath, true), " "), + DisplayName: "Portmaster Core", + Description: "Portmaster Application Firewall - Core Service", + } +} + +func getRecoveryActions() (recoveryActions []mgr.RecoveryAction, resetPeriod uint32) { + return []mgr.RecoveryAction{ + //mgr.RecoveryAction{ + // Type: mgr.ServiceRestart, // one of NoAction, ComputerReboot, ServiceRestart or RunCommand + // Delay: 1 * time.Minute, // the time to wait before performing the specified action + //}, + // mgr.RecoveryAction{ + // Type: mgr.ServiceRestart, // one of NoAction, ComputerReboot, ServiceRestart or RunCommand + // Delay: 1 * time.Minute, // the time to wait before performing the specified action + // }, + mgr.RecoveryAction{ + Type: mgr.ServiceRestart, // one of NoAction, ComputerReboot, ServiceRestart or RunCommand + Delay: 1 * time.Minute, // the time to wait before performing the specified action + }, + }, 86400 +} + +func installWindowsService(cmd *cobra.Command, args []string) error { + // get exe path + exePath, err := getExePath() + if err != nil { + return fmt.Errorf("failed to get exe path: %s", err) + } + + // connect to Windows service manager + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %s", err) + } + defer m.Disconnect() + + // open service + created := false + s, err := m.OpenService(serviceName) + if err != nil { + // create service + cmd := getServiceExecCommand(exePath, false) + s, err = m.CreateService(serviceName, cmd[0], getServiceConfig(exePath), cmd[1:]...) + if err != nil { + return fmt.Errorf("failed to create service: %s", err) + } + defer s.Close() + created = true + } else { + // update service + s.UpdateConfig(getServiceConfig(exePath)) + if err != nil { + return fmt.Errorf("failed to update service: %s", err) + } + defer s.Close() + } + + // update recovery actions + err = s.SetRecoveryActions(getRecoveryActions()) + if err != nil { + return fmt.Errorf("failed to update recovery actions: %s", err) + } + + if created { + log.Printf("created service %s\n", serviceName) + } else { + log.Printf("updated service %s\n", serviceName) + } + + return nil +} + +func uninstallWindowsService(cmd *cobra.Command, args []string) error { + // connect to Windows service manager + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + + // open service + s, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("service %s is not installed", serviceName) + } + defer s.Close() + + _, err = s.Control(svc.Stop) + if err != nil { + log.Printf("failed to stop service: %s\n", err) + } + + // delete service + err = s.Delete() + if err != nil { + return fmt.Errorf("failed to delete service: %s", err) + } + + log.Printf("uninstalled service %s\n", serviceName) + return nil +} diff --git a/pmctl/lock.go b/pmctl/lock.go new file mode 100644 index 00000000..8117109c --- /dev/null +++ b/pmctl/lock.go @@ -0,0 +1,70 @@ +package main + +import ( + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + "strconv" + "strings" + + processInfo "github.com/shirou/gopsutil/process" +) + +func checkAndCreateInstanceLock(name string) (pid int32, err error) { + lockFilePath := filepath.Join(dataRoot.Path, fmt.Sprintf("%s-lock.pid", name)) + + // read current pid file + data, err := ioutil.ReadFile(lockFilePath) + if err != nil { + if os.IsNotExist(err) { + // create new lock + return 0, createInstanceLock(lockFilePath) + } + return 0, err + } + + // file exists! + parsedPid, err := strconv.ParseInt(strings.TrimSpace(string(data)), 10, 64) + if err != nil { + log.Printf("failed to parse existing lock pid file (ignoring): %s\n", err) + return 0, createInstanceLock(lockFilePath) + } + + // check if process exists + p, err := processInfo.NewProcess(int32(parsedPid)) + if err == nil { + // TODO: remove this workaround as soon as NewProcess really returns an error on windows when the process does not exist + // Issue: https://github.com/shirou/gopsutil/issues/729 + _, err = p.Name() + if err == nil { + // process exists + return p.Pid, nil + } + } + + // else create new lock + return 0, createInstanceLock(lockFilePath) +} + +func createInstanceLock(lockFilePath string) error { + // check data root dir + err := dataRoot.Ensure() + if err != nil { + log.Printf("failed to check data root dir: %s\n", err) + } + + // create lock file + err = ioutil.WriteFile(lockFilePath, []byte(fmt.Sprintf("%d", os.Getpid())), 0666) + if err != nil { + return err + } + + return nil +} + +func deleteInstanceLock(name string) error { + lockFilePath := filepath.Join(dataRoot.Path, fmt.Sprintf("%s-lock.pid", name)) + return os.Remove(lockFilePath) +} diff --git a/pmctl/logs.go b/pmctl/logs.go new file mode 100644 index 00000000..d2ea37f7 --- /dev/null +++ b/pmctl/logs.go @@ -0,0 +1,142 @@ +package main + +import ( + "fmt" + "log" + "os" + "path/filepath" + "runtime" + "runtime/pprof" + "time" + + "github.com/safing/portbase/container" + "github.com/safing/portbase/database/record" + "github.com/safing/portbase/formats/dsd" + "github.com/safing/portbase/info" + "github.com/safing/portmaster/updates" + "github.com/spf13/cobra" +) + +func initializeLogFile(logFilePath string, identifier string, updateFile *updates.File) *os.File { + logFile, err := os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE, 0444) + if err != nil { + log.Printf("failed to create log file %s: %s\n", logFilePath, err) + return nil + } + + // create header, so that the portmaster can view log files as a database + meta := record.Meta{} + meta.Update() + meta.SetAbsoluteExpiry(time.Now().Add(720 * time.Hour).Unix()) // one month + + // manually marshal + // version + c := container.New([]byte{1}) + // meta + metaSection, err := dsd.Dump(meta, dsd.JSON) + if err != nil { + log.Printf("failed to serialize header for log file %s: %s\n", logFilePath, err) + finalizeLogFile(logFile, logFilePath) + return nil + } + c.AppendAsBlock(metaSection) + // log file data type (string) and newline for better manual viewing + c.Append([]byte("S\n")) + c.Append([]byte(fmt.Sprintf("executing %s version %s on %s %s\n", identifier, updateFile.Version(), runtime.GOOS, runtime.GOARCH))) + + _, err = logFile.Write(c.CompileData()) + if err != nil { + log.Printf("failed to write header for log file %s: %s\n", logFilePath, err) + finalizeLogFile(logFile, logFilePath) + return nil + } + + return logFile +} + +func finalizeLogFile(logFile *os.File, logFilePath string) { + err := logFile.Close() + if err != nil { + log.Printf("failed to close log file %s: %s\n", logFilePath, err) + } + + // check file size + stat, err := os.Stat(logFilePath) + if err == nil { + // delete if file is smaller than + if stat.Size() < 200 { // header + info is about 150 bytes + err := os.Remove(logFilePath) + if err != nil { + log.Printf("failed to delete empty log file %s: %s\n", logFilePath, err) + } + } + } +} + +func initControlLogFile() *os.File { + // check logging dir + logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + err := logsRoot.EnsureAbsPath(logFileBasePath) + if err != nil { + log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) + } + + // open log file + logFilePath := filepath.Join(logFileBasePath, fmt.Sprintf("%s.log", time.Now().UTC().Format("2006-02-01-15-04-05"))) + return initializeLogFile(logFilePath, "control/portmaster-control", updates.NewFile("", info.Version(), false)) +} + +func logControlError(cErr error) { + // check if error present + if cErr == nil { + return + } + + // check logging dir + logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + err := logsRoot.EnsureAbsPath(logFileBasePath) + if err != nil { + log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) + } + + // open log file + logFilePath := filepath.Join(logFileBasePath, fmt.Sprintf("%s.error.log", time.Now().UTC().Format("2006-02-01-15-04-05"))) + errorFile := initializeLogFile(logFilePath, "control/portmaster-control", updates.NewFile("", info.Version(), false)) + if errorFile == nil { + return + } + + // write error and close + fmt.Fprintln(errorFile, cErr.Error()) + errorFile.Close() +} + +func logControlStack() { + // check logging dir + logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + err := logsRoot.EnsureAbsPath(logFileBasePath) + if err != nil { + log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) + } + + // open log file + logFilePath := filepath.Join(logFileBasePath, fmt.Sprintf("%s.stack.log", time.Now().UTC().Format("2006-02-01-15-04-05"))) + errorFile := initializeLogFile(logFilePath, "control/portmaster-control", updates.NewFile("", info.Version(), false)) + if errorFile == nil { + return + } + + // write error and close + pprof.Lookup("goroutine").WriteTo(errorFile, 1) + errorFile.Close() +} + +func runAndLogControlError(wrappedFunc func(cmd *cobra.Command, args []string) error) func(cmd *cobra.Command, args []string) error { + return func(cmd *cobra.Command, args []string) error { + err := wrappedFunc(cmd, args) + if err != nil { + logControlError(err) + } + return err + } +} diff --git a/pmctl/main.go b/pmctl/main.go index c6f66109..16934329 100644 --- a/pmctl/main.go +++ b/pmctl/main.go @@ -2,51 +2,66 @@ package main import ( "errors" - "flag" "fmt" + "log" "os" - "os/user" - "path/filepath" - "runtime" + "os/signal" "strings" + "syscall" + + "github.com/safing/portmaster/core/structure" + "github.com/safing/portmaster/updates" + + "github.com/safing/portbase/utils" "github.com/safing/portbase/info" - "github.com/safing/portbase/log" - "github.com/safing/portmaster/updates" + portlog "github.com/safing/portbase/log" "github.com/spf13/cobra" ) -const ( - logPrefix = "[control]" -) - var ( - updateStoragePath string - databaseRootDir *string + dataDir string + databaseDir string + dataRoot *utils.DirStructure + logsRoot *utils.DirStructure + + showShortVersion bool + showFullVersion bool rootCmd = &cobra.Command{ Use: "portmaster-control", - Short: "contoller for all portmaster components", - PersistentPreRunE: initPmCtl, + Short: "Controller for all portmaster components", + PersistentPreRunE: cmdSetup, RunE: func(cmd *cobra.Command, args []string) error { + if showShortVersion { + fmt.Println(info.Version()) + return nil + } + if showFullVersion { + fmt.Println(info.FullVersion()) + return nil + } return cmd.Help() }, + SilenceUsage: true, } ) func init() { - databaseRootDir = rootCmd.PersistentFlags().String("db", "", "set database directory") - err := rootCmd.MarkPersistentFlagRequired("db") - if err != nil { - panic(err) - } + // Let cobra ignore if we are running as "GUI" or not + cobra.MousetrapHelpText = "" + + rootCmd.PersistentFlags().StringVar(&dataDir, "data", "", "set data directory") + rootCmd.PersistentFlags().StringVar(&databaseDir, "db", "", "alias to --data (deprecated)") + rootCmd.MarkPersistentFlagDirname("data") + rootCmd.MarkPersistentFlagDirname("db") + rootCmd.Flags().BoolVar(&showFullVersion, "version", false, "print version") + rootCmd.Flags().BoolVar(&showShortVersion, "ver", false, "print version number only") } func main() { - flag.Parse() - - // not using portbase logger - log.SetLogLevel(log.CriticalLevel) + // set meta info + info.Set("Portmaster Control", "0.2.11", "AGPLv3", true) // for debugging // log.Start() @@ -57,66 +72,105 @@ func main() { // os.Exit(1) // }() - // set meta info - info.Set("Portmaster Control", "0.2.1", "AGPLv3", true) - - // check if meta info is ok - err := info.CheckVersion() - if err != nil { - fmt.Printf("%s compile error: please compile using the provided build script\n", logPrefix) - os.Exit(1) - } - - // react to version flag - if info.PrintVersion() { - os.Exit(0) - } + // catch interrupt for clean shutdown + signalCh := make(chan os.Signal) + signal.Notify( + signalCh, + os.Interrupt, + os.Kill, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + ) // start root command - if err := rootCmd.Execute(); err != nil { - os.Exit(1) + go func() { + if err := rootCmd.Execute(); err != nil { + os.Exit(1) + } + os.Exit(0) + }() + + // for debugging windows service (no stdout/err) + // go func() { + // time.Sleep(10 * time.Second) + // // initiateShutdown(nil) + // // logControlStack() + // }() + + // wait for signals + for sig := range signalCh { + if childIsRunning.IsSet() { + log.Printf("got %s signal (ignoring), waiting for child to exit...\n", sig) + } else { + log.Printf("got %s signal, exiting... (not executing anything)\n", sig) + os.Exit(0) + } } - os.Exit(0) } -func initPmCtl(cmd *cobra.Command, args []string) error { - - // transform from db base path to updates path - if *databaseRootDir != "" { - updates.SetDatabaseRoot(*databaseRootDir) - updateStoragePath = filepath.Join(*databaseRootDir, "updates") - } else { - return errors.New("please supply the database directory using the --db flag") - } - - // check if we are root/admin for self upgrade - userInfo, err := user.Current() +func cmdSetup(cmd *cobra.Command, args []string) (err error) { + // check if we are running in a console (try to attach to parent console if available) + runningInConsole, err = attachToParentConsole() if err != nil { - return nil - } - switch runtime.GOOS { - case "linux": - if userInfo.Username != "root" { - return nil - } - case "windows": - if !strings.HasSuffix(userInfo.Username, "SYSTEM") { // is this correct? - return nil - } + log.Printf("failed to attach to parent console: %s\n", err) + os.Exit(1) } - err = removeOldBin() + // check if meta info is ok + err = info.CheckVersion() if err != nil { - fmt.Printf("%s warning: failed to remove old upgrade: %s\n", logPrefix, err) + fmt.Println("compile error: please compile using the provided build script") + os.Exit(1) } - update := checkForUpgrade() - if update != nil { - err = doSelfUpgrade(update) + // set up logging + log.SetFlags(log.Ldate | log.Ltime | log.LUTC) + log.SetPrefix("[control] ") + log.SetOutput(os.Stdout) + + // not using portbase logger + portlog.SetLogLevel(portlog.CriticalLevel) + + // data directory + if !showShortVersion && !showFullVersion { + // set data root + // backwards compatibility + if dataDir == "" { + dataDir = databaseDir + } + // check data dir + if dataDir == "" { + return errors.New("please set the data directory using --data=/path/to/data/dir") + } + + // remove redundant escape characters and quotes + dataDir = strings.Trim(dataDir, `\"`) + // initialize structure + err = structure.Initialize(dataDir, 0755) if err != nil { - return fmt.Errorf("%s failed to upgrade self: %s", logPrefix, err) + return fmt.Errorf("failed to initialize data root: %s", err) + } + dataRoot = structure.Root() + // manually set updates root (no modules) + updates.SetDataRoot(structure.Root()) + } + + // logs and warning + if !showShortVersion && !showFullVersion && !strings.Contains(cmd.CommandPath(), " show ") { + // set up logs root + logsRoot = structure.NewRootDir("logs", 0777) + err = logsRoot.Ensure() + if err != nil { + return fmt.Errorf("failed to initialize logs root: %s", err) + } + + // warn about CTRL-C on windows + if runningInConsole && onWindows { + log.Println("WARNING: portmaster-control is marked as a GUI application in order to get rid of the console window.") + log.Println("WARNING: CTRL-C will immediately kill without clean shutdown.") } - fmt.Println("upgraded portmaster-control") } return nil diff --git a/pmctl/run.go b/pmctl/run.go index 054d13ef..f2d625ed 100644 --- a/pmctl/run.go +++ b/pmctl/run.go @@ -3,20 +3,35 @@ package main import ( "fmt" "io" + "log" "os" "os/exec" - "os/signal" + "path" + "path/filepath" "runtime" "strings" - "syscall" + "sync" + "time" "github.com/spf13/cobra" + "github.com/tevino/abool" +) + +var ( + runningInConsole bool + onWindows = runtime.GOOS == "windows" + + childIsRunning = abool.NewBool(false) ) // Options for starting component type Options struct { - Identifier string - AllowDownload bool + Identifier string // component identifier + ShortIdentifier string // populated automatically + SuppressArgs bool // do not use any args + AllowDownload bool // allow download of component if it is not yet available + AllowHidingWindow bool // allow hiding the window of the subprocess + NoOutput bool // do not use stdout/err if logging to file is available (did not fail to open log file) } func init() { @@ -35,9 +50,10 @@ var runCore = &cobra.Command{ Use: "core", Short: "Run the Portmaster Core", RunE: func(cmd *cobra.Command, args []string) error { - return run(cmd, &Options{ - Identifier: "core/portmaster-core", - AllowDownload: true, + return handleRun(cmd, &Options{ + Identifier: "core/portmaster-core", + AllowDownload: true, + AllowHidingWindow: true, }) }, FParseErrWhitelist: cobra.FParseErrWhitelist{ @@ -50,9 +66,10 @@ var runApp = &cobra.Command{ Use: "app", Short: "Run the Portmaster App", RunE: func(cmd *cobra.Command, args []string) error { - return run(cmd, &Options{ - Identifier: "app/portmaster-app", - AllowDownload: false, + return handleRun(cmd, &Options{ + Identifier: "app/portmaster-app", + AllowDownload: false, + AllowHidingWindow: false, }) }, FParseErrWhitelist: cobra.FParseErrWhitelist{ @@ -65,9 +82,10 @@ var runNotifier = &cobra.Command{ Use: "notifier", Short: "Run the Portmaster Notifier", RunE: func(cmd *cobra.Command, args []string) error { - return run(cmd, &Options{ - Identifier: "notifier/portmaster-notifier", - AllowDownload: false, + return handleRun(cmd, &Options{ + Identifier: "notifier/portmaster-notifier", + AllowDownload: false, + AllowHidingWindow: true, }) }, FParseErrWhitelist: cobra.FParseErrWhitelist{ @@ -76,7 +94,40 @@ var runNotifier = &cobra.Command{ }, } -func run(cmd *cobra.Command, opts *Options) error { +func handleRun(cmd *cobra.Command, opts *Options) (err error) { + err = run(cmd, opts) + initiateShutdown(err) + return +} + +func run(cmd *cobra.Command, opts *Options) (err error) { + + // parse identifier + opts.ShortIdentifier = path.Dir(opts.Identifier) + + // check for concurrent error (eg. service) + shutdownLock.Lock() + alreadyDead := shutdownInitiated + shutdownLock.Unlock() + if alreadyDead { + return + } + + // check for duplicate instances + if opts.ShortIdentifier == "core" { + pid, _ := checkAndCreateInstanceLock(opts.ShortIdentifier) + if pid != 0 { + return fmt.Errorf("another instance of Portmaster Core is already running: PID %d", pid) + } + defer deleteInstanceLock(opts.ShortIdentifier) + } + + // notify service after some time + go func() { + // assume that after 5 seconds service has finished starting + time.Sleep(3 * time.Second) + startupComplete <- struct{}{} + }() // get original arguments var args []string @@ -84,115 +135,247 @@ func run(cmd *cobra.Command, opts *Options) error { return cmd.Help() } args = os.Args[3:] + if opts.SuppressArgs { + args = nil + } // adapt identifier - if windows() { + if onWindows { opts.Identifier += ".exe" } - // run - for { - file, err := getFile(opts) - if err != nil { - return fmt.Errorf("could not get component: %s", err) + // setup logging + // init log file + logFile := initControlLogFile() + if logFile != nil { + // don't close logFile, will be closed by system + if opts.NoOutput { + log.Println("disabling log output to stdout... bye!") + log.SetOutput(logFile) + } else { + log.SetOutput(io.MultiWriter(os.Stdout, logFile)) } - - // check permission - if !windows() { - info, err := os.Stat(file.Path()) - if err != nil { - return fmt.Errorf("failed to get file info on %s: %s", file.Path(), err) - } - if info.Mode() != 0755 { - err := os.Chmod(file.Path(), 0755) - if err != nil { - return fmt.Errorf("failed to set exec permissions on %s: %s", file.Path(), err) - } - } - } - - fmt.Printf("%s starting %s %s\n", logPrefix, file.Path(), strings.Join(args, " ")) - - // create command - exc := exec.Command(file.Path(), args...) - - // consume stdout/stderr - stdout, err := exc.StdoutPipe() - if err != nil { - return fmt.Errorf("failed to connect stdout: %s", err) - } - stderr, err := exc.StderrPipe() - if err != nil { - return fmt.Errorf("failed to connect stderr: %s", err) - } - - // start - err = exc.Start() - if err != nil { - return fmt.Errorf("failed to start %s: %s", opts.Identifier, err) - } - - // start output writers - go func() { - io.Copy(os.Stdout, stdout) - }() - go func() { - io.Copy(os.Stderr, stderr) - }() - - // catch interrupt for clean shutdown - signalCh := make(chan os.Signal) - signal.Notify( - signalCh, - os.Interrupt, - os.Kill, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT, - ) - go func() { - for { - sig := <-signalCh - fmt.Printf("%s got %s signal (ignoring), waiting for %s to exit...\n", logPrefix, sig, opts.Identifier) - } - }() - - // wait for completion - err = exc.Wait() - if err != nil { - exErr, ok := err.(*exec.ExitError) - if ok { - switch exErr.ProcessState.ExitCode() { - case 0: - // clean exit - fmt.Printf("%s clean exit of %s, but with error: %s\n", logPrefix, opts.Identifier, err) - os.Exit(1) - case 1: - // error exit - fmt.Printf("%s error during execution of %s: %s\n", logPrefix, opts.Identifier, err) - os.Exit(1) - case 2357427: // Leet Speak for "restart" - // restart request - fmt.Printf("%s restarting %s\n", logPrefix, opts.Identifier) - continue - default: - fmt.Printf("%s unexpected error during execution of %s: %s\n", logPrefix, opts.Identifier, err) - os.Exit(exErr.ProcessState.ExitCode()) - } - } else { - fmt.Printf("%s unexpected error type during execution of %s: %s\n", logPrefix, opts.Identifier, err) - os.Exit(1) - } - } - // clean exit - break } - fmt.Printf("%s %s completed successfully\n", logPrefix, opts.Identifier) - return nil + // run + tries := 0 + for { + // normal execution + tryAgain := false + tryAgain, err = execute(opts, args) + switch { + case tryAgain && err != nil: + // temporary? execution error + log.Printf("execution of %s failed: %s\n", opts.Identifier, err) + tries++ + if tries >= 5 { + log.Println("error seems to be permanent, giving up...") + return err + } + log.Println("trying again...") + case tryAgain && err == nil: + // upgrade + log.Println("restarting by request...") + case !tryAgain && err != nil: + // fatal error + return err + case !tryAgain && err == nil: + // clean exit + log.Printf("%s completed successfully\n", opts.Identifier) + return nil + } + } } -func windows() bool { - return runtime.GOOS == "windows" +func execute(opts *Options, args []string) (cont bool, err error) { + file, err := getFile(opts) + if err != nil { + return true, fmt.Errorf("could not get component: %s", err) + } + + // check permission + if !onWindows { + info, err := os.Stat(file.Path()) + if err != nil { + return true, fmt.Errorf("failed to get file info on %s: %s", file.Path(), err) + } + if info.Mode() != 0755 { + err := os.Chmod(file.Path(), 0755) + if err != nil { + return true, fmt.Errorf("failed to set exec permissions on %s: %s", file.Path(), err) + } + } + } + + log.Printf("starting %s %s\n", file.Path(), strings.Join(args, " ")) + + // log files + var logFile, errorFile *os.File + logFileBasePath := filepath.Join(logsRoot.Path, "fstree", opts.ShortIdentifier) + err = logsRoot.EnsureAbsPath(logFileBasePath) + if err != nil { + log.Printf("failed to check/create log file dir %s: %s\n", logFileBasePath, err) + } else { + // open log file + logFilePath := filepath.Join(logFileBasePath, fmt.Sprintf("%s.log", time.Now().UTC().Format("2006-02-01-15-04-05"))) + logFile = initializeLogFile(logFilePath, opts.Identifier, file) + if logFile != nil { + defer finalizeLogFile(logFile, logFilePath) + } + // open error log file + errorFilePath := filepath.Join(logFileBasePath, fmt.Sprintf("%s.error.log", time.Now().UTC().Format("2006-02-01-15-04-05"))) + errorFile = initializeLogFile(errorFilePath, opts.Identifier, file) + if errorFile != nil { + defer finalizeLogFile(errorFile, errorFilePath) + } + } + + // create command + exc := exec.Command(file.Path(), args...) + + if !runningInConsole && opts.AllowHidingWindow { + // Windows only: + // only hide (all) windows of program if we are not running in console and windows may be hidden + hideWindow(exc) + } + + // check if input signals are enabled + inputSignalsEnabled := false + for _, arg := range args { + if strings.HasSuffix(arg, "-input-signals") { + inputSignalsEnabled = true + break + } + } + + // consume stdout/stderr + stdout, err := exc.StdoutPipe() + if err != nil { + return true, fmt.Errorf("failed to connect stdout: %s", err) + } + stderr, err := exc.StderrPipe() + if err != nil { + return true, fmt.Errorf("failed to connect stderr: %s", err) + } + var stdin io.WriteCloser + if inputSignalsEnabled { + stdin, err = exc.StdinPipe() + if err != nil { + return true, fmt.Errorf("failed to connect stdin: %s", err) + } + } + + // start + err = exc.Start() + if err != nil { + return true, fmt.Errorf("failed to start %s: %s", opts.Identifier, err) + } + childIsRunning.Set() + + // start output writers + var wg sync.WaitGroup + wg.Add(2) + go func() { + var logFileError error + if logFile == nil { + _, logFileError = io.Copy(os.Stdout, stdout) + } else { + if opts.NoOutput { + _, logFileError = io.Copy(logFile, stdout) + } else { + _, logFileError = io.Copy(io.MultiWriter(os.Stdout, logFile), stdout) + } + } + if logFileError != nil { + log.Printf("failed write logs: %s\n", logFileError) + } + wg.Done() + }() + go func() { + var errorFileError error + if logFile == nil { + _, errorFileError = io.Copy(os.Stderr, stderr) + } else { + if opts.NoOutput { + _, errorFileError = io.Copy(errorFile, stderr) + } else { + _, errorFileError = io.Copy(io.MultiWriter(os.Stderr, errorFile), stderr) + } + } + if errorFileError != nil { + log.Printf("failed write error logs: %s\n", errorFileError) + } + wg.Done() + }() + + // wait for completion + finished := make(chan error) + go func() { + // wait for output writers to complete + wg.Wait() + // wait for process to return + finished <- exc.Wait() + // update status + childIsRunning.UnSet() + // notify manager + close(finished) + }() + + // state change listeners + for { + select { + case <-shuttingDown: + // signal process shutdown + if inputSignalsEnabled { + // for windows + _, err = stdin.Write([]byte("SIGINT\n")) + } else { + err = exc.Process.Signal(os.Interrupt) + } + if err != nil { + log.Printf("failed to signal %s to shutdown: %s\n", opts.Identifier, err) + err = exc.Process.Kill() + if err != nil { + return false, fmt.Errorf("failed to kill %s: %s", opts.Identifier, err) + } + return false, fmt.Errorf("killed %s", opts.Identifier) + } + // wait until shut down + select { + case <-finished: + case <-time.After(11 * time.Second): // portmaster core prints stack if not able to shutdown in 10 seconds + // kill + err = exc.Process.Kill() + if err != nil { + return false, fmt.Errorf("failed to kill %s: %s", opts.Identifier, err) + } + return false, fmt.Errorf("killed %s", opts.Identifier) + } + return false, nil + case err := <-finished: + if err != nil { + exErr, ok := err.(*exec.ExitError) + if ok { + switch exErr.ProcessState.ExitCode() { + case 0: + // clean exit + return false, fmt.Errorf("clean exit, but with error: %s", err) + case 1: + // error exit + return true, fmt.Errorf("error during execution: %s", err) + case 2357427: // Leet Speak for "restart" + // restart request + log.Printf("restarting %s\n", opts.Identifier) + return true, nil + default: + return true, fmt.Errorf("unexpected error during execution: %s", err) + } + } else { + return true, fmt.Errorf("unexpected error type during execution: %s", err) + } + } + // clean exit + return false, nil + } + } } diff --git a/pmctl/service.go b/pmctl/service.go new file mode 100644 index 00000000..4a9130ca --- /dev/null +++ b/pmctl/service.go @@ -0,0 +1,31 @@ +package main + +import ( + "sync" +) + +var ( + startupComplete = make(chan struct{}) // signal that the start procedure completed (is never closed, just signaled once) + shuttingDown = make(chan struct{}) // signal that we are shutting down (will be closed, may not be closed directly, use initiateShutdown) + shutdownInitiated = false // not to be used directly + shutdownError error // may not be read or written to directly + shutdownLock sync.Mutex +) + +func initiateShutdown(err error) { + shutdownLock.Lock() + defer shutdownLock.Unlock() + + if !shutdownInitiated { + shutdownInitiated = true + shutdownError = err + close(shuttingDown) + } +} + +func getShutdownError() error { + shutdownLock.Lock() + defer shutdownLock.Unlock() + + return shutdownError +} diff --git a/pmctl/service_windows.go b/pmctl/service_windows.go new file mode 100644 index 00000000..4fb68bc6 --- /dev/null +++ b/pmctl/service_windows.go @@ -0,0 +1,138 @@ +package main + +// Based on the offical Go examples from +// https://github.com/golang/sys/blob/master/windows/svc/example +// by The Go Authors. +// Original LICENSE (sha256sum: 2d36597f7117c38b006835ae7f537487207d8ec407aa9d9980794b2030cbc067) can be found in vendor/pkg cache directory. + +import ( + "fmt" + "log" + "sync" + "time" + + "github.com/spf13/cobra" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" + "golang.org/x/sys/windows/svc/eventlog" +) + +var ( + runCoreService = &cobra.Command{ + Use: "core-service", + Short: "Run the Portmaster Core as a Windows Service", + RunE: runAndLogControlError(func(cmd *cobra.Command, args []string) error { + return runService(cmd, &Options{ + Identifier: "core/portmaster-core", + AllowDownload: true, + AllowHidingWindow: false, + NoOutput: true, + }) + }), + FParseErrWhitelist: cobra.FParseErrWhitelist{ + // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + UnknownFlags: true, + }, + } + + // wait groups + runWg sync.WaitGroup + finishWg sync.WaitGroup +) + +func init() { + runCmd.AddCommand(runCoreService) +} + +const serviceName = "PortmasterCore" + +type windowsService struct{} + +func (ws *windowsService) Execute(args []string, changeRequests <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) { + const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown + changes <- svc.Status{State: svc.StartPending} + +service: + for { + select { + case <-startupComplete: + changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + case <-shuttingDown: + changes <- svc.Status{State: svc.StopPending} + break service + case c := <-changeRequests: + switch c.Cmd { + case svc.Interrogate: + changes <- c.CurrentStatus + case svc.Stop, svc.Shutdown: + initiateShutdown(nil) + default: + log.Printf("unexpected control request: #%d\n", c) + } + } + } + + // define return values + if getShutdownError() != nil { + ssec = true // this error is specific to this service (ie. custom) + errno = 1 // generic error, check logs / windows events + } + + // wait until everything else is finished + finishWg.Wait() + // send stopped status + changes <- svc.Status{State: svc.Stopped} + // wait a little for the status to reach Windows + time.Sleep(100 * time.Millisecond) + return +} + +func runService(cmd *cobra.Command, opts *Options) error { + // check if we are running interactively + isDebug, err := svc.IsAnInteractiveSession() + if err != nil { + return fmt.Errorf("could not determine if running interactively: %s", err) + } + // select service run type + svcRun := svc.Run + if isDebug { + log.Printf("WARNING: running interactively, switching to debug execution (no real service).\n") + svcRun = debug.Run + } + + // open eventlog + elog, err := eventlog.Open(serviceName) + if err != nil { + return fmt.Errorf("failed to open eventlog: %s", err) + } + defer elog.Close() + + runWg.Add(2) + finishWg.Add(1) + + // run service client + go func() { + sErr := svcRun(serviceName, &windowsService{}) + initiateShutdown(sErr) + runWg.Done() + }() + + // run service + go func() { + // run slightly delayed + time.Sleep(250 * time.Millisecond) + handleRun(cmd, opts) + finishWg.Done() + runWg.Done() + }() + + runWg.Wait() + + err = getShutdownError() + if err != nil { + log.Printf("%s service experienced an error: %s\n", serviceName, err) + elog.Error(1, fmt.Sprintf("%s experienced an error: %s", serviceName, err)) + } + + return err +} diff --git a/pmctl/show.go b/pmctl/show.go new file mode 100644 index 00000000..9ea0934f --- /dev/null +++ b/pmctl/show.go @@ -0,0 +1,90 @@ +package main + +import ( + "fmt" + "os" + "strings" + + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(showCmd) + showCmd.AddCommand(showCore) + showCmd.AddCommand(showApp) + showCmd.AddCommand(showNotifier) +} + +var showCmd = &cobra.Command{ + Use: "show", + Short: "Show the command to run a Portmaster component yourself", +} + +var showCore = &cobra.Command{ + Use: "core", + Short: "Show command to run the Portmaster Core", + RunE: func(cmd *cobra.Command, args []string) error { + return show(cmd, &Options{ + Identifier: "core/portmaster-core", + }) + }, + FParseErrWhitelist: cobra.FParseErrWhitelist{ + // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + UnknownFlags: true, + }, +} + +var showApp = &cobra.Command{ + Use: "app", + Short: "Show command to run the Portmaster App", + RunE: func(cmd *cobra.Command, args []string) error { + return show(cmd, &Options{ + Identifier: "app/portmaster-app", + }) + }, + FParseErrWhitelist: cobra.FParseErrWhitelist{ + // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + UnknownFlags: true, + }, +} + +var showNotifier = &cobra.Command{ + Use: "notifier", + Short: "Show command to run the Portmaster Notifier", + RunE: func(cmd *cobra.Command, args []string) error { + return show(cmd, &Options{ + Identifier: "notifier/portmaster-notifier", + SuppressArgs: true, + }) + }, + FParseErrWhitelist: cobra.FParseErrWhitelist{ + // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + UnknownFlags: true, + }, +} + +func show(cmd *cobra.Command, opts *Options) error { + // get original arguments + var args []string + if len(os.Args) < 4 { + return cmd.Help() + } + args = os.Args[3:] + if opts.SuppressArgs { + args = nil + } + + // adapt identifier + if onWindows { + opts.Identifier += ".exe" + } + + file, err := getFile(opts) + if err != nil { + return fmt.Errorf("could not get component: %s", err) + } + + fmt.Printf("%s %s\n", file.Path(), strings.Join(args, " ")) + + return nil +} diff --git a/pmctl/snoretoast_windows.go b/pmctl/snoretoast_windows.go new file mode 100644 index 00000000..75421253 --- /dev/null +++ b/pmctl/snoretoast_windows.go @@ -0,0 +1,40 @@ +package main + +import "github.com/spf13/cobra" + +func init() { + showCmd.AddCommand(showSnoreToast) + runCmd.AddCommand(runSnoreToast) +} + +var showSnoreToast = &cobra.Command{ + Use: "notifier-snoretoast", + Short: "Show command to run the Notifier component SnoreToast", + RunE: func(cmd *cobra.Command, args []string) error { + return show(cmd, &Options{ + Identifier: "notifier/portmaster-snoretoast", + SuppressArgs: true, + }) + }, + FParseErrWhitelist: cobra.FParseErrWhitelist{ + // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + UnknownFlags: true, + }, +} + +var runSnoreToast = &cobra.Command{ + Use: "notifier-snoretoast", + Short: "Run the Notifier component SnoreToast", + RunE: func(cmd *cobra.Command, args []string) error { + return handleRun(cmd, &Options{ + Identifier: "notifier/portmaster-snoretoast", + AllowDownload: false, + AllowHidingWindow: true, + SuppressArgs: true, + }) + }, + FParseErrWhitelist: cobra.FParseErrWhitelist{ + // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + UnknownFlags: true, + }, +} diff --git a/pmctl/upgrade.go b/pmctl/upgrade.go deleted file mode 100644 index 142840a8..00000000 --- a/pmctl/upgrade.go +++ /dev/null @@ -1,124 +0,0 @@ -package main - -import ( - "fmt" - "io" - "os" - "path/filepath" - "runtime" - - "github.com/safing/portbase/info" - "github.com/safing/portmaster/updates" -) - -var ( - oldBinSuffix = "-old" -) - -func checkForUpgrade() (update *updates.File) { - info := info.GetInfo() - file, err := updates.GetLocalPlatformFile("control/portmaster-control") - if err != nil { - return nil - } - if info.Version != file.Version() { - return file - } - return nil -} - -func doSelfUpgrade(file *updates.File) error { - - // FIXME: fix permissions if needed - - // get destination - dst, err := os.Executable() - if err != nil { - return err - } - dst, err = filepath.EvalSymlinks(dst) - if err != nil { - return err - } - - // mv destination - err = os.Rename(dst, dst+oldBinSuffix) - if err != nil { - return err - } - - // hard link - err = os.Link(file.Path(), dst) - if err != nil { - fmt.Printf("%s failed to hardlink self upgrade: %s, will copy...\n", logPrefix, err) - err = copyFile(file.Path(), dst) - if err != nil { - return err - } - } - - // check permission - if runtime.GOOS != "windows" { - info, err := os.Stat(dst) - if err != nil { - return fmt.Errorf("failed to get file info on %s: %s", dst, err) - } - if info.Mode() != 0755 { - err := os.Chmod(dst, 0755) - if err != nil { - return fmt.Errorf("failed to set permissions on %s: %s", dst, err) - } - } - } - return nil -} - -func copyFile(srcPath, dstPath string) (err error) { - srcFile, err := os.Open(srcPath) - if err != nil { - return - } - defer srcFile.Close() - - dstFile, err := os.Create(dstPath) - if err != nil { - return - } - defer func() { - closeErr := dstFile.Close() - if err == nil { - err = closeErr - } - }() - - _, err = io.Copy(dstFile, srcFile) - if err != nil { - return - } - err = dstFile.Sync() - return -} - -func removeOldBin() error { - // get location - dst, err := os.Executable() - if err != nil { - return err - } - dst, err = filepath.EvalSymlinks(dst) - if err != nil { - return err - } - - // delete old - err = os.Remove(dst + oldBinSuffix) - if err != nil { - if !os.IsNotExist(err) { - return err - } - return nil - } - - fmt.Println("removed previous portmaster-control") - return nil -} diff --git a/process/executable.go b/process/executable.go index 3bb6c00d..a6f4fdce 100644 --- a/process/executable.go +++ b/process/executable.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package process import ( diff --git a/process/proc/gather.go b/process/proc/gather.go index 98a83be4..30f3c053 100644 --- a/process/proc/gather.go +++ b/process/proc/gather.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package proc import ( diff --git a/process/proc/processfinder.go b/process/proc/processfinder.go index a6391c2c..4a3647ee 100644 --- a/process/proc/processfinder.go +++ b/process/proc/processfinder.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package proc import ( diff --git a/process/proc/processfinder_test.go b/process/proc/processfinder_test.go index ada7d02b..30df4483 100644 --- a/process/proc/processfinder_test.go +++ b/process/proc/processfinder_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package proc import ( diff --git a/process/proc/sockets.go b/process/proc/sockets.go index b9c773ea..036f5c4d 100644 --- a/process/proc/sockets.go +++ b/process/proc/sockets.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package proc import ( diff --git a/process/proc/sockets_test.go b/process/proc/sockets_test.go index 6bc77bfc..71574ffe 100644 --- a/process/proc/sockets_test.go +++ b/process/proc/sockets_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package proc import ( diff --git a/process/process.go b/process/process.go index cdadc389..6d18474f 100644 --- a/process/process.go +++ b/process/process.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package process import ( @@ -264,7 +262,13 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { pInfo, err := processInfo.NewProcess(int32(pid)) if err != nil { - return nil, err + // TODO: remove this workaround as soon as NewProcess really returns an error on windows when the process does not exist + // Issue: https://github.com/shirou/gopsutil/issues/729 + _, err = pInfo.Name() + if err != nil { + // process does not exists + return nil, err + } } // UID diff --git a/process/xdgicons/humanInfo.go b/process/xdgicons/humanInfo.go index bbab8f75..f02c271e 100644 --- a/process/xdgicons/humanInfo.go +++ b/process/xdgicons/humanInfo.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package process import ( diff --git a/process/xdgicons/humanInfo_test.go b/process/xdgicons/humanInfo_test.go index ed7501e8..4af29226 100644 --- a/process/xdgicons/humanInfo_test.go +++ b/process/xdgicons/humanInfo_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package process import ( diff --git a/process/xdgicons/locations.go b/process/xdgicons/locations.go index bd258827..d3956199 100644 --- a/process/xdgicons/locations.go +++ b/process/xdgicons/locations.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package process // spec: https://specifications.freedesktop.org/icon-theme-spec/icon-theme-spec-latest.html diff --git a/profile/index/indexer.go b/profile/index/indexer.go index 471c66b9..e76beb7a 100644 --- a/profile/index/indexer.go +++ b/profile/index/indexer.go @@ -23,7 +23,7 @@ var ( ) func init() { - modules.Register("profile:index", nil, start, stop, "profile", "database") + modules.Register("profile:index", nil, start, stop, "core", "profile") } func start() (err error) { diff --git a/ui/module.go b/ui/module.go index 3927c4ed..3971ae43 100644 --- a/ui/module.go +++ b/ui/module.go @@ -1,13 +1,11 @@ package ui import ( - "github.com/safing/portbase/api" "github.com/safing/portbase/modules" ) func init() { - modules.Register("ui", prep, nil, nil, "updates", "api") - api.SetDefaultAPIListenAddress("127.0.0.1:817") + modules.Register("ui", prep, nil, nil, "core", "updates") } func prep() error { diff --git a/ui/serve.go b/ui/serve.go index 659d4d1c..4632c2db 100644 --- a/ui/serve.go +++ b/ui/serve.go @@ -140,7 +140,7 @@ func RedirectToBase(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - http.Redirect(w, r, r.URL.ResolveReference(u).String(), http.StatusPermanentRedirect) + http.Redirect(w, r, r.URL.ResolveReference(u).String(), http.StatusTemporaryRedirect) } func redirAddSlash(w http.ResponseWriter, r *http.Request) { diff --git a/updates/fetch.go b/updates/fetch.go index 4d3d3c02..fdbf86d7 100644 --- a/updates/fetch.go +++ b/updates/fetch.go @@ -15,7 +15,6 @@ import ( "github.com/google/renameio" "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" ) var ( @@ -38,13 +37,13 @@ func fetchFile(realFilepath, updateFilepath string, tries int) error { // check destination dir dirPath := filepath.Dir(realFilepath) - err = utils.EnsureDirectory(dirPath, 0755) + err = updateStorage.EnsureAbsPath(dirPath) if err != nil { return fmt.Errorf("could not create updates folder: %s", dirPath) } // open file for writing - atomicFile, err := renameio.TempFile(downloadTmpPath, realFilepath) + atomicFile, err := renameio.TempFile(tmpStorage.Path, realFilepath) if err != nil { return fmt.Errorf("could not create temp file for download: %s", err) } diff --git a/updates/filename.go b/updates/filename.go index 1768255d..a4c52a4c 100644 --- a/updates/filename.go +++ b/updates/filename.go @@ -6,12 +6,15 @@ import ( "strings" ) -var versionRegex = regexp.MustCompile("_v[0-9]+-[0-9]+-[0-9]+b?") +var ( + fileVersionRegex = regexp.MustCompile(`_v[0-9]+-[0-9]+-[0-9]+b?`) + rawVersionRegex = regexp.MustCompile(`^[0-9]+\.[0-9]+\.[0-9]+b?\*?$`) +) // GetIdentifierAndVersion splits the given file path into its identifier and version. func GetIdentifierAndVersion(versionedPath string) (identifier, version string, ok bool) { // extract version - rawVersion := versionRegex.FindString(versionedPath) + rawVersion := fileVersionRegex.FindString(versionedPath) if rawVersion == "" { return "", "", false } diff --git a/updates/filename_test.go b/updates/filename_test.go new file mode 100644 index 00000000..274e8e6d --- /dev/null +++ b/updates/filename_test.go @@ -0,0 +1,51 @@ +package updates + +import ( + "regexp" + "testing" +) + +func testRegexMatch(t *testing.T, testRegex *regexp.Regexp, testString string, shouldMatch bool) { + if testRegex.MatchString(testString) != shouldMatch { + if shouldMatch { + t.Errorf("regex %s should match %s", testRegex, testString) + } else { + t.Errorf("regex %s should not match %s", testRegex, testString) + } + } +} + +func testRegexFind(t *testing.T, testRegex *regexp.Regexp, testString string, shouldMatch bool) { + if (testRegex.FindString(testString) != "") != shouldMatch { + if shouldMatch { + t.Errorf("regex %s should find %s", testRegex, testString) + } else { + t.Errorf("regex %s should not find %s", testRegex, testString) + } + } +} + +func TestRegexes(t *testing.T) { + testRegexMatch(t, rawVersionRegex, "0.1.2", true) + testRegexMatch(t, rawVersionRegex, "0.1.2*", true) + testRegexMatch(t, rawVersionRegex, "0.1.2b", true) + testRegexMatch(t, rawVersionRegex, "0.1.2b*", true) + testRegexMatch(t, rawVersionRegex, "12.13.14", true) + + testRegexMatch(t, rawVersionRegex, "v0.1.2", false) + testRegexMatch(t, rawVersionRegex, "0.", false) + testRegexMatch(t, rawVersionRegex, "0.1", false) + testRegexMatch(t, rawVersionRegex, "0.1.", false) + testRegexMatch(t, rawVersionRegex, ".1.2", false) + testRegexMatch(t, rawVersionRegex, ".1.", false) + testRegexMatch(t, rawVersionRegex, "012345", false) + + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2-3", true) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2-3.exe", true) + + testRegexFind(t, fileVersionRegex, "/path/to/file-v1-2-3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1.2.3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_1-2-3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2", false) + testRegexFind(t, fileVersionRegex, "/path/to/file-v1-2-3", false) +} diff --git a/updates/get.go b/updates/get.go index 9d4b7ac8..dfd5ba50 100644 --- a/updates/get.go +++ b/updates/get.go @@ -9,7 +9,6 @@ import ( "runtime" "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" ) // Errors @@ -75,7 +74,7 @@ func loadOrFetchFile(identifier string, fetch bool) (*File, error) { } // build final filepath - realFilePath := filepath.Join(updateStoragePath, filepath.FromSlash(versionedFilePath)) + realFilePath := filepath.Join(updateStorage.Path, filepath.FromSlash(versionedFilePath)) if _, err := os.Stat(realFilePath); err == nil { // file exists updateUsedStatus(identifier, version) @@ -83,7 +82,7 @@ func loadOrFetchFile(identifier string, fetch bool) (*File, error) { } // check download dir - err := utils.EnsureDirectory(downloadTmpPath, 0755) + err := tmpStorage.Ensure() if err != nil { return nil, fmt.Errorf("could not prepare tmp directory for download: %s", err) } diff --git a/updates/latest.go b/updates/latest.go index 57109a7f..0e96af8c 100644 --- a/updates/latest.go +++ b/updates/latest.go @@ -11,6 +11,9 @@ import ( "sync" "github.com/safing/portbase/log" + "github.com/safing/portbase/utils" + + semver "github.com/hashicorp/go-version" ) var ( @@ -26,14 +29,14 @@ func LoadLatest() error { // all prefix := "all" - new, err1 := ScanForLatest(filepath.Join(updateStoragePath, prefix), false) + new, err1 := ScanForLatest(filepath.Join(updateStorage.Path, prefix), false) for key, val := range new { newLocalUpdates[filepath.ToSlash(filepath.Join(prefix, key))] = val } // os_platform prefix = fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH) - new, err2 := ScanForLatest(filepath.Join(updateStoragePath, prefix), false) + new, err2 := ScanForLatest(filepath.Join(updateStorage.Path, prefix), false) for key, val := range new { newLocalUpdates[filepath.ToSlash(filepath.Join(prefix, key))] = val } @@ -70,11 +73,11 @@ func ScanForLatest(baseDir string, hardFail bool) (latest map[string]string, las filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error { if err != nil { if !os.IsNotExist(err) { - lastError = err + lastError = fmt.Errorf("updates: could not read %s: %s", path, err) if hardFail { - return err + return lastError } - log.Warningf("updates: could not read %s", path) + log.Warning(lastError.Error()) } return nil } @@ -95,9 +98,24 @@ func ScanForLatest(baseDir string, hardFail bool) (latest map[string]string, las // add/update index storedVersion, ok := latest[identifierPath] if ok { - // FIXME: this will fail on multi-digit version segments! - // FIXME: use https://github.com/hashicorp/go-version - if version > storedVersion { + parsedVersion, err := semver.NewVersion(version) + if err != nil { + lastError = fmt.Errorf("updates: could not parse version of %s: %s", path, err) + if hardFail { + return lastError + } + log.Warning(lastError.Error()) + } + parsedStoredVersion, err := semver.NewVersion(storedVersion) + if err != nil { + lastError = fmt.Errorf("updates: could not parse version of %s: %s", path, err) + if hardFail { + return lastError + } + log.Warning(lastError.Error()) + } + // compare + if parsedVersion.GreaterThan(parsedStoredVersion) { latest[identifierPath] = version } } else { @@ -120,7 +138,7 @@ func ScanForLatest(baseDir string, hardFail bool) (latest map[string]string, las // LoadIndexes loads the current update indexes from disk. func LoadIndexes() error { - data, err := ioutil.ReadFile(filepath.Join(updateStoragePath, "stable.json")) + data, err := ioutil.ReadFile(filepath.Join(updateStorage.Path, "stable.json")) if err != nil { return err } @@ -148,3 +166,39 @@ func LoadIndexes() error { return nil } + +// CreateSymlinks creates a directory structure with unversions symlinks to the given updates list. +func CreateSymlinks(symlinkRoot, updateStorage *utils.DirStructure, updatesList map[string]string) error { + err := os.RemoveAll(symlinkRoot.Path) + if err != nil { + return fmt.Errorf("failed to wipe symlink root: %s", err) + } + + err = symlinkRoot.Ensure() + if err != nil { + return fmt.Errorf("failed to create symlink root: %s", err) + } + + for identifier, version := range updatesList { + targetPath := filepath.Join(updateStorage.Path, filepath.FromSlash(GetVersionedPath(identifier, version))) + linkPath := filepath.Join(symlinkRoot.Path, filepath.FromSlash(identifier)) + linkPathDir := filepath.Dir(linkPath) + + err = symlinkRoot.EnsureAbsPath(linkPathDir) + if err != nil { + return fmt.Errorf("failed to create dir for link: %s", err) + } + + relativeTargetPath, err := filepath.Rel(linkPathDir, targetPath) + if err != nil { + return fmt.Errorf("failed to get relative target path: %s", err) + } + + err = os.Symlink(relativeTargetPath, linkPath) + if err != nil { + return fmt.Errorf("failed to link %s: %s", identifier, err) + } + } + + return nil +} diff --git a/updates/main.go b/updates/main.go index 395e2f92..11c2c087 100644 --- a/updates/main.go +++ b/updates/main.go @@ -3,25 +3,30 @@ package updates import ( "errors" "os" - "path/filepath" + "runtime" + + "github.com/safing/portmaster/core/structure" - "github.com/safing/portbase/database" "github.com/safing/portbase/info" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" "github.com/safing/portbase/utils" ) -var ( - updateStoragePath string - downloadTmpPath string +const ( + isWindows = runtime.GOOS == "windows" ) -// SetDatabaseRoot tells the updates module where the database is - and where to put its stuff. -func SetDatabaseRoot(path string) { - if updateStoragePath == "" { - updateStoragePath = filepath.Join(path, "updates") - downloadTmpPath = filepath.Join(updateStoragePath, "tmp") +var ( + updateStorage *utils.DirStructure + tmpStorage *utils.DirStructure +) + +// SetDataRoot sets the data root from which the updates module derives its paths. +func SetDataRoot(root *utils.DirStructure) { + if root != nil && updateStorage == nil { + updateStorage = root.ChildDir("updates", 0755) + tmpStorage = updateStorage.ChildDir("tmp", 0777) } } @@ -30,19 +35,12 @@ func init() { } func prep() error { - dbRoot := database.GetDatabaseRoot() - if dbRoot == "" { - return errors.New("database root is not set") - } - updateStoragePath = filepath.Join(dbRoot, "updates") - downloadTmpPath = filepath.Join(updateStoragePath, "tmp") - - err := utils.EnsureDirectory(updateStoragePath, 0755) - if err != nil { - return err + SetDataRoot(structure.Root()) + if updateStorage == nil { + return errors.New("update storage path is not set") } - err = utils.EnsureDirectory(downloadTmpPath, 0700) + err := updateStorage.Ensure() if err != nil { return err } @@ -61,7 +59,13 @@ func start() error { err = LoadIndexes() if err != nil { if os.IsNotExist(err) { - log.Infof("updates: stable.json does not yet exist, waiting for first update cycle") + // download indexes + log.Infof("updates: downloading update index...") + + err = UpdateIndexes() + if err != nil { + log.Errorf("updates: failed to download update index: %s", err) + } } else { return err } @@ -79,5 +83,5 @@ func start() error { func stop() error { // delete download tmp dir - return os.RemoveAll(downloadTmpPath) + return os.RemoveAll(tmpStorage.Path) } diff --git a/updates/notify.go b/updates/notify.go index cb785952..216b4a00 100644 --- a/updates/notify.go +++ b/updates/notify.go @@ -7,29 +7,33 @@ import ( "github.com/safing/portbase/notifications" ) -const coreIdentifier = "core/portmaster" +const coreIdentifier = "core/portmaster-core" var lastNotified time.Time func updateNotifier() { - time.Sleep(30 * time.Second) + time.Sleep(5 * time.Minute) for { + ident := coreIdentifier + if isWindows { + ident += ".exe" + } - _, version, _, ok := getLatestFilePath(coreIdentifier) - if ok { + file, err := GetLocalPlatformFile(ident) + if err == nil { status.Lock() liveVersion := status.Core.Version status.Unlock() - if version != liveVersion { + if file.Version() != liveVersion { // create notification (¬ifications.Notification{ ID: "updates-core-update-available", - Message: fmt.Sprintf("There is an update available for the Portmaster core (v%s), please restart the Portmaster to apply the update.", version), + Message: fmt.Sprintf("There is an update available for the Portmaster core (v%s), please restart the Portmaster to apply the update.", file.Version()), Type: notifications.Info, Expires: time.Now().Add(1 * time.Minute).Unix(), - }).Init().Save() + }).Save() } } diff --git a/updates/updater.go b/updates/updater.go index 2a01babe..4183b952 100644 --- a/updates/updater.go +++ b/updates/updater.go @@ -5,13 +5,13 @@ import ( "errors" "fmt" "io/ioutil" + "os" "path" "path/filepath" "runtime" "time" "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" ) func updater() { @@ -25,6 +25,14 @@ func updater() { if err != nil { log.Warningf("updates: downloading updates failed: %s", err) } + err = runFileUpgrades() + if err != nil { + log.Warningf("updates: failed to upgrade portmaster-control: %s", err) + } + err = cleanOldUpgradedFiles() + if err != nil { + log.Warningf("updates: failed to clean old upgraded files: %s", err) + } time.Sleep(1 * time.Hour) } } @@ -75,13 +83,13 @@ func UpdateIndexes() (err error) { updatesLock.Unlock() // check dir - err = utils.EnsureDirectory(updateStoragePath, 0755) + err = updateStorage.Ensure() if err != nil { return err } // save stable index - err = ioutil.WriteFile(filepath.Join(updateStoragePath, "stable.json"), data, 0644) + err = ioutil.WriteFile(filepath.Join(updateStorage.Path, "stable.json"), data, 0644) if err != nil { log.Warningf("updates: failed to save new version of stable.json: %s", err) } @@ -107,6 +115,7 @@ func DownloadUpdates() (err error) { markPlatformFileForDownload("control/portmaster-control.exe") markPlatformFileForDownload("app/portmaster-app.exe") markPlatformFileForDownload("notifier/portmaster-notifier.exe") + markPlatformFileForDownload("notifier/portmaster-snoretoast.exe") } else { markPlatformFileForDownload("core/portmaster-core") markPlatformFileForDownload("control/portmaster-control") @@ -115,6 +124,12 @@ func DownloadUpdates() (err error) { } updatesLock.Unlock() + // check download dir + err = tmpStorage.Ensure() + if err != nil { + return fmt.Errorf("could not prepare tmp directory for download: %s", err) + } + // RLock for the remaining function updatesLock.RLock() defer updatesLock.RUnlock() @@ -127,7 +142,7 @@ func DownloadUpdates() (err error) { log.Tracef("updates: updating %s to %s", identifier, newVersion) filePath := GetVersionedPath(identifier, newVersion) - realFilePath := filepath.Join(updateStoragePath, filePath) + realFilePath := filepath.Join(updateStorage.Path, filePath) for tries := 0; tries < 3; tries++ { err = fetchFile(realFilePath, filePath, tries) if err == nil { @@ -142,5 +157,11 @@ func DownloadUpdates() (err error) { } log.Tracef("updates: finished updating existing files") + // remove tmp folder after we are finished + err = os.RemoveAll(tmpStorage.Path) + if err != nil { + log.Tracef("updates: failed to remove tmp dir %s after downloading updates: %s", updateStorage.Path, err) + } + return nil } diff --git a/updates/upgrader.go b/updates/upgrader.go new file mode 100644 index 00000000..d41bca51 --- /dev/null +++ b/updates/upgrader.go @@ -0,0 +1,186 @@ +package updates + +import ( + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/google/renameio" + + "github.com/safing/portbase/log" + + processInfo "github.com/shirou/gopsutil/process" +) + +const ( + upgradedSuffix = "-upgraded" +) + +func runFileUpgrades() error { + filename := "portmaster-control" + if runtime.GOOS == "windows" { + filename += ".exe" + } + + // get newest portmaster-control + newFile, err := GetPlatformFile("control/" + filename) // identifier, use forward slash! + if err != nil { + return err + } + + // update portmaster-control in data root + rootControlPath := filepath.Join(filepath.Dir(updateStorage.Path), filename) + err = upgradeFile(rootControlPath, newFile) + if err != nil { + return err + } + log.Infof("updates: upgraded %s", rootControlPath) + + // upgrade parent process, if it's portmaster-control + parent, err := processInfo.NewProcess(int32(os.Getppid())) + if err != nil { + return fmt.Errorf("could not get parent process for upgrade checks: %s", err) + } + parentName, err := parent.Name() + if err != nil { + return fmt.Errorf("could not get parent process name for upgrade checks: %s", err) + } + if !strings.HasPrefix(parentName, filename) { + log.Tracef("updates: parent process does not seem to be portmaster-control, name is %s", parentName) + return nil + } + parentPath, err := parent.Exe() + if err != nil { + return fmt.Errorf("could not get parent process path for upgrade: %s", err) + } + err = upgradeFile(parentPath, newFile) + if err != nil { + return err + } + log.Infof("updates: upgraded %s", parentPath) + + return nil +} + +func upgradeFile(fileToUpgrade string, file *File) error { + fileExists := false + _, err := os.Stat(fileToUpgrade) + if err == nil { + // file exists and is accessible + fileExists = true + } + + // ensure that the tmp dir exists + err = tmpStorage.Ensure() + if err != nil { + return fmt.Errorf("unable to create directory for upgrade process: %s", err) + } + + if fileExists { + // get current version + var currentVersion string + cmd := exec.Command(fileToUpgrade, "--ver") + out, err := cmd.Output() + if err == nil { + // abort if version matches + currentVersion = strings.Trim(strings.TrimSpace(string(out)), "*") + if currentVersion == file.Version() { + // already up to date! + return nil + } + } else { + log.Warningf("updates: failed to run %s to get version for upgrade check: %s", fileToUpgrade, err) + currentVersion = "0.0.0" + } + + // test currentVersion for sanity + if !rawVersionRegex.MatchString(currentVersion) { + log.Tracef("updates: version string returned by %s is invalid: %s", fileToUpgrade, currentVersion) + currentVersion = "0.0.0" + } + + // try removing old version + err = os.Remove(fileToUpgrade) + if err != nil { + // maybe we're on windows and it's in use, try moving + err = os.Rename(fileToUpgrade, filepath.Join( + tmpStorage.Path, + fmt.Sprintf( + "%s-%d%s", + GetVersionedPath(filepath.Base(fileToUpgrade), currentVersion), + time.Now().UTC().Unix(), + upgradedSuffix, + ), + )) + if err != nil { + return fmt.Errorf("unable to move file that needs upgrade: %s", err) + } + } + } + + // copy upgrade + // TODO: handle copy failure + err = copyFile(file.Path(), fileToUpgrade) + if err != nil { + time.Sleep(1 * time.Second) + // try again + err = copyFile(file.Path(), fileToUpgrade) + if err != nil { + return err + } + } + + // check permissions + if runtime.GOOS != "windows" { + info, err := os.Stat(fileToUpgrade) + if err != nil { + return fmt.Errorf("failed to get file info on %s: %s", fileToUpgrade, err) + } + if info.Mode() != 0755 { + err := os.Chmod(fileToUpgrade, 0755) + if err != nil { + return fmt.Errorf("failed to set permissions on %s: %s", fileToUpgrade, err) + } + } + } + return nil +} + +func copyFile(srcPath, dstPath string) (err error) { + // open file for writing + atomicDstFile, err := renameio.TempFile(tmpStorage.Path, dstPath) + if err != nil { + return fmt.Errorf("could not create temp file for atomic copy: %s", err) + } + defer atomicDstFile.Cleanup() + + // open source + srcFile, err := os.Open(srcPath) + if err != nil { + return + } + defer srcFile.Close() + + // copy data + _, err = io.Copy(atomicDstFile, srcFile) + if err != nil { + return + } + + // finalize file + err = atomicDstFile.CloseAtomicallyReplace() + if err != nil { + return fmt.Errorf("updates: failed to finalize copy to file %s: %s", dstPath, err) + } + + return nil +} + +func cleanOldUpgradedFiles() error { + return os.RemoveAll(tmpStorage.Path) +} diff --git a/updates/uptool/root.go b/updates/uptool/root.go index 8c75312c..f05c3a08 100644 --- a/updates/uptool/root.go +++ b/updates/uptool/root.go @@ -1,23 +1,38 @@ package main import ( - "fmt" "os" + "path/filepath" + + "github.com/safing/portbase/utils" "github.com/spf13/cobra" ) +var ( + updatesStorage *utils.DirStructure +) + var rootCmd = &cobra.Command{ Use: "uptool", Short: "helper tool for the update process", Run: func(cmd *cobra.Command, args []string) { cmd.Usage() }, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + absPath, err := filepath.Abs(".") + if err != nil { + return err + } + + updatesStorage = utils.NewDirStructure(absPath, 0755) + return nil + }, + SilenceUsage: true, } func main() { if err := rootCmd.Execute(); err != nil { - fmt.Println(err) os.Exit(1) } } diff --git a/updates/uptool/update.go b/updates/uptool/update.go new file mode 100644 index 00000000..6e63fd77 --- /dev/null +++ b/updates/uptool/update.go @@ -0,0 +1,44 @@ +package main + +import ( + "encoding/json" + "io/ioutil" + + "github.com/safing/portmaster/updates" + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(updateCmd) +} + +var updateCmd = &cobra.Command{ + Use: "update", + Short: "Update scans the current directory and updates the index and symlink structure", + RunE: update, +} + +func update(cmd *cobra.Command, args []string) error { + + latest, err := updates.ScanForLatest(".", true) + if err != nil { + return err + } + + data, err := json.MarshalIndent(latest, "", " ") + if err != nil { + return err + } + + err = ioutil.WriteFile("stable.json", data, 0755) + if err != nil { + return err + } + + err = updates.CreateSymlinks(updatesStorage.ChildDir("latest", 0755), updatesStorage, latest) + if err != nil { + return err + } + + return nil +}