diff --git a/service/core/api.go b/service/core/api.go index bb63464b..ea4f18d1 100644 --- a/service/core/api.go +++ b/service/core/api.go @@ -133,6 +133,22 @@ func registerAPIEndpoints() error { return err } + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "updates/from-url", + WriteMethod: "POST", + Write: api.PermitAnyone, + ActionFunc: func(ar *api.Request) (string, error) { + err := module.instance.BinaryUpdates().UpdateFromURL(string(ar.InputData)) + if err != nil { + return err.Error(), err + } + return "upgrade triggered", nil + }, + Name: "Replace current version from the version supplied in the URL", + }); err != nil { + return err + } + return nil } diff --git a/service/updates/downloader.go b/service/updates/downloader.go index 4087964b..84259901 100644 --- a/service/updates/downloader.go +++ b/service/updates/downloader.go @@ -48,7 +48,7 @@ func (d *Downloader) downloadIndexFile(ctx context.Context) error { for _, url := range d.indexURLs { content, err = d.downloadIndexFileFromURL(ctx, url) if err != nil { - log.Warningf("updates: failed while downloading index file %s", err) + log.Warningf("updates: failed while downloading index file: %s", err) continue } // Downloading was successful. @@ -60,7 +60,7 @@ func (d *Downloader) downloadIndexFile(ctx context.Context) error { } // Parsing was successful var version *semver.Version - version, err = semver.NewVersion(d.bundle.Version) + version, err = semver.NewVersion(bundle.Version) if err != nil { log.Warningf("updates: failed to parse bundle version: %s", err) continue @@ -116,7 +116,7 @@ func (d *Downloader) downloadIndexFileFromURL(ctx context.Context, url string) ( // Request the index file req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { - return "", fmt.Errorf("failed to create GET request to %s: %w", url, err) + return "", fmt.Errorf("failed to create GET request to: %w", err) } if UserAgent != "" { req.Header.Set("User-Agent", UserAgent) @@ -281,14 +281,14 @@ func (d *Downloader) downloadFile(ctx context.Context, url string) ([]byte, erro // Try to make the request req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { - return nil, fmt.Errorf("failed to create GET request to %s: %s", url, err) + return nil, fmt.Errorf("failed to create GET request to %s: %w", url, err) } if UserAgent != "" { req.Header.Set("User-Agent", UserAgent) } resp, err := d.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed a get file request to: %s", err) + return nil, fmt.Errorf("failed a get file request to: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -299,7 +299,7 @@ func (d *Downloader) downloadFile(ctx context.Context, url string) ([]byte, erro content, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read body of response: %s", err) + return nil, fmt.Errorf("failed to read body of response: %w", err) } return content, nil } diff --git a/service/updates/module.go b/service/updates/module.go index 44b1146f..2e583a84 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -9,6 +9,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/mgr" + "github.com/tevino/abool" ) const ( @@ -48,7 +49,7 @@ type Updates struct { states *mgr.StateMgr updateCheckWorkerMgr *mgr.WorkerMgr - upgraderWorkerMgr *mgr.WorkerMgr + upgradeWorkerMgr *mgr.WorkerMgr EventResourcesUpdated *mgr.EventMgr[struct{}] @@ -58,6 +59,8 @@ type Updates struct { autoApply bool needsRestart bool + isUpdateRunning *abool.AtomicBool + instance instance } @@ -70,15 +73,25 @@ func New(instance instance, name string, index UpdateIndex) (*Updates, error) { EventResourcesUpdated: mgr.NewEventMgr[struct{}](ResourceUpdateEvent, m), - autoApply: index.AutoApply, - needsRestart: index.NeedsRestart, + autoApply: index.AutoApply, + needsRestart: index.NeedsRestart, + isUpdateRunning: abool.NewBool(false), instance: instance, } // Workers module.updateCheckWorkerMgr = m.NewWorkerMgr("update checker", module.checkForUpdates, nil).Repeat(updateTaskRepeatDuration) - module.upgraderWorkerMgr = m.NewWorkerMgr("upgrader", module.applyUpdates, nil) + module.upgradeWorkerMgr = m.NewWorkerMgr("upgrader", func(w *mgr.WorkerCtx) error { + if !module.isUpdateRunning.SetToIf(false, true) { + return fmt.Errorf("unable to apply updates, concurrent updater task is running") + } + // Make sure to unset it + defer module.isUpdateRunning.UnSet() + + module.applyUpdates(module.downloader, false) + return nil + }, nil) var err error module.registry, err = CreateRegistry(index) @@ -92,11 +105,17 @@ func New(instance instance, name string, index UpdateIndex) (*Updates, error) { } func (u *Updates) checkForUpdates(wc *mgr.WorkerCtx) error { + if !u.isUpdateRunning.SetToIf(false, true) { + return fmt.Errorf("unable to check for updates, concurrent updater task is running") + } + // Make sure to unset it on return. + defer u.isUpdateRunning.UnSet() // Download the index file. err := u.downloader.downloadIndexFile(wc.Ctx()) if err != nil { return fmt.Errorf("failed to download index file: %w", err) } + // Check if there is a new version. if u.downloader.version.LessThanOrEqual(u.registry.version) { log.Infof("updates: check compete: no new updates") @@ -115,8 +134,8 @@ func (u *Updates) checkForUpdates(wc *mgr.WorkerCtx) error { log.Errorf("updates: failed to download update: %s", err) } else { if u.autoApply { - // Trigger upgrade. - u.upgraderWorkerMgr.Go() + // Apply updates. + u.applyUpdates(u.downloader, false) } else { // Notify the user with option to trigger upgrade. notifications.NotifyPrompt(updateAvailableNotificationID, "New update is available.", fmt.Sprintf("%s %s", downloadBundle.Name, downloadBundle.Version), notifications.Action{ @@ -133,16 +152,57 @@ func (u *Updates) checkForUpdates(wc *mgr.WorkerCtx) error { return nil } -func (u *Updates) applyUpdates(_ *mgr.WorkerCtx) error { - currentBundle := u.registry.bundle - downloadBundle := u.downloader.bundle - if u.downloader.version.LessThanOrEqual(u.registry.version) { - // No new version, silently return. +// UpdateFromURL installs an update from the provided url. +func (u *Updates) UpdateFromURL(url string) error { + if !u.isUpdateRunning.SetToIf(false, true) { + return fmt.Errorf("unable to upgrade from url, concurrent updater task is running") + } + + u.m.Go("custom-url-downloader", func(w *mgr.WorkerCtx) error { + // Make sure to unset it on return. + defer u.isUpdateRunning.UnSet() + + // Initialize parameters + index := UpdateIndex{ + DownloadDirectory: u.downloader.dir, + IndexURLs: []string{url}, + IndexFile: u.downloader.indexFile, + } + + // Initialize with proper values and download the index file. + downloader := CreateDownloader(index) + err := downloader.downloadIndexFile(w.Ctx()) + if err != nil { + return err + } + + // Start downloading the artifacts + err = downloader.downloadAndVerify(w.Ctx()) + if err != nil { + return err + } + + // Artifacts are downloaded, perform the update. + u.applyUpdates(downloader, true) + return nil + }) + return nil +} + +func (u *Updates) applyUpdates(downloader Downloader, force bool) error { + currentBundle := u.registry.bundle + downloadBundle := downloader.bundle + + if !force { + if u.downloader.version.LessThanOrEqual(u.registry.version) { + // No new version, silently return. + return nil + } } log.Infof("update: starting update: %s %s -> %s", currentBundle.Name, currentBundle.Version, downloadBundle.Version) - err := u.registry.performRecoverableUpgrade(u.downloader.dir, u.downloader.indexFile) + err := u.registry.performRecoverableUpgrade(downloader.dir, downloader.indexFile) if err != nil { // Notify the user that update failed. notifications.NotifyPrompt(updateFailedNotificationID, "Failed to apply update.", err.Error()) @@ -166,7 +226,7 @@ func (u *Updates) TriggerUpdateCheck() { // TriggerApplyUpdates triggers upgrade. func (u *Updates) TriggerApplyUpdates() { - u.upgraderWorkerMgr.Go() + u.upgradeWorkerMgr.Go() } // States returns the state manager.