From c22814e6e10838dc1bef569154954fb6eb02ea71 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 7 Nov 2024 16:20:58 +0100 Subject: [PATCH] Improve start and shutdown controls and flow --- cmds/portmaster-core/recover_linux.go | 17 +- cmds/portmaster-core/run.go | 99 ++++++---- cmds/portmaster-core/run_linux.go | 143 +++++++------- cmds/portmaster-core/run_windows.go | 262 ++++++++++++++++---------- cmds/portmaster-core/update.go | 5 +- service/instance.go | 59 ++++-- spn/access/module.go | 4 +- 7 files changed, 355 insertions(+), 234 deletions(-) diff --git a/cmds/portmaster-core/recover_linux.go b/cmds/portmaster-core/recover_linux.go index 8ce989ba..6f5532c2 100644 --- a/cmds/portmaster-core/recover_linux.go +++ b/cmds/portmaster-core/recover_linux.go @@ -8,17 +8,28 @@ import ( "strings" "github.com/hashicorp/go-multierror" + "github.com/spf13/cobra" "github.com/safing/portmaster/service/firewall/interception" ) -var recoverIPTables bool +var ( + recoverCmd = &cobra.Command{ + Use: "recover-iptables", + Short: "Force an update of all components.", + RunE: update, + } + + recoverIPTables bool +) func init() { - flag.BoolVar(&recoverIPTables, "recover-iptables", false, "recovers ip table rules") + rootCmd.AddCommand(recoverCmd) + + flag.BoolVar(&recoverIPTables, "recover-iptables", false, "recovers ip table rules (backward compatibility; use command instead)") } -func recoverIPTablesCmd() error { +func recover(cmd *cobra.Command, args []string) error { // interception.DeactiveNfqueueFirewall uses coreos/go-iptables // which shells out to the /sbin/iptables binary. As a result, // we don't get the errno of the actual error and need to parse the diff --git a/cmds/portmaster-core/run.go b/cmds/portmaster-core/run.go index 0cbe4812..4c11a1cd 100644 --- a/cmds/portmaster-core/run.go +++ b/cmds/portmaster-core/run.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "errors" "flag" "fmt" @@ -9,44 +8,43 @@ import ( "log/slog" "os" "runtime/pprof" - "syscall" + "time" "github.com/spf13/cobra" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) -var ( - printStackOnExit bool - enableInputSignals bool - - sigUSR1 = syscall.Signal(0xa) // dummy for windows -) +var printStackOnExit bool func init() { flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down") - flag.BoolVar(&enableInputSignals, "input-signals", false, "emulate signals using stdin") +} + +type SystemService interface { + Run() + IsService() bool + RestartService() error } func cmdRun(cmd *cobra.Command, args []string) { - // Call platform specific checks, that will execute commands like "recover-iptables" - platformSpecificChecks() + // Run platform specific setup or switches. + runPlatformSpecifics(cmd, args) + + // SETUP svcCfg.VerifyBinaryUpdates = nil // FIXME svcCfg.VerifyIntelUpdates = nil // FIXME - instance := createInstance() - run(instance) -} - -func createInstance() *service.Instance { // enable SPN client mode conf.EnableClient(true) conf.EnableIntegration(true) // Create instance. + // Instance modules might request a cmdline execution of a function. var execCmdLine bool instance, err := service.New(svcCfg) switch { @@ -76,7 +74,56 @@ func createInstance() *service.Instance { } os.Exit(0) } - return instance + + // START + + // Set default log level and start logging. + log.SetLogLevel(log.WarningLevel) + _ = log.Start() + + // Create system service. + service := NewSystemService(instance) + + // Start instance via system service manager. + go func() { + service.Run() + }() + + // SHUTDOWN + + // Wait for shutdown to be started. + <-instance.ShuttingDown() + + // Wait for shutdown to be finished. + select { + case <-instance.ShutdownComplete(): + case <-time.After(3 * time.Minute): + printStackTo(os.Stderr, "PRINTING STACK - TAKING TOO LONG FOR SHUTDOWN") + } + + // Stop logging. + log.Shutdown() + + // Print stack on shutdown, if enabled. + if printStackOnExit { + printStackTo(os.Stdout, "PRINTING STACK ON EXIT") + } + + // Check if restart was triggered and send start service command if true. + if instance.ShouldRestart && service.IsService() { + if err := service.RestartService(); err != nil { + slog.Error("failed to restart service", "err", err) + } + } + + // Give a small amount of time for everything to settle: + // - All logs written. + // - Restart command started, if needed. + // - Windows service manager notified. + time.Sleep(100 * time.Millisecond) + + // Exit + os.Exit(instance.ExitCode()) } func printStackTo(writer io.Writer, msg string) { @@ -88,21 +135,3 @@ func printStackTo(writer io.Writer, msg string) { slog.Error("failed to write stack trace", "err", err) } } - -func inputSignals(signalCh chan os.Signal) { - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - switch scanner.Text() { - case "SIGHUP": - signalCh <- syscall.SIGHUP - case "SIGINT": - signalCh <- syscall.SIGINT - case "SIGQUIT": - signalCh <- syscall.SIGQUIT - case "SIGTERM": - signalCh <- syscall.SIGTERM - case "SIGUSR1": - signalCh <- sigUSR1 - } - } -} diff --git a/cmds/portmaster-core/run_linux.go b/cmds/portmaster-core/run_linux.go index 25c6da0b..7859ed26 100644 --- a/cmds/portmaster-core/run_linux.go +++ b/cmds/portmaster-core/run_linux.go @@ -7,39 +7,37 @@ import ( "os/exec" "os/signal" "syscall" - "time" processInfo "github.com/shirou/gopsutil/process" + "github.com/spf13/cobra" - "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service" ) -func run(instance *service.Instance) { - // Set default log level. - log.SetLogLevel(log.WarningLevel) - _ = log.Start() +type LinuxSystemService struct { + instance *service.Instance +} - // Start - go func() { - err := instance.Start() - if err != nil { - fmt.Printf("instance start failed: %s\n", err) +func NewSystemService(instance *service.Instance) *LinuxSystemService { + return &LinuxSystemService{instance: instance} +} - // Print stack on start failure, if enabled. - if printStackOnExit { - printStackTo(os.Stdout, "PRINTING STACK ON START FAILURE") - } +func (s *LinuxSystemService) Run() { + // Start instance. + err := s.instance.Start() + if err != nil { + slog.Error("failed to start", "err", err) - os.Exit(1) + // Print stack on start failure, if enabled. + if printStackOnExit { + printStackTo(os.Stderr, "PRINTING STACK ON START FAILURE") } - }() - // Wait for signal. - signalCh := make(chan os.Signal, 1) - if enableInputSignals { - go inputSignals(signalCh) + os.Exit(1) } + + // Subscribe to signals. + signalCh := make(chan os.Signal, 1) signal.Notify( signalCh, os.Interrupt, @@ -47,74 +45,61 @@ func run(instance *service.Instance) { syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, - sigUSR1, + syscall.SIGUSR1, ) - select { - case sig := <-signalCh: - // Only print and continue to wait if SIGUSR1 - if sig == sigUSR1 { - printStackTo(os.Stderr, "PRINTING STACK ON REQUEST") - } else { - fmt.Println(" ") // CLI output. - slog.Warn("program was interrupted, stopping") + // Wait for shutdown signal. +wait: + for { + select { + case sig := <-signalCh: + // Only print and continue to wait if SIGUSR1 + if sig == syscall.SIGUSR1 { + printStackTo(os.Stdout, "PRINTING STACK ON REQUEST") + continue wait + } else { + // Trigger shutdown. + fmt.Printf(" ", sig) // CLI output. + slog.Warn("received stop signal", "signal", sig) + s.instance.Shutdown() + break wait + } + case <-s.instance.ShuttingDown(): + break wait } - - case <-instance.Stopped(): - log.Shutdown() - os.Exit(instance.ExitCode()) } + // Wait for shutdown to finish. + // Catch signals during shutdown. - // Rapid unplanned disassembly after 5 interrupts. - go func() { - forceCnt := 5 - for { - <-signalCh - forceCnt-- - if forceCnt > 0 { - fmt.Printf(" again, but already shutting down - %d more to force\n", forceCnt) - } else { - printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") - os.Exit(1) + // Force exit after 5 interrupts. + forceCnt := 5 + for { + select { + case <-s.instance.ShutdownComplete(): + return + case sig := <-signalCh: + if sig != syscall.SIGUSR1 { + forceCnt-- + if forceCnt > 0 { + fmt.Printf(" again, but already shutting down - %d more to force\n", sig, forceCnt) + } else { + printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") + os.Exit(1) + } } } - }() - - // Rapid unplanned disassembly after 3 minutes. - go func() { - time.Sleep(3 * time.Minute) - printStackTo(os.Stderr, "PRINTING STACK - TAKING TOO LONG FOR SHUTDOWN") - os.Exit(1) - }() - - // Stop instance. - if err := instance.Stop(); err != nil { - slog.Error("failed to stop", "err", err) } - log.Shutdown() - - // Print stack on shutdown, if enabled. - if printStackOnExit { - printStackTo(os.Stdout, "PRINTING STACK ON EXIT") - } - - // Check if restart was trigger and send start service command if true. - if isRunningAsService() && instance.ShouldRestart { - _ = runServiceRestart() - } - - os.Exit(instance.ExitCode()) } -func runServiceRestart() error { +func (s *LinuxSystemService) RestartService() error { // Check if user defined custom command for restarting the service. restartCommand, exists := os.LookupEnv("PORTMASTER_RESTART_COMMAND") // Run the service restart var cmd *exec.Cmd if exists && restartCommand != "" { - log.Debugf(`instance: running custom restart command: "%s"`, restartCommand) + slog.Debug("running custom restart command", "command", restartCommand) cmd = exec.Command("sh", "-c", restartCommand) } else { cmd = exec.Command("systemctl", "restart", "portmaster") @@ -125,28 +110,30 @@ func runServiceRestart() error { return nil } -func isRunningAsService() bool { - // Get the current process ID +func (s *LinuxSystemService) IsService() bool { + // Get own process ID pid := os.Getpid() + // Get parent process ID. currentProcess, err := processInfo.NewProcess(int32(pid)) //nolint:gosec if err != nil { return false } - ppid, err := currentProcess.Ppid() if err != nil { return false } + // Check if the parent process ID is 1 == init system return ppid == 1 } -func platformSpecificChecks() { - // If flag is set. Run recover IP tables and exit. (Can be true only on linux) +func runPlatformSpecifics(cmd *cobra.Command, args []string) { + // If recover-iptables flag is set, run the recover-iptables command. + // This is for backwards compatibility if recoverIPTables { exitCode := 0 - err := recoverIPTablesCmd() + err := recover(cmd, args) if err != nil { fmt.Printf("failed: %s", err) exitCode = 1 diff --git a/cmds/portmaster-core/run_windows.go b/cmds/portmaster-core/run_windows.go index 8e0d9a6d..3976745f 100644 --- a/cmds/portmaster-core/run_windows.go +++ b/cmds/portmaster-core/run_windows.go @@ -12,138 +12,165 @@ import ( "os/exec" "os/signal" "syscall" - "time" - "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service" + "github.com/spf13/cobra" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/debug" ) const serviceName = "PortmasterCore" -type windowsService struct { +type WindowsSystemService struct { instance *service.Instance } -func (ws *windowsService) Execute(args []string, changeRequests <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) { - const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown - changes <- svc.Status{State: svc.StartPending} - ws.instance.Start() - changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} +func NewSystemService(instance *service.Instance) *WindowsSystemService { + return &WindowsSystemService{instance: instance} +} -service: +func (s *WindowsSystemService) Run() { + svcRun := svc.Run + + // Check if we are running interactively. + isService, err := svc.IsWindowsService() + switch { + case err != nil: + slog.Warn("failed to determine if running interactively", "err", err) + slog.Warn("continuing without service integration (no real service)") + svcRun = debug.Run + + case !isService: + slog.Warn("running interactively, switching to debug execution (no real service)") + svcRun = debug.Run + } + + // Run service client. + err = svcRun(serviceName, s) + if err != nil { + slog.Error("service execution failed", "err", err) + os.Exit(1) + } + + // Execution continues in s.Execute(). +} + +func (s *WindowsSystemService) Execute(args []string, changeRequests <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { + // Tell service manager we are starting. + changes <- svc.Status{State: svc.StartPending} + + // Start instance. + err := s.instance.Start() + if err != nil { + fmt.Printf("failed to start: %s\n", err) + + // Print stack on start failure, if enabled. + if printStackOnExit { + printStackTo(os.Stderr, "PRINTING STACK ON START FAILURE") + } + + // Notify service manager we stopped again. + changes <- svc.Status{State: svc.Stopped} + + // Relay exit code to service manager. + return false, 1 + } + + // Tell service manager we are up and running! + changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} + + // Subscribe to signals. + // Docs: https://pkg.go.dev/os/signal?GOOS=windows + signalCh := make(chan os.Signal, 4) + signal.Notify( + signalCh, + + // Windows ^C (Control-C) or ^BREAK (Control-Break). + // Completely prevents kill. + os.Interrupt, + + // Windows CTRL_CLOSE_EVENT, CTRL_LOGOFF_EVENT or CTRL_SHUTDOWN_EVENT. + // Does not prevent kill, but gives a little time to stop service. + syscall.SIGTERM, + ) + + // Wait for shutdown signal. +waitSignal: for { select { - case <-ws.instance.Stopped(): - log.Infof("instance stopped") - break service + case sig := <-signalCh: + // Trigger shutdown. + fmt.Printf(" ", sig) // CLI output. + slog.Warn("received stop signal", "signal", sig) + break waitSignal + case c := <-changeRequests: switch c.Cmd { case svc.Interrogate: changes <- c.CurrentStatus + case svc.Stop, svc.Shutdown: - log.Debugf("received shutdown command") - changes <- svc.Status{State: svc.StopPending} - ws.instance.Shutdown() + fmt.Printf(" ", serviceCmdName(c.Cmd)) // CLI output. + slog.Warn("received service shutdown command", "cmd", c.Cmd) + break waitSignal + default: - log.Errorf("unexpected control request: %+v", c) + slog.Error("unexpected service control request", "cmd", serviceCmdName(c.Cmd)) } + + case <-s.instance.ShuttingDown(): + break waitSignal } } - log.Shutdown() - - // send stopped status - changes <- svc.Status{State: svc.Stopped} - // wait a little for the status to reach Windows - time.Sleep(100 * time.Millisecond) - - return ssec, errno -} - -func run(instance *service.Instance) error { - log.SetLogLevel(log.WarningLevel) - _ = log.Start() - - // check if we are running interactively - isService, err := svc.IsWindowsService() - if err != nil { - return fmt.Errorf("could not determine if running interactively: %s", err) - } - - // select service run type - svcRun := svc.Run - if !isService { - log.Warningf("running interactively, switching to debug execution (no real service).") - svcRun = debug.Run - go registerSignalHandler(instance) - } - - // run service client - sErr := svcRun(serviceName, &windowsService{ - instance: instance, - }) - if sErr != nil { - fmt.Printf("shuting down service with error: %s", sErr) - } else { - fmt.Printf("shuting down service") - } - - // Check if restart was trigger and send start service command if true. - if isRunningAsService() && instance.ShouldRestart { - _ = runServiceRestart() - } - - return err -} - -func registerSignalHandler(instance *service.Instance) { - // Wait for signal. - signalCh := make(chan os.Signal, 1) - if enableInputSignals { - go inputSignals(signalCh) - } - signal.Notify( - signalCh, - os.Interrupt, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT, - sigUSR1, - ) - - select { - case sig := <-signalCh: - // Only print and continue to wait if SIGUSR1 - if sig == sigUSR1 { - printStackTo(os.Stderr, "PRINTING STACK ON REQUEST") - } else { - fmt.Println(" ") // CLI output. - slog.Warn("program was interrupted, stopping") - instance.Shutdown() - } - } + // Wait for shutdown to finish. + changes <- svc.Status{State: svc.StopPending} // Catch signals during shutdown. - // Rapid unplanned disassembly after 5 interrupts. - go func() { - forceCnt := 5 - for { - <-signalCh + // Force exit after 5 interrupts. + forceCnt := 5 +waitShutdown: + for { + select { + case <-s.instance.ShutdownComplete(): + break waitShutdown + + case sig := <-signalCh: forceCnt-- if forceCnt > 0 { - fmt.Printf(" again, but already shutting down - %d more to force\n", forceCnt) + fmt.Printf(" but already shutting down - %d more to force\n", sig, forceCnt) } else { printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") os.Exit(1) } + + case c := <-changeRequests: + switch c.Cmd { + case svc.Interrogate: + changes <- c.CurrentStatus + + case svc.Stop, svc.Shutdown: + forceCnt-- + if forceCnt > 0 { + fmt.Printf(" but already shutting down - %d more to force\n", serviceCmdName(c.Cmd), forceCnt) + } else { + printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") + os.Exit(1) + } + + default: + slog.Error("unexpected service control request", "cmd", serviceCmdName(c.Cmd)) + } } - }() + } + + // Notify service manager. + changes <- svc.Status{State: svc.Stopped} + + return false, 0 } -func isRunningAsService() bool { +func (s *WindowsSystemService) IsService() bool { isService, err := svc.IsWindowsService() if err != nil { return false @@ -151,7 +178,7 @@ func isRunningAsService() bool { return isService } -func runServiceRestart() error { +func (s *WindowsSystemService) RestartService() error { // Script that wait for portmaster service status to change to stop // and then sends a start command for the same service. command := ` @@ -172,4 +199,41 @@ sc.exe start $serviceName` return nil } -func platformSpecificChecks() {} +func runPlatformSpecifics(cmd *cobra.Command, args []string) + +func serviceCmdName(cmd svc.Cmd) string { + switch cmd { + case svc.Stop: + return "Stop" + case svc.Pause: + return "Pause" + case svc.Continue: + return "Continue" + case svc.Interrogate: + return "Interrogate" + case svc.Shutdown: + return "Shutdown" + case svc.ParamChange: + return "ParamChange" + case svc.NetBindAdd: + return "NetBindAdd" + case svc.NetBindRemove: + return "NetBindRemove" + case svc.NetBindEnable: + return "NetBindEnable" + case svc.NetBindDisable: + return "NetBindDisable" + case svc.DeviceEvent: + return "DeviceEvent" + case svc.HardwareProfileChange: + return "HardwareProfileChange" + case svc.PowerEvent: + return "PowerEvent" + case svc.SessionChange: + return "SessionChange" + case svc.PreShutdown: + return "PreShutdown" + default: + return "Unknown Command" + } +} diff --git a/cmds/portmaster-core/update.go b/cmds/portmaster-core/update.go index ceb0ff6c..a3158dea 100644 --- a/cmds/portmaster-core/update.go +++ b/cmds/portmaster-core/update.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "log/slog" "github.com/spf13/cobra" @@ -53,11 +54,11 @@ func update(cmd *cobra.Command, args []string) error { // Force update all. binErr := binaryUpdates.ForceUpdate() if binErr != nil { - log.Errorf("binary update failed: %s", binErr) + slog.Error("binary update failed", "err", binErr) } intelErr := intelUpdates.ForceUpdate() if intelErr != nil { - log.Errorf("intel update failed: %s", intelErr) + slog.Error("intel update failed", "err", intelErr) } // Return error. diff --git a/service/instance.go b/service/instance.go index 5bfeb7a1..8fc9a17c 100644 --- a/service/instance.go +++ b/service/instance.go @@ -49,8 +49,12 @@ import ( // Instance is an instance of a Portmaster service. type Instance struct { - ctx context.Context - cancelCtx context.CancelFunc + ctx context.Context + cancelCtx context.CancelFunc + + shutdownCtx context.Context + cancelShutdownCtx context.CancelFunc + serviceGroup *mgr.Group binDir string @@ -607,20 +611,13 @@ func (i *Instance) Ready() bool { return i.serviceGroup.Ready() } -// Ctx returns the instance context. -// It is only canceled on shutdown. -func (i *Instance) Ctx() context.Context { - return i.ctx -} - -// Start starts the instance. +// Start starts the instance modules. func (i *Instance) Start() error { return i.serviceGroup.Start() } -// Stop stops the instance and cancels the instance context when done. +// Stop stops the instance modules. func (i *Instance) Stop() error { - defer i.cancelCtx() return i.serviceGroup.Stop() } @@ -649,9 +646,19 @@ func (i *Instance) Shutdown() { } func (i *Instance) shutdown(exitCode int) { + // Only shutdown once. + if i.IsShuttingDown() { + return + } + // Set given exit code. i.exitCode.Store(int32(exitCode)) + // Cancel contexts. + i.cancelCtx() + defer i.cancelShutdownCtx() + + // Start shutdown asynchronously in a separate manager. m := mgr.New("instance") m.Go("shutdown", func(w *mgr.WorkerCtx) error { for { @@ -665,16 +672,38 @@ func (i *Instance) shutdown(exitCode int) { }) } -// Stopping returns whether the instance is shutting down. -func (i *Instance) Stopping() bool { +// Ctx returns the instance context. +// It is canceled when shutdown is started. +func (i *Instance) Ctx() context.Context { + return i.ctx +} + +// IsShuttingDown returns whether the instance is shutting down. +func (i *Instance) IsShuttingDown() bool { return i.ctx.Err() != nil } -// Stopped returns a channel that is triggered when the instance has shut down. -func (i *Instance) Stopped() <-chan struct{} { +// ShuttingDown returns a channel that is triggered when the instance starts shutting down. +func (i *Instance) ShuttingDown() <-chan struct{} { return i.ctx.Done() } +// ShutdownCtx returns the instance shutdown context. +// It is canceled when shutdown is complete. +func (i *Instance) ShutdownCtx() context.Context { + return i.shutdownCtx +} + +// IsShutDown returns whether the instance has stopped. +func (i *Instance) IsShutDown() bool { + return i.shutdownCtx.Err() != nil +} + +// ShutDownComplete returns a channel that is triggered when the instance has shut down. +func (i *Instance) ShutdownComplete() <-chan struct{} { + return i.shutdownCtx.Done() +} + // ExitCode returns the set exit code of the instance. func (i *Instance) ExitCode() int { return int(i.exitCode.Load()) diff --git a/spn/access/module.go b/spn/access/module.go index d49524f9..a805240a 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -80,7 +80,7 @@ func start() error { // Add config listener to enable/disable SPN. module.instance.Config().EventConfigChange.AddCallback("spn enable check", func(wc *mgr.WorkerCtx, s struct{}) (bool, error) { // Do not do anything when we are shutting down. - if module.instance.Stopping() { + if module.instance.IsShuttingDown() { return true, nil } @@ -255,5 +255,5 @@ func New(instance instance) (*Access, error) { type instance interface { Config() *config.Config SPNGroup() *mgr.ExtendedGroup - Stopping() bool + IsShuttingDown() bool }