diff --git a/core/api.go b/core/api.go index 399903c9..559c7914 100644 --- a/core/api.go +++ b/core/api.go @@ -1,9 +1,7 @@ package core import ( - "context" "net/http" - "time" "github.com/safing/portbase/api" "github.com/safing/portbase/log" @@ -13,16 +11,6 @@ import ( "github.com/safing/portmaster/updates" ) -const ( - eventShutdown = "shutdown" - eventRestart = "restart" -) - -func registerEvents() { - module.RegisterEvent(eventShutdown, true) - module.RegisterEvent(eventRestart, true) -} - func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "core/shutdown", @@ -70,16 +58,8 @@ func registerAPIEndpoints() error { func shutdown(_ *api.Request) (msg string, err error) { log.Warning("core: user requested shutdown via action") - module.StartWorker("shutdown", func(context.Context) error { - // Notify everyone of the shutdown. - module.TriggerEvent(eventShutdown, nil) - // Wait a bit for the event to propagate. - time.Sleep(1 * time.Second) - - // Do not run in worker, as this would block itself here. - go modules.Shutdown() //nolint:errcheck - return nil - }) + // Do not run in worker, as this would block itself here. + go modules.Shutdown() //nolint:errcheck return "shutdown initiated", nil } @@ -88,15 +68,10 @@ func shutdown(_ *api.Request) (msg string, err error) { func restart(_ *api.Request) (msg string, err error) { log.Info("core: user requested restart via action") - module.StartWorker("restart", func(context.Context) error { - // Notify everyone of the shutdown. - module.TriggerEvent(eventRestart, nil) - // Wait a bit for the event to propagate. - time.Sleep(1 * time.Second) - - updates.RestartNow() - return nil - }) + // Trigger restart event instead of shutdown event. + restarting.Set() + // Let the updates module handle restarting. + updates.RestartNow() return "restart initiated", nil } diff --git a/core/core.go b/core/core.go index 9acd4dad..1e08e0f5 100644 --- a/core/core.go +++ b/core/core.go @@ -2,9 +2,13 @@ package core import ( "fmt" + "time" + + "github.com/safing/portbase/config" "github.com/safing/portbase/modules" "github.com/safing/portbase/modules/subsystems" + "github.com/tevino/abool" // module dependencies _ "github.com/safing/portmaster/netenv" @@ -13,8 +17,16 @@ import ( _ "github.com/safing/portmaster/updates" ) +const ( + eventShutdown = "shutdown" + eventRestart = "restart" +) + var ( module *modules.Module + + restarting = abool.New() + devMode = config.Concurrent.GetAsBool(config.CfgDevModeKey, false) ) func init() { @@ -27,6 +39,8 @@ func init() { "config:core/", nil, ) + + modules.SetGlobalShutdownFn(shutdownHook) } func prep() error { @@ -54,3 +68,25 @@ func start() error { return nil } + +func registerEvents() { + module.RegisterEvent(eventShutdown, true) + module.RegisterEvent(eventRestart, true) +} + +func shutdownHook() { + // Don't trigger event in Dev Mode. + if devMode() { + return + } + + // Notify everyone of the restart/shutdown. + if restarting.IsNotSet() { + module.TriggerEvent(eventShutdown, nil) + } else { + module.TriggerEvent(eventRestart, nil) + } + + // Wait a bit for the event to propagate. + time.Sleep(1 * time.Second) +}