From ddf7ba170ea0d297d0b89e25ff21ccd932b4ff4d Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 8 Nov 2024 14:39:58 +0100 Subject: [PATCH] Fix tests and issues --- base/config/init_test.go | 14 ++-- cmds/updatemgr/scan.go | 6 +- service/instance.go | 22 +++--- service/updates/downloader.go | 15 ++-- service/updates/index.go | 38 +++++----- service/updates/index_scan.go | 103 +++---------------------- service/updates/module.go | 98 +++++++++++++++++------- service/updates/updates_test.go | 128 ++++++++++++++++++++++++-------- service/updates/upgrade.go | 43 ++++++----- spn/instance.go | 89 ++++++++++++++++------ 10 files changed, 325 insertions(+), 231 deletions(-) diff --git a/base/config/init_test.go b/base/config/init_test.go index e3390379..d5d90aea 100644 --- a/base/config/init_test.go +++ b/base/config/init_test.go @@ -29,21 +29,23 @@ func newTestInstance(testName string) (*testInstance, error) { }, nil } -func TestConfigPersistence(t *testing.T) { - t.Parallel() - +func TestMain(m *testing.M) { instance, err := newTestInstance("test-config") if err != nil { - t.Fatalf("failed to create test instance: %s", err) + panic(fmt.Errorf("failed to create test instance: %w", err)) } defer func() { _ = os.RemoveAll(instance.DataDir()) }() module, err = New(instance) if err != nil { - t.Fatalf("failed to initialize module: %s", err) + panic(fmt.Errorf("failed to initialize module: %w", err)) } - err = SaveConfig() + m.Run() +} + +func TestConfigPersistence(t *testing.T) { //nolint:paralleltest + err := SaveConfig() if err != nil { t.Fatal(err) } diff --git a/cmds/updatemgr/scan.go b/cmds/updatemgr/scan.go index db60058f..9ef29f15 100644 --- a/cmds/updatemgr/scan.go +++ b/cmds/updatemgr/scan.go @@ -70,16 +70,16 @@ func init() { } func scan(cmd *cobra.Command, args []string) error { - bundle, err := updates.GenerateIndexFromDir(scanDir, scanConfig) + index, err := updates.GenerateIndexFromDir(scanDir, scanConfig) if err != nil { return err } - bundleStr, err := json.MarshalIndent(&bundle, "", " ") + indexJson, err := json.MarshalIndent(&index, "", " ") if err != nil { return fmt.Errorf("marshal index: %w", err) } - fmt.Printf("%s", bundleStr) + fmt.Printf("%s", indexJson) return nil } diff --git a/service/instance.go b/service/instance.go index 8fc9a17c..ddb23671 100644 --- a/service/instance.go +++ b/service/instance.go @@ -130,6 +130,7 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { //nolint:maintidx dataDir: svcCfg.DataDir, } instance.ctx, instance.cancelCtx = context.WithCancel(context.Background()) + instance.shutdownCtx, instance.cancelShutdownCtx = context.WithCancel(context.Background()) // Base modules instance.base, err = base.New(instance) @@ -651,24 +652,23 @@ func (i *Instance) shutdown(exitCode int) { return } + // Cancel main context. + i.cancelCtx() + // Set given exit code. i.exitCode.Store(int32(exitCode)) - // Cancel contexts. - i.cancelCtx() - defer i.cancelShutdownCtx() - // Start shutdown asynchronously in a separate manager. m := mgr.New("instance") m.Go("shutdown", func(w *mgr.WorkerCtx) error { - for { - if err := i.Stop(); err != nil { - w.Error("failed to shutdown", "err", err, "retry", "1s") - time.Sleep(1 * time.Second) - } else { - return nil - } + // Stop all modules. + if err := i.Stop(); err != nil { + w.Error("failed to shutdown", "err", err) } + + // Cancel shutdown process context. + i.cancelShutdownCtx() + return nil }) } diff --git a/service/updates/downloader.go b/service/updates/downloader.go index 8d0d8b24..30c32261 100644 --- a/service/updates/downloader.go +++ b/service/updates/downloader.go @@ -56,7 +56,7 @@ func (d *Downloader) updateIndex(ctx context.Context) error { } log.Warningf("updates/%s: failed to update index from %q: %s", d.u.cfg.Name, url, err) - err = fmt.Errorf("update index file from %q: %s", url, err) + err = fmt.Errorf("update index file from %q: %w", url, err) } if err != nil { return fmt.Errorf("all index URLs failed, last error: %w", err) @@ -65,7 +65,7 @@ func (d *Downloader) updateIndex(ctx context.Context) error { // Write the index into a file. indexFilepath := filepath.Join(d.u.cfg.DownloadDirectory, d.u.cfg.IndexFile) - err = os.WriteFile(indexFilepath, []byte(indexData), defaultFileMode) + err = os.WriteFile(indexFilepath, indexData, defaultFileMode) if err != nil { return fmt.Errorf("write index file: %w", err) } @@ -111,7 +111,7 @@ func (d *Downloader) gatherExistingFiles(dir string) error { // Read full file. fileData, err := os.ReadFile(fullpath) if err != nil { - log.Debugf("updates/%s: failed to read file %q while searching for existing files: %w", d.u.cfg.Name, fullpath, err) + log.Debugf("updates/%s: failed to read file %q while searching for existing files: %s", d.u.cfg.Name, fullpath, err) return fmt.Errorf("failed to read file %s: %w", fullpath, err) } @@ -150,7 +150,12 @@ artifacts: if err == nil { continue artifacts } - log.Debugf("updates/%s: failed to copy existing file %s: %w", d.u.cfg.Name, artifact.Filename, err) + log.Debugf("updates/%s: failed to copy existing file %s: %s", d.u.cfg.Name, artifact.Filename, err) + } + + // Check if the artifact has download URLs. + if len(artifact.URLs) == 0 { + return fmt.Errorf("artifact %s is missing download URLs", artifact.Filename) } // Try to download the artifact from one of the URLs. @@ -163,7 +168,7 @@ artifacts: // Valid artifact found! break artifactURLs } - err = fmt.Errorf("update index file from %q: %s", url, err) + err = fmt.Errorf("update index file from %q: %w", url, err) } if err != nil { return fmt.Errorf("all artifact URLs for %s failed, last error: %w", artifact.Filename, err) diff --git a/service/updates/index.go b/service/updates/index.go index f627d0d3..8a30f643 100644 --- a/service/updates/index.go +++ b/service/updates/index.go @@ -15,6 +15,7 @@ import ( "time" semver "github.com/hashicorp/go-version" + "github.com/safing/jess" "github.com/safing/jess/filesig" ) @@ -26,7 +27,7 @@ const currentPlatform = runtime.GOOS + "_" + runtime.GOARCH var zeroVersion = semver.Must(semver.NewVersion("0.0.0")) -// Artifacts represents a single file with metadata. +// Artifact represents a single file with metadata. type Artifact struct { Filename string `json:"Filename"` SHA256 string `json:"SHA256"` @@ -85,7 +86,7 @@ func (a *Artifact) IsNewerThan(b *Artifact) (newer, ok bool) { } func (a *Artifact) export(dir string, indexVersion *semver.Version) *Artifact { - copy := &Artifact{ + copied := &Artifact{ Filename: a.Filename, SHA256: a.SHA256, URLs: a.URLs, @@ -98,20 +99,20 @@ func (a *Artifact) export(dir string, indexVersion *semver.Version) *Artifact { // Make sure we have a version number. switch { - case copy.versionNum != nil: + case copied.versionNum != nil: // Version already parsed. - case copy.Version != "": + case copied.Version != "": // Need to parse version. - v, err := semver.NewVersion(copy.Version) + v, err := semver.NewVersion(copied.Version) if err == nil { - copy.versionNum = v + copied.versionNum = v } default: // No version defined, inherit index version. - copy.versionNum = indexVersion + copied.versionNum = indexVersion } - return copy + return copied } // Index represents a collection of artifacts with metadata. @@ -146,8 +147,8 @@ func ParseIndex(jsonContent []byte, trustStore jess.TrustStore) (*Index, error) } // Parse json. - var index Index - err := json.Unmarshal([]byte(jsonContent), &index) + index := &Index{} + err := json.Unmarshal(jsonContent, index) if err != nil { return nil, fmt.Errorf("parse index: %w", err) } @@ -158,7 +159,7 @@ func ParseIndex(jsonContent []byte, trustStore jess.TrustStore) (*Index, error) return nil, err } - return &index, nil + return index, nil } func (index *Index) init() error { @@ -219,7 +220,7 @@ func (index *Index) ShouldUpgradeTo(newIndex *Index) error { return fmt.Errorf("current index cannot do upgrades: %w", err) } if err := newIndex.CanDoUpgrades(); err != nil { - return fmt.Errorf("new index cannot do upgrade: %w") + return fmt.Errorf("new index cannot do upgrade: %w", err) } switch { @@ -229,13 +230,14 @@ func (index *Index) ShouldUpgradeTo(newIndex *Index) error { return nil case index.Name != newIndex.Name: - return errors.New("index names do not match") - - case index.versionNum.GreaterThan(newIndex.versionNum): - return errors.New("current index has newer version") + return errors.New("new index name does not match") case index.Published.After(newIndex.Published): - return errors.New("current index was published later") + return errors.New("new index is older (time)") + + case index.versionNum.Segments()[0] > newIndex.versionNum.Segments()[0]: + // Downgrades are allowed, if they are not breaking changes. + return errors.New("new index is a breaking change downgrade") case index.Published.Equal(newIndex.Published): // "Do nothing". @@ -252,7 +254,7 @@ func (index *Index) VerifyArtifacts(dir string) error { for _, artifact := range index.Artifacts { err := checkSHA256SumFile(filepath.Join(dir, artifact.Filename), artifact.SHA256) if err != nil { - return fmt.Errorf("verify %s: %s", artifact.Filename, err) + return fmt.Errorf("verify %s: %w", artifact.Filename, err) } } diff --git a/service/updates/index_scan.go b/service/updates/index_scan.go index aa61417e..6e6a6af2 100644 --- a/service/updates/index_scan.go +++ b/service/updates/index_scan.go @@ -3,7 +3,6 @@ package updates import ( "crypto/sha256" "encoding/hex" - "encoding/json" "errors" "fmt" "io/fs" @@ -95,7 +94,7 @@ settings: } // GenerateIndexFromDir generates a index from a given folder. -func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) { +func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) { //nolint:maintidx artifacts := make(map[string]Artifact) // Initialize. @@ -107,6 +106,13 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) if err != nil { return nil, fmt.Errorf("invalid index dir: %w", err) } + var indexVersion *semver.Version + if cfg.Version != "" { + indexVersion, err = semver.NewVersion(cfg.Version) + if err != nil { + return nil, fmt.Errorf("invalid index version: %w", err) + } + } err = filepath.WalkDir(sourceDir, func(fullpath string, d fs.DirEntry, err error) error { // Fail on access error. @@ -227,9 +233,10 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) // Create base index. index := &Index{ - Name: cfg.Name, - Version: cfg.Version, - Published: time.Now(), + Name: cfg.Name, + Version: cfg.Version, + Published: time.Now(), + versionNum: indexVersion, } if index.Version == "" && cfg.PrimaryArtifact != "" { pv, ok := artifacts[cfg.PrimaryArtifact] @@ -286,45 +293,6 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) return index, nil } -func selectLatestArtifacts(artifacts []Artifact) ([]Artifact, error) { - artifactsMap := make(map[string]Artifact) - - for _, a := range artifacts { - // Make the key platform specific since there can be same filename for multiple platforms. - key := a.Filename + a.Platform - aMap, ok := artifactsMap[key] - if !ok { - artifactsMap[key] = a - continue - } - - if aMap.Version == "" || a.Version == "" { - return nil, fmt.Errorf("invalid mix version and non versioned files for: %s", a.Filename) - } - - mapVersion, err := semver.NewVersion(aMap.Version) - if err != nil { - return nil, fmt.Errorf("invalid version for artifact: %s", aMap.Filename) - } - - artifactVersion, err := semver.NewVersion(a.Version) - if err != nil { - return nil, fmt.Errorf("invalid version for artifact: %s", a.Filename) - } - - if mapVersion.LessThan(artifactVersion) { - artifactsMap[key] = a - } - } - - artifactsFiltered := make([]Artifact, 0, len(artifactsMap)) - for _, a := range artifactsMap { - artifactsFiltered = append(artifactsFiltered, a) - } - - return artifactsFiltered, nil -} - func getSHA256(path string, unpackType string) (string, error) { content, err := os.ReadFile(path) if err != nil { @@ -372,50 +340,3 @@ func getIdentifierAndVersion(versionedPath string) (identifier, version string, // `dirPath + filename` is guaranteed by path.Split() return dirPath + filename, version, true } - -// GenerateMockFolder generates mock index folder for testing. -func GenerateMockFolder(dir, name, version string) error { // FIXME: move this to test? - // Make sure dir exists - _ = os.MkdirAll(dir, defaultDirMode) - - // Create empty files - file, err := os.Create(filepath.Join(dir, "portmaster")) - if err != nil { - return err - } - _ = file.Close() - file, err = os.Create(filepath.Join(dir, "portmaster-core")) - if err != nil { - return err - } - _ = file.Close() - file, err = os.Create(filepath.Join(dir, "portmaster.zip")) - if err != nil { - return err - } - _ = file.Close() - file, err = os.Create(filepath.Join(dir, "assets.zip")) - if err != nil { - return err - } - _ = file.Close() - - index, err := GenerateIndexFromDir(dir, IndexScanConfig{ - Name: name, - Version: version, - }) - if err != nil { - return err - } - - indexJson, err := json.MarshalIndent(index, "", " ") - if err != nil { - fmt.Fprintf(os.Stderr, "failed to marshal index: %s\n", err) - } - - err = os.WriteFile(filepath.Join(dir, "index.json"), indexJson, defaultFileMode) - if err != nil { - return err - } - return nil -} diff --git a/service/updates/module.go b/service/updates/module.go index 8b04d3db..4ae6022c 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -10,11 +10,12 @@ import ( "sync" "time" + "github.com/tevino/abool" + "github.com/safing/jess" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/mgr" - "github.com/tevino/abool" ) const ( @@ -169,7 +170,7 @@ func New(instance instance, name string, cfg Config) (*Updater, error) { // Fall back to scanning the directory. if !errors.Is(err, os.ErrNotExist) { - log.Errorf("updates/%s: invalid index file, falling back to dir scan: %w", cfg.Name, err) + log.Errorf("updates/%s: invalid index file, falling back to dir scan: %s", cfg.Name, err) } index, err = GenerateIndexFromDir(cfg.Directory, IndexScanConfig{Version: "0.0.0"}) if err == nil && index.init() == nil { @@ -181,13 +182,12 @@ func New(instance instance, name string, cfg Config) (*Updater, error) { return module, nil } -func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreVersion, forceApply bool) (err error) { +func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreVersion, forceApply bool) (err error) { //nolint:maintidx // Make sure only one update process is running. if !u.isUpdateRunning.SetToIf(false, true) { return fmt.Errorf("an updater task is already running, please try again later") } defer u.isUpdateRunning.UnSet() - // FIXME: Switch to mutex? // Create a new downloader. downloader := NewDownloader(u, indexURLs) @@ -201,7 +201,7 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV } } else { // Otherwise, load index from download dir. - downloader.index, err = LoadIndex(filepath.Join(u.cfg.Directory, u.cfg.IndexFile), u.cfg.Verify) + downloader.index, err = LoadIndex(filepath.Join(u.cfg.DownloadDirectory, u.cfg.IndexFile), u.cfg.Verify) if err != nil { return fmt.Errorf("load previously downloaded index file: %w", err) } @@ -215,23 +215,42 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV u.indexLock.Unlock() // Check with local pointer to index. if err := index.ShouldUpgradeTo(downloader.index); err != nil { - log.Infof("updates/%s: no new or eligible update: %s", u.cfg.Name, err) - if u.cfg.Notify && u.instance.Notifications() != nil { - u.instance.Notifications().Notify(¬ifications.Notification{ - EventID: noNewUpdateNotificationID, - Type: notifications.Info, - Title: "Portmaster Is Up-To-Date", - Message: "Portmaster v" + index.Version + " is the newest version.", - Expires: time.Now().Add(1 * time.Minute).Unix(), - AvailableActions: []*notifications.Action{ - { - ID: "ack", - Text: "OK", + if errors.Is(err, ErrSameIndex) { + log.Infof("updates/%s: no new update", u.cfg.Name) + if u.cfg.Notify && u.instance.Notifications() != nil { + u.instance.Notifications().Notify(¬ifications.Notification{ + EventID: noNewUpdateNotificationID, + Type: notifications.Info, + Title: "Portmaster Is Up-To-Date", + Message: "Portmaster v" + index.Version + " is the newest version.", + Expires: time.Now().Add(1 * time.Minute).Unix(), + AvailableActions: []*notifications.Action{ + { + ID: "ack", + Text: "OK", + }, }, - }, - }) + }) + } + } else { + log.Warningf("updates/%s: cannot update: %s", u.cfg.Name, err) + if u.cfg.Notify && u.instance.Notifications() != nil { + u.instance.Notifications().Notify(¬ifications.Notification{ + EventID: noNewUpdateNotificationID, + Type: notifications.Info, + Title: "Portmaster Is Up-To-Date*", + Message: "While Portmaster v" + index.Version + " is the newest version, there is an internal issue with checking for updates: " + err.Error(), + Expires: time.Now().Add(1 * time.Minute).Unix(), + AvailableActions: []*notifications.Action{ + { + ID: "ack", + Text: "OK", + }, + }, + }) + } } - return ErrNoUpdateAvailable + return fmt.Errorf("%w: %w", ErrNoUpdateAvailable, err) } } @@ -320,7 +339,10 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV // Install is complete! // Clean up and notify modules of changed files. - u.cleanupAfterUpgrade() + err = u.cleanupAfterUpgrade() + if err != nil { + log.Debugf("updates/%s: failed to clean up after upgrade: %s", u.cfg.Name, err) + } u.EventResourcesUpdated.Submit(struct{}{}) // If no restart is needed, we are done. @@ -363,7 +385,7 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV Type: notifications.ActionTypeWebhook, Payload: notifications.ActionTypeWebhookPayload{ Method: "POST", - URL: "updates/apply", // FIXME + URL: "core/restart", }, }, ) @@ -376,15 +398,35 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV } func (u *Updater) updateCheckWorker(w *mgr.WorkerCtx) error { - _ = u.updateAndUpgrade(w, u.cfg.IndexURLs, false, false) - // FIXME: Handle errors. - return nil + err := u.updateAndUpgrade(w, u.cfg.IndexURLs, false, false) + switch { + case err == nil: + return nil // Success! + case errors.Is(err, ErrSameIndex): + return nil // Nothing to do. + case errors.Is(err, ErrNoUpdateAvailable): + return nil // Already logged. + case errors.Is(err, ErrActionRequired) && !u.cfg.Notify: + return fmt.Errorf("user action required, but notifying user is disabled: %w", err) + default: + return fmt.Errorf("udpating failed: %w", err) + } } func (u *Updater) upgradeWorker(w *mgr.WorkerCtx) error { - _ = u.updateAndUpgrade(w, u.cfg.IndexURLs, false, true) - // FIXME: Handle errors. - return nil + err := u.updateAndUpgrade(w, u.cfg.IndexURLs, false, true) + switch { + case err == nil: + return nil // Success! + case errors.Is(err, ErrSameIndex): + return nil // Nothing to do. + case errors.Is(err, ErrNoUpdateAvailable): + return nil // Already logged. + case errors.Is(err, ErrActionRequired) && !u.cfg.Notify: + return fmt.Errorf("user action required, but notifying user is disabled: %w", err) + default: + return fmt.Errorf("udpating failed: %w", err) + } } // ForceUpdate executes a forced update and upgrade directly and synchronously diff --git a/service/updates/updates_test.go b/service/updates/updates_test.go index 7ca86445..d79706c7 100644 --- a/service/updates/updates_test.go +++ b/service/updates/updates_test.go @@ -1,12 +1,15 @@ package updates import ( + "encoding/json" "fmt" "os" "path/filepath" "testing" + "time" "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" ) type testInstance struct{} @@ -24,66 +27,129 @@ func (i *testInstance) Ready() bool { func (i *testInstance) SetCmdLineOperation(f func() error) {} -func TestPreformUpdate(t *testing.T) { +func TestPerformUpdate(t *testing.T) { t.Parallel() // Initialize mock instance stub := &testInstance{} // Make tmp dirs - installedDir, err := os.MkdirTemp("", "updates_current") + installedDir, err := os.MkdirTemp("", "updates_current_") if err != nil { - panic(err) + t.Fatal(err) } defer func() { _ = os.RemoveAll(installedDir) }() - updateDir, err := os.MkdirTemp("", "updates_new") + updateDir, err := os.MkdirTemp("", "updates_new_") if err != nil { - panic(err) + t.Fatal(err) } defer func() { _ = os.RemoveAll(updateDir) }() - purgeDir, err := os.MkdirTemp("", "updates_purge") + purgeDir, err := os.MkdirTemp("", "updates_purge_") if err != nil { - panic(err) + t.Fatal(err) } defer func() { _ = os.RemoveAll(purgeDir) }() // Generate mock files - if err := GenerateMockFolder(installedDir, "Test", "1.0.0"); err != nil { - panic(err) + now := time.Now() + if err := GenerateMockFolder(installedDir, "Test", "1.0.0", now); err != nil { + t.Fatal(err) } - if err := GenerateMockFolder(updateDir, "Test", "1.0.1"); err != nil { - panic(err) + if err := GenerateMockFolder(updateDir, "Test", "1.0.1", now.Add(1*time.Minute)); err != nil { + t.Fatal(err) } - // Create updater - updates, err := New(stub, "Test", Config{ + // Create updater (loads index). + updater, err := New(stub, "Test", Config{ + Name: "Test", Directory: installedDir, DownloadDirectory: updateDir, PurgeDirectory: purgeDir, IndexFile: "index.json", - AutoApply: false, - NeedsRestart: false, + AutoDownload: true, + AutoApply: true, }) if err != nil { - panic(err) - } - // Read and parse the index file - if err := updates.downloader.Verify(); err != nil { - panic(err) + t.Fatal(err) } + // Try to apply the updates - err = updates.applyUpdates(nil) - if err != nil { - panic(err) - } + m := mgr.New("updates test") + _ = m.Do("test update and upgrade", func(w *mgr.WorkerCtx) error { + if err := updater.updateAndUpgrade(w, nil, false, false); err != nil { + if data, err := os.ReadFile(filepath.Join(installedDir, "index.json")); err == nil { + fmt.Println(string(data)) + fmt.Println(updater.index.Version) + fmt.Println(updater.index.versionNum) + } + if data, err := os.ReadFile(filepath.Join(updateDir, "index.json")); err == nil { + fmt.Println(string(data)) + idx, err := ParseIndex(data, nil) + if err == nil { + fmt.Println(idx.Version) + fmt.Println(idx.versionNum) + } + } - // CHeck if the current version is now the new. - bundle, err := LoadBundle(filepath.Join(installedDir, "index.json")) - if err != nil { - panic(err) - } + t.Fatal(err) + } + return nil + }) - if bundle.Version != "1.0.1" { - panic(fmt.Errorf("expected version 1.0.1 found %s", bundle.Version)) + // Check if the current version is now the new. + newIndex, err := LoadIndex(filepath.Join(installedDir, "index.json"), nil) + if err != nil { + t.Fatal(err) + } + if newIndex.Version != "1.0.1" { + t.Fatalf("expected version 1.0.1 found %s", newIndex.Version) } } + +// GenerateMockFolder generates mock index folder for testing. +func GenerateMockFolder(dir, name, version string, published time.Time) error { + // Make sure dir exists + _ = os.MkdirAll(dir, defaultDirMode) + + // Create empty files + file, err := os.Create(filepath.Join(dir, "portmaster")) + if err != nil { + return err + } + _ = file.Close() + file, err = os.Create(filepath.Join(dir, "portmaster-core")) + if err != nil { + return err + } + _ = file.Close() + file, err = os.Create(filepath.Join(dir, "portmaster.zip")) + if err != nil { + return err + } + _ = file.Close() + file, err = os.Create(filepath.Join(dir, "assets.zip")) + if err != nil { + return err + } + _ = file.Close() + + index, err := GenerateIndexFromDir(dir, IndexScanConfig{ + Name: name, + Version: version, + }) + if err != nil { + return err + } + index.Published = published + + indexJSON, err := json.MarshalIndent(index, "", " ") + if err != nil { + fmt.Fprintf(os.Stderr, "failed to marshal index: %s\n", err) + } + + err = os.WriteFile(filepath.Join(dir, "index.json"), indexJSON, defaultFileMode) + if err != nil { + return err + } + return nil +} diff --git a/service/updates/upgrade.go b/service/updates/upgrade.go index ddb86bfa..593c3dd5 100644 --- a/service/updates/upgrade.go +++ b/service/updates/upgrade.go @@ -1,6 +1,7 @@ package updates import ( + "errors" "fmt" "io/fs" "os" @@ -31,7 +32,7 @@ func (u *Updater) upgrade(downloader *Downloader, ignoreVersion bool) error { } // Execute the upgrade. - upgradeError := u.upgradeMoveFiles(downloader, ignoreVersion) + upgradeError := u.upgradeMoveFiles(downloader) if upgradeError == nil { return nil } @@ -43,10 +44,10 @@ func (u *Updater) upgrade(downloader *Downloader, ignoreVersion bool) error { } // Recovery failed too. - return fmt.Errorf("upgrade (including recovery) failed: %s", u.cfg.Name, upgradeError) + return fmt.Errorf("upgrade (including recovery) failed: %w", upgradeError) } -func (u *Updater) upgradeMoveFiles(downloader *Downloader, ignoreVersion bool) error { +func (u *Updater) upgradeMoveFiles(downloader *Downloader) error { // Important: // We assume that the downloader has done its job and all artifacts are verified. // Files will just be moved here. @@ -65,20 +66,28 @@ func (u *Updater) upgradeMoveFiles(downloader *Downloader, ignoreVersion bool) e } files, err := os.ReadDir(u.cfg.Directory) if err != nil { - return fmt.Errorf("read current directory: %w", err) - } - for _, file := range files { - // Check if file is ignored. - if slices.Contains(u.cfg.Ignore, file.Name()) { - continue + if !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("read current directory: %w", err) } - - // Otherwise, move file to purge dir. - src := filepath.Join(u.cfg.Directory, file.Name()) - dst := filepath.Join(u.cfg.PurgeDirectory, file.Name()) - err := u.moveFile(src, dst, "", file.Type().Perm()) + err = os.MkdirAll(u.cfg.Directory, defaultDirMode) if err != nil { - return fmt.Errorf("failed to move current file %s to purge dir: %w", file.Name(), err) + return fmt.Errorf("create current directory: %w", err) + } + } else { + // Move files. + for _, file := range files { + // Check if file is ignored. + if slices.Contains(u.cfg.Ignore, file.Name()) { + continue + } + + // Otherwise, move file to purge dir. + src := filepath.Join(u.cfg.Directory, file.Name()) + dst := filepath.Join(u.cfg.PurgeDirectory, file.Name()) + err := u.moveFile(src, dst, "", file.Type().Perm()) + if err != nil { + return fmt.Errorf("failed to move current file %s to purge dir: %w", file.Name(), err) + } } } @@ -118,7 +127,7 @@ func (u *Updater) moveFile(currentPath, newPath string, sha256sum string, fileMo // Moving was successful, return. return nil } - log.Tracef("updates/%s: failed to move to %q, falling back to copy+delete: %w", u.cfg.Name, newPath, err) + log.Tracef("updates/%s: failed to move to %q, falling back to copy+delete: %s", u.cfg.Name, newPath, err) // Copy and check the checksum while we are at it. err = copyAndCheckSHA256Sum(currentPath, newPath, sha256sum, fileMode) @@ -144,7 +153,7 @@ func (u *Updater) recoverFromFailedUpgrade() error { err := u.moveFile(purgedFile, activeFile, "", file.Type().Perm()) if err != nil { // Only warn and continue to recover as many files as possible. - log.Warningf("updates/%s: failed to roll back file %s: %w", u.cfg.Name, file.Name(), err) + log.Warningf("updates/%s: failed to roll back file %s: %s", u.cfg.Name, file.Name(), err) } } diff --git a/spn/instance.go b/spn/instance.go index 5c9b3957..ce555d15 100644 --- a/spn/instance.go +++ b/spn/instance.go @@ -34,10 +34,17 @@ import ( // Instance is an instance of a Portmaster service. type Instance struct { - ctx context.Context - cancelCtx context.CancelFunc + ctx context.Context + cancelCtx context.CancelFunc + + shutdownCtx context.Context + cancelShutdownCtx context.CancelFunc + serviceGroup *mgr.Group + binDir string + dataDir string + exitCode atomic.Int32 base *base.Base @@ -67,6 +74,7 @@ type Instance struct { terminal *terminal.TerminalModule CommandLineOperation func() error + ShouldRestart bool } // New returns a new Portmaster service instance. @@ -74,6 +82,7 @@ func New() (*Instance, error) { // Create instance to pass it to modules. instance := &Instance{} instance.ctx, instance.cancelCtx = context.WithCancel(context.Background()) + instance.shutdownCtx, instance.cancelShutdownCtx = context.WithCancel(context.Background()) binaryUpdateIndex := updates.Config{ // FIXME: fill @@ -234,6 +243,18 @@ func (i *Instance) SetSleep(enabled bool) { } } +// BinDir returns the directory for binaries. +// This directory may be read-only. +func (i *Instance) BinDir() string { + return i.binDir +} + +// DataDir returns the directory for variable data. +// This directory is expected to be read/writeable. +func (i *Instance) DataDir() string { + return i.dataDir +} + // Database returns the database module. func (i *Instance) Database() *dbmodule.DBModule { return i.database @@ -379,12 +400,6 @@ func (i *Instance) Ready() bool { return i.serviceGroup.Ready() } -// Ctx returns the instance context. -// It is only canceled on shutdown. -func (i *Instance) Ctx() context.Context { - return i.ctx -} - // Start starts the instance. func (i *Instance) Start() error { return i.serviceGroup.Start() @@ -392,7 +407,6 @@ func (i *Instance) Start() error { // Stop stops the instance and cancels the instance context when done. func (i *Instance) Stop() error { - defer i.cancelCtx() return i.serviceGroup.Stop() } @@ -406,6 +420,8 @@ func (i *Instance) Restart() { i.core.EventRestart.Submit(struct{}{}) time.Sleep(10 * time.Millisecond) + // Set the restart flag and shutdown. + i.ShouldRestart = true i.shutdown(RestartExitCode) } @@ -419,32 +435,63 @@ func (i *Instance) Shutdown() { } func (i *Instance) shutdown(exitCode int) { + // Only shutdown once. + if i.IsShuttingDown() { + return + } + + // Cancel main context. + i.cancelCtx() + // Set given exit code. i.exitCode.Store(int32(exitCode)) + // Start shutdown asynchronously in a separate manager. m := mgr.New("instance") m.Go("shutdown", func(w *mgr.WorkerCtx) error { - for { - if err := i.Stop(); err != nil { - w.Error("failed to shutdown", "err", err, "retry", "1s") - time.Sleep(1 * time.Second) - } else { - return nil - } + // Stop all modules. + if err := i.Stop(); err != nil { + w.Error("failed to shutdown", "err", err) } + + // Cancel shutdown process context. + i.cancelShutdownCtx() + return nil }) } -// Stopping returns whether the instance is shutting down. -func (i *Instance) Stopping() bool { - return i.ctx.Err() == nil +// Ctx returns the instance context. +// It is canceled when shutdown is started. +func (i *Instance) Ctx() context.Context { + return i.ctx } -// Stopped returns a channel that is triggered when the instance has shut down. -func (i *Instance) Stopped() <-chan struct{} { +// IsShuttingDown returns whether the instance is shutting down. +func (i *Instance) IsShuttingDown() bool { + return i.ctx.Err() != nil +} + +// ShuttingDown returns a channel that is triggered when the instance starts shutting down. +func (i *Instance) ShuttingDown() <-chan struct{} { return i.ctx.Done() } +// ShutdownCtx returns the instance shutdown context. +// It is canceled when shutdown is complete. +func (i *Instance) ShutdownCtx() context.Context { + return i.shutdownCtx +} + +// IsShutDown returns whether the instance has stopped. +func (i *Instance) IsShutDown() bool { + return i.shutdownCtx.Err() != nil +} + +// ShutDownComplete returns a channel that is triggered when the instance has shut down. +func (i *Instance) ShutdownComplete() <-chan struct{} { + return i.shutdownCtx.Done() +} + // ExitCode returns the set exit code of the instance. func (i *Instance) ExitCode() int { return int(i.exitCode.Load())