diff --git a/service/intel/filterlists/module.go b/service/intel/filterlists/module.go index ecded5dc..8f3a5a5e 100644 --- a/service/intel/filterlists/module.go +++ b/service/intel/filterlists/module.go @@ -62,7 +62,7 @@ func prep() error { if ignoreUpdateEvents.IsSet() { return false, nil } - log.Debugf("performing filter list upadte") + log.Debugf("performing filter list update") return false, tryListUpdate(wc.Ctx()) }) diff --git a/service/updates/bundle.go b/service/updates/bundle.go index 904ae8fd..02b6089c 100644 --- a/service/updates/bundle.go +++ b/service/updates/bundle.go @@ -1,23 +1,16 @@ package updates import ( - "archive/zip" "bytes" - "compress/gzip" - "context" "crypto/sha256" "encoding/hex" "encoding/json" - "errors" "fmt" "io" - "net/http" "os" "path/filepath" "runtime" "time" - - "github.com/safing/portmaster/base/log" ) const MaxUnpackSize = 1 << 30 // 2^30 == 1GB @@ -33,6 +26,19 @@ type Artifact struct { Version string `json:"Version,omitempty"` } +func (a *Artifact) GetFileMode() os.FileMode { + // Special case for portmaster ui. Should be able to be executed from the regular user + if a.Platform == currentPlatform && a.Filename == "portmaster" { + return executableUIFileMode + } + + if a.Platform == currentPlatform { + return executableFileMode + } + + return defaultFileMode +} + type Bundle struct { Name string `json:"Bundle"` Version string `json:"Version"` @@ -40,10 +46,9 @@ type Bundle struct { Artifacts []Artifact `json:"Artifacts"` } -func ParseBundle(dir string, indexFile string) (*Bundle, error) { - filepath := fmt.Sprintf("%s/%s", dir, indexFile) +func ParseBundle(indexFile string) (*Bundle, error) { // Check if the file exists. - file, err := os.Open(filepath) + file, err := os.Open(indexFile) if err != nil { return nil, fmt.Errorf("failed to open index file: %w", err) } @@ -59,7 +64,7 @@ func ParseBundle(dir string, indexFile string) (*Bundle, error) { var bundle Bundle err = json.Unmarshal(content, &bundle) if err != nil { - return nil, fmt.Errorf("%s %w", filepath, err) + return nil, fmt.Errorf("%s %w", indexFile, err) } // Filter artifacts @@ -74,70 +79,6 @@ func ParseBundle(dir string, indexFile string) (*Bundle, error) { return &bundle, nil } -// CopyMatchingFilesFromCurrent check if there the current bundle files has matching files with the new bundle and copies them if they match. -func (bundle Bundle) CopyMatchingFilesFromCurrent(current Bundle, currentDir, newDir string) error { - // Make sure new dir exists - _ = os.MkdirAll(newDir, defaultDirMode) - - for _, currentArtifact := range current.Artifacts { - new: - for _, newArtifact := range bundle.Artifacts { - if currentArtifact.Filename == newArtifact.Filename { - if currentArtifact.SHA256 == newArtifact.SHA256 { - // Read the content of the current file. - sourceFilePath := filepath.Join(currentDir, newArtifact.Filename) - content, err := os.ReadFile(sourceFilePath) - if err != nil { - return fmt.Errorf("failed to read file %s: %w", sourceFilePath, err) - } - - // Check if the content matches the artifact hash - expectedHash, err := hex.DecodeString(newArtifact.SHA256) - if err != nil || len(expectedHash) != sha256.Size { - return fmt.Errorf("invalid artifact hash %s: %w", newArtifact.SHA256, err) - } - hash := sha256.Sum256(content) - if !bytes.Equal(expectedHash, hash[:]) { - return fmt.Errorf("expected and file hash mismatch: %s", sourceFilePath) - } - - // Create new file - destFilePath := filepath.Join(newDir, newArtifact.Filename) - err = os.WriteFile(destFilePath, content, defaultFileMode) - if err != nil { - return fmt.Errorf("failed to write to file %s: %w", destFilePath, err) - } - log.Debugf("updates: file copied from current version: %s", newArtifact.Filename) - } - break new - } - } - } - return nil -} - -func (bundle Bundle) DownloadAndVerify(ctx context.Context, client *http.Client, dir string) { - // Make sure dir exists - _ = os.MkdirAll(dir, defaultDirMode) - - for _, artifact := range bundle.Artifacts { - filePath := filepath.Join(dir, artifact.Filename) - - // Check file is already downloaded and valid. - exists, _ := checkIfFileIsValid(filePath, artifact) - if exists { - log.Debugf("updates: file already downloaded: %s", filePath) - continue - } - - // Download artifact - err := processArtifact(ctx, client, artifact, filePath) - if err != nil { - log.Errorf("updates: %s", err) - } - } -} - // Verify checks if the files are present int the dataDir and have the correct hash. func (bundle Bundle) Verify(dir string) error { for _, artifact := range bundle.Artifacts { @@ -179,144 +120,3 @@ func checkIfFileIsValid(filename string, artifact Artifact) (bool, error) { } return true, nil } - -func processArtifact(ctx context.Context, client *http.Client, artifact Artifact, filePath string) error { - providedHash, err := hex.DecodeString(artifact.SHA256) - if err != nil || len(providedHash) != sha256.Size { - return fmt.Errorf("invalid provided hash %s: %w", artifact.SHA256, err) - } - - // Download - log.Debugf("updates: downloading file: %s", artifact.Filename) - content, err := downloadFile(ctx, client, artifact.URLs) - if err != nil { - return fmt.Errorf("failed to download artifact: %w", err) - } - - // Decompress - if artifact.Unpack != "" { - content, err = unpack(artifact.Unpack, content) - if err != nil { - return fmt.Errorf("failed to decompress artifact: %w", err) - } - } - - // Verify - hash := sha256.Sum256(content) - if !bytes.Equal(providedHash, hash[:]) { - return fmt.Errorf("failed to verify artifact: %s", artifact.Filename) - } - - // Save - tmpFilename := fmt.Sprintf("%s.download", filePath) - fileMode := defaultFileMode - if artifact.Platform == currentPlatform { - fileMode = executableFileMode - } - err = os.WriteFile(tmpFilename, content, fileMode) - if err != nil { - return fmt.Errorf("failed to write to file: %w", err) - } - - // Rename - err = os.Rename(tmpFilename, filePath) - if err != nil { - return fmt.Errorf("failed to rename file: %w", err) - } - - log.Infof("updates: file downloaded and verified: %s", artifact.Filename) - - return nil -} - -func downloadFile(ctx context.Context, client *http.Client, urls []string) ([]byte, error) { - for _, url := range urls { - // Try to make the request - req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) - if err != nil { - log.Warningf("failed to create GET request to %s: %s", url, err) - continue - } - if UserAgent != "" { - req.Header.Set("User-Agent", UserAgent) - } - resp, err := client.Do(req) - if err != nil { - log.Warningf("failed a get file request to: %s", err) - continue - } - defer func() { _ = resp.Body.Close() }() - - // Check if the server returned an error - if resp.StatusCode != http.StatusOK { - log.Warningf("server returned non-OK status: %d %s", resp.StatusCode, resp.Status) - continue - } - - content, err := io.ReadAll(resp.Body) - if err != nil { - log.Warningf("failed to read body of response: %s", err) - continue - } - return content, nil - } - - return nil, fmt.Errorf("failed to download file from the provided urls") -} - -func unpack(cType string, fileBytes []byte) ([]byte, error) { - switch cType { - case "zip": - return decompressZip(fileBytes) - case "gz": - return decompressGzip(fileBytes) - default: - return nil, fmt.Errorf("unsupported compression type") - } -} - -func decompressGzip(data []byte) ([]byte, error) { - // Create a gzip reader from the byte array - gzipReader, err := gzip.NewReader(bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("failed to create gzip reader: %w", err) - } - defer func() { _ = gzipReader.Close() }() - - var buf bytes.Buffer - _, err = io.CopyN(&buf, gzipReader, MaxUnpackSize) - if err != nil && !errors.Is(err, io.EOF) { - return nil, fmt.Errorf("failed to read gzip file: %w", err) - } - - return buf.Bytes(), nil -} - -func decompressZip(data []byte) ([]byte, error) { - // Create a zip reader from the byte array - zipReader, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) - if err != nil { - return nil, fmt.Errorf("failed to create zip reader: %w", err) - } - - // Ensure there is only one file in the zip - if len(zipReader.File) != 1 { - return nil, fmt.Errorf("zip file must contain exactly one file") - } - - // Read the single file in the zip - file := zipReader.File[0] - fileReader, err := file.Open() - if err != nil { - return nil, fmt.Errorf("failed to open file in zip: %w", err) - } - defer func() { _ = fileReader.Close() }() - - var buf bytes.Buffer - _, err = io.CopyN(&buf, fileReader, MaxUnpackSize) - if err != nil && !errors.Is(err, io.EOF) { - return nil, fmt.Errorf("failed to read file in zip: %w", err) - } - - return buf.Bytes(), nil -} diff --git a/service/updates/downloader.go b/service/updates/downloader.go new file mode 100644 index 00000000..3cc35624 --- /dev/null +++ b/service/updates/downloader.go @@ -0,0 +1,343 @@ +package updates + +import ( + "archive/zip" + "bytes" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + semver "github.com/hashicorp/go-version" + "github.com/safing/portmaster/base/log" +) + +type Downloader struct { + dir string + indexFile string + indexURLs []string + bundle *Bundle + version *semver.Version + + httpClient http.Client +} + +func CreateDownloader(index UpdateIndex) Downloader { + return Downloader{ + dir: index.DownloadDirectory, + indexFile: index.IndexFile, + indexURLs: index.IndexURLs, + } +} + +func (d *Downloader) downloadIndexFile(ctx context.Context) (err error) { + // Make sure dir exists + _ = os.MkdirAll(d.dir, defaultDirMode) + for _, url := range d.indexURLs { + err = d.downloadIndexFileFromURL(ctx, url) + if err != nil { + log.Warningf("updates: failed while downloading index file %s", err) + continue + } + // Downloading was successful. + err = nil + break + } + + if err == nil { + err = d.parseBundle() + } + + return +} + +// Verify verifies if the downloaded files match the corresponding hash. +func (d *Downloader) Verify() error { + err := d.parseBundle() + if err != nil { + return err + } + + return d.bundle.Verify(d.dir) +} + +func (d *Downloader) parseBundle() error { + indexFilepath := filepath.Join(d.dir, d.indexFile) + var err error + d.bundle, err = ParseBundle(indexFilepath) + if err != nil { + return err + } + + d.version, err = semver.NewVersion(d.bundle.Version) + if err != nil { + return err + } + return nil +} + +func (d *Downloader) downloadIndexFileFromURL(ctx context.Context, url string) error { + // Request the index file + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return fmt.Errorf("failed to create GET request to %s: %w", url, err) + } + if UserAgent != "" { + req.Header.Set("User-Agent", UserAgent) + } + + // Perform request + resp, err := d.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed GET request to %s: %w", url, err) + } + defer func() { _ = resp.Body.Close() }() + + // Check the status code + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("received error from the server status code: %s", resp.Status) + } + // Create file + indexFilepath := filepath.Join(d.dir, d.indexFile) + file, err := os.Create(indexFilepath) + if err != nil { + return err + } + defer func() { _ = file.Close() }() + + // Write body of the response + _, err = io.Copy(file, resp.Body) + if err != nil { + return err + } + + return nil +} + +// CopyMatchingFilesFromCurrent check if there the current bundle files has matching files with the new bundle and copies them if they match. +func (d *Downloader) copyMatchingFilesFromCurrent(currentFiles map[string]File) error { + // Make sure new dir exists + _ = os.MkdirAll(d.dir, defaultDirMode) + + for _, a := range d.bundle.Artifacts { + currentFile, ok := currentFiles[a.Filename] + if ok && currentFile.Sha256() == a.SHA256 { + // Read the content of the current file. + content, err := os.ReadFile(currentFile.Path()) + if err != nil { + return fmt.Errorf("failed to read file %s: %w", currentFile.Path(), err) + } + + // Check if the content matches the artifact hash + expectedHash, err := hex.DecodeString(a.SHA256) + if err != nil || len(expectedHash) != sha256.Size { + return fmt.Errorf("invalid artifact hash %s: %w", a.SHA256, err) + } + hash := sha256.Sum256(content) + if !bytes.Equal(expectedHash, hash[:]) { + return fmt.Errorf("expected and file hash mismatch: %s", currentFile.Path()) + } + + // Create new file + destFilePath := filepath.Join(d.dir, a.Filename) + err = os.WriteFile(destFilePath, content, a.GetFileMode()) + if err != nil { + return fmt.Errorf("failed to write to file %s: %w", destFilePath, err) + } + log.Debugf("updates: file copied from current version: %s", a.Filename) + } + } + return nil +} + +func (d *Downloader) downloadAndVerify(ctx context.Context) error { + // Make sure we have the bundle file parsed. + err := d.parseBundle() + if err != nil { + log.Errorf("updates: invalid update bundle file: %s", err) + } + + // Make sure dir exists + _ = os.MkdirAll(d.dir, defaultDirMode) + + for _, artifact := range d.bundle.Artifacts { + filePath := filepath.Join(d.dir, artifact.Filename) + + // Check file is already downloaded and valid. + exists, _ := checkIfFileIsValid(filePath, artifact) + if exists { + log.Debugf("updates: file already downloaded: %s", filePath) + continue + } + + // Download artifact + err := d.processArtifact(ctx, artifact, filePath) + if err != nil { + return err + } + } + return nil +} + +func (d *Downloader) processArtifact(ctx context.Context, artifact Artifact, filePath string) error { + providedHash, err := hex.DecodeString(artifact.SHA256) + if err != nil || len(providedHash) != sha256.Size { + return fmt.Errorf("invalid provided hash %s: %w", artifact.SHA256, err) + } + + // Download + log.Debugf("updates: downloading file: %s", artifact.Filename) + content, err := d.downloadFile(ctx, artifact.URLs) + if err != nil { + return fmt.Errorf("failed to download artifact: %w", err) + } + + // Decompress + if artifact.Unpack != "" { + content, err = unpack(artifact.Unpack, content) + if err != nil { + return fmt.Errorf("failed to decompress artifact: %w", err) + } + } + + // Verify + hash := sha256.Sum256(content) + if !bytes.Equal(providedHash, hash[:]) { + return fmt.Errorf("failed to verify artifact: %s", artifact.Filename) + } + + // Save + tmpFilename := fmt.Sprintf("%s.download", filePath) + err = os.WriteFile(tmpFilename, content, artifact.GetFileMode()) + if err != nil { + return fmt.Errorf("failed to write to file: %w", err) + } + + // Rename + err = os.Rename(tmpFilename, filePath) + if err != nil { + return fmt.Errorf("failed to rename file: %w", err) + } + + log.Infof("updates: file downloaded and verified: %s", artifact.Filename) + + return nil +} + +func (d *Downloader) downloadFile(ctx context.Context, urls []string) ([]byte, error) { + for _, url := range urls { + // Try to make the request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + log.Warningf("failed to create GET request to %s: %s", url, err) + continue + } + if UserAgent != "" { + req.Header.Set("User-Agent", UserAgent) + } + resp, err := d.httpClient.Do(req) + if err != nil { + log.Warningf("failed a get file request to: %s", err) + continue + } + defer func() { _ = resp.Body.Close() }() + + // Check if the server returned an error + if resp.StatusCode != http.StatusOK { + log.Warningf("server returned non-OK status: %d %s", resp.StatusCode, resp.Status) + continue + } + + content, err := io.ReadAll(resp.Body) + if err != nil { + log.Warningf("failed to read body of response: %s", err) + continue + } + return content, nil + } + + return nil, fmt.Errorf("failed to download file from the provided urls") +} + +func (d *Downloader) deleteUnfinishedDownloads() error { + entries, err := os.ReadDir(d.dir) + if err != nil { + return err + } + for _, e := range entries { + // Check if the current file has the download extension + if !e.IsDir() && strings.HasSuffix(e.Name(), ".download") { + path := filepath.Join(d.dir, e.Name()) + log.Warningf("updates: deleting unfinished download file: %s\n", path) + err := os.Remove(path) + if err != nil { + log.Errorf("updates: failed to delete unfinished download file %s: %s", path, err) + } + } + } + return nil +} + +func unpack(cType string, fileBytes []byte) ([]byte, error) { + switch cType { + case "zip": + return decompressZip(fileBytes) + case "gz": + return decompressGzip(fileBytes) + default: + return nil, fmt.Errorf("unsupported compression type") + } +} + +func decompressGzip(data []byte) ([]byte, error) { + // Create a gzip reader from the byte array + gzipReader, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("failed to create gzip reader: %w", err) + } + defer func() { _ = gzipReader.Close() }() + + var buf bytes.Buffer + _, err = io.CopyN(&buf, gzipReader, MaxUnpackSize) + if err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("failed to read gzip file: %w", err) + } + + return buf.Bytes(), nil +} + +func decompressZip(data []byte) ([]byte, error) { + // Create a zip reader from the byte array + zipReader, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + return nil, fmt.Errorf("failed to create zip reader: %w", err) + } + + // Ensure there is only one file in the zip + if len(zipReader.File) != 1 { + return nil, fmt.Errorf("zip file must contain exactly one file") + } + + // Read the single file in the zip + file := zipReader.File[0] + fileReader, err := file.Open() + if err != nil { + return nil, fmt.Errorf("failed to open file in zip: %w", err) + } + defer func() { _ = fileReader.Close() }() + + var buf bytes.Buffer + _, err = io.CopyN(&buf, fileReader, MaxUnpackSize) + if err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("failed to read file in zip: %w", err) + } + + return buf.Bytes(), nil +} diff --git a/service/updates/index.go b/service/updates/index.go index c9a5f29d..de15a98b 100644 --- a/service/updates/index.go +++ b/service/updates/index.go @@ -1,77 +1 @@ package updates - -import ( - "context" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - - "github.com/safing/portmaster/base/log" -) - -type UpdateIndex struct { - Directory string - DownloadDirectory string - PurgeDirectory string - Ignore []string - IndexURLs []string - IndexFile string - AutoApply bool - NeedsRestart bool -} - -func (ui *UpdateIndex) DownloadIndexFile(ctx context.Context, client *http.Client) (err error) { - // Make sure dir exists - _ = os.MkdirAll(ui.DownloadDirectory, defaultDirMode) - for _, url := range ui.IndexURLs { - err = ui.downloadIndexFileFromURL(ctx, client, url) - if err != nil { - log.Warningf("updates: failed while downloading index file %s", err) - continue - } - // Downloading was successful. - err = nil - break - } - return -} - -func (ui *UpdateIndex) downloadIndexFileFromURL(ctx context.Context, client *http.Client, url string) error { - // Request the index file - req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) - if err != nil { - return fmt.Errorf("failed to create GET request to %s: %w", url, err) - } - if UserAgent != "" { - req.Header.Set("User-Agent", UserAgent) - } - - // Perform request - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed GET request to %s: %w", url, err) - } - defer func() { _ = resp.Body.Close() }() - - // Check the status code - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("received error from the server status code: %s", resp.Status) - } - // Create file - filePath := filepath.Join(ui.DownloadDirectory, ui.IndexFile) - file, err := os.Create(filePath) - if err != nil { - return err - } - defer func() { _ = file.Close() }() - - // Write body of the response - _, err = io.Copy(file, resp.Body) - if err != nil { - return err - } - - return nil -} diff --git a/service/updates/module.go b/service/updates/module.go index 5caf2ab4..813d5025 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -1,15 +1,10 @@ package updates import ( - "context" - "errors" "fmt" - "net/http" - "os" "runtime" "time" - semver "github.com/hashicorp/go-version" "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" @@ -21,12 +16,6 @@ const ( updateTaskRepeatDuration = 1 * time.Hour updateAvailableNotificationID = "updates:update-available" - // VersionUpdateEvent is emitted every time a new - // version of a monitored resource is selected. - // During module initialization VersionUpdateEvent - // is also emitted. - VersionUpdateEvent = "active version update" - // ResourceUpdateEvent is emitted every time the // updater successfully performed a resource update. // ResourceUpdateEvent is emitted even if no new @@ -34,6 +23,12 @@ const ( // to check if new versions of their resources are // available by checking File.UpgradeAvailable(). ResourceUpdateEvent = "resource update" + + // VersionUpdateEvent is emitted every time a new + // version of a monitored resource is selected. + // During module initialization VersionUpdateEvent + // is also emitted. + VersionUpdateEvent = "active version update" ) // UserAgent is an HTTP User-Agent that is used to add @@ -41,25 +36,18 @@ const ( // fetching resources from the update server. var UserAgent = fmt.Sprintf("Portmaster (%s %s)", runtime.GOOS, runtime.GOARCH) -type File struct { - id string - path string +// UpdateIndex holds the configuration for the updates module +type UpdateIndex struct { + Directory string + DownloadDirectory string + PurgeDirectory string + Ignore []string + IndexURLs []string + IndexFile string + AutoApply bool + NeedsRestart bool } -func (f *File) Identifier() string { - return f.id -} - -func (f *File) Path() string { - return f.path -} - -func (f *File) Version() string { - return "" -} - -var ErrNotFound error = errors.New("file not found") - // Updates provides access to released artifacts. type Updates struct { m *mgr.Manager @@ -71,12 +59,11 @@ type Updates struct { EventResourcesUpdated *mgr.EventMgr[struct{}] EventVersionsUpdated *mgr.EventMgr[struct{}] - updateIndex UpdateIndex + registry Registry + downloader Downloader - bundle *Bundle - updateBundle *Bundle - - files map[string]File + autoApply bool + needsRestart bool instance instance } @@ -91,8 +78,8 @@ func New(instance instance, name string, index UpdateIndex) (*Updates, error) { EventResourcesUpdated: mgr.NewEventMgr[struct{}](ResourceUpdateEvent, m), EventVersionsUpdated: mgr.NewEventMgr[struct{}](VersionUpdateEvent, m), - updateIndex: index, - files: make(map[string]File), + autoApply: index.AutoApply, + needsRestart: index.NeedsRestart, instance: instance, } @@ -103,53 +90,39 @@ func New(instance instance, name string, index UpdateIndex) (*Updates, error) { module.upgraderWorkerMgr = m.NewWorkerMgr("upgrader", module.applyUpdates, nil) var err error - module.bundle, err = ParseBundle(module.updateIndex.Directory, module.updateIndex.IndexFile) + module.registry, err = CreateRegistry(index) if err != nil { - return nil, fmt.Errorf("failed to parse binary bundle: %s", err) + return nil, fmt.Errorf("failed to create registry: %w", err) } - // Add bundle artifacts to registry. - module.processBundle(module.bundle) + module.downloader = CreateDownloader(index) return module, nil } -func (reg *Updates) processBundle(bundle *Bundle) { - for _, artifact := range bundle.Artifacts { - artifactPath := fmt.Sprintf("%s/%s", reg.updateIndex.Directory, artifact.Filename) - reg.files[artifact.Filename] = File{id: artifact.Filename, path: artifactPath} - } -} - func (u *Updates) checkForUpdates(wc *mgr.WorkerCtx) error { - httpClient := http.Client{} - err := u.updateIndex.DownloadIndexFile(wc.Ctx(), &httpClient) + err := u.downloader.downloadIndexFile(wc.Ctx()) if err != nil { - return fmt.Errorf("failed to download index file: %s", err) + return fmt.Errorf("failed to download index file: %w", err) } - u.updateBundle, err = ParseBundle(u.updateIndex.DownloadDirectory, u.updateIndex.IndexFile) - if err != nil { - return fmt.Errorf("failed parsing bundle: %s", err) - } defer u.EventResourcesUpdated.Submit(struct{}{}) - hasUpdate, err := u.checkVersionIncrement() - if err != nil { - return fmt.Errorf("failed to compare versions: %s", err) - } - - if !hasUpdate { + if u.downloader.version.LessThanOrEqual(u.registry.version) { log.Infof("updates: check compete: no new updates") return nil } - - log.Infof("updates: check complete: downloading new version: %s %s", u.updateBundle.Name, u.updateBundle.Version) - err = u.downloadUpdates(wc.Ctx(), &httpClient) + downloadBundle := u.downloader.bundle + log.Infof("updates: check complete: downloading new version: %s %s", downloadBundle.Name, downloadBundle.Version) + err = u.downloader.copyMatchingFilesFromCurrent(u.registry.files) if err != nil { - log.Errorf("updates: failed to download bundle: %s", err) + log.Warningf("updates: failed to copy files from current installation: %s", err) + } + err = u.downloader.downloadAndVerify(wc.Ctx()) + if err != nil { + log.Errorf("updates: failed to download update: %s", err) } else { - if u.updateIndex.AutoApply { + if u.autoApply { u.upgraderWorkerMgr.Go() } else { notifications.NotifyPrompt(updateAvailableNotificationID, "Update available", "Apply update and restart?", notifications.Action{ @@ -166,72 +139,28 @@ func (u *Updates) checkForUpdates(wc *mgr.WorkerCtx) error { return nil } -func (u *Updates) checkVersionIncrement() (bool, error) { - // Compare current and downloaded index version. - currentVersion, err := semver.NewVersion(u.bundle.Version) - if err != nil { - return false, err - } - downloadVersion, err := semver.NewVersion(u.updateBundle.Version) - if err != nil { - return false, err - } - log.Debugf("updates: checking version: curr: %s new: %s", currentVersion.String(), downloadVersion.String()) - return downloadVersion.GreaterThan(currentVersion), nil -} - -func (u *Updates) downloadUpdates(ctx context.Context, client *http.Client) error { - if u.updateBundle == nil { - // checkForUpdates needs to be called before this. - return fmt.Errorf("no valid update bundle found") - } - _ = deleteUnfinishedDownloads(u.updateIndex.DownloadDirectory) - err := u.updateBundle.CopyMatchingFilesFromCurrent(*u.bundle, u.updateIndex.Directory, u.updateIndex.DownloadDirectory) - if err != nil { - log.Warningf("updates: error while coping file from current to update: %s", err) - } - u.updateBundle.DownloadAndVerify(ctx, client, u.updateIndex.DownloadDirectory) - return nil -} - func (u *Updates) applyUpdates(_ *mgr.WorkerCtx) error { - // Check if we have new version - hasNewVersion, err := u.checkVersionIncrement() - if err != nil { - return fmt.Errorf("error while reading bundle version: %w", err) - } - - if !hasNewVersion { - return fmt.Errorf("there is no new version to apply") - } - - // Verify files of the downloaded files. - err = u.updateBundle.Verify(u.updateIndex.DownloadDirectory) - if err != nil { - return fmt.Errorf("failed to verify downloaded files: %s", err) - } - - // New version is downloaded and verified. Start the update process - log.Infof("update: starting update: %s %s -> %s", u.bundle.Name, u.bundle.Version, u.updateBundle.Version) - err = switchFolders(u.updateIndex, *u.updateBundle) + currentBundle := u.registry.bundle + downloadBundle := u.downloader.bundle + log.Infof("update: starting update: %s %s -> %s", currentBundle.Name, currentBundle.Version, downloadBundle.Version) + err := u.registry.performUpgrade(u.downloader.dir, u.downloader.indexFile) if err != nil { // TODO(vladimir): Send notification to UI log.Errorf("updates: failed to apply updates: %s", err) - } else { + } else if u.needsRestart { // TODO(vladimir): Prompt user to restart? - if u.updateIndex.NeedsRestart { - u.instance.Restart() - } + u.instance.Restart() } + u.EventResourcesUpdated.Submit(struct{}{}) return nil } -// TriggerUpdateCheck triggers an update check +// TriggerUpdateCheck triggers an update check. func (u *Updates) TriggerUpdateCheck() { u.updateCheckWorkerMgr.Go() } -// TriggerApplyUpdates triggers upgrade +// TriggerApplyUpdates triggers upgrade. func (u *Updates) TriggerApplyUpdates() { u.upgraderWorkerMgr.Go() } @@ -250,10 +179,8 @@ func (u *Updates) Manager() *mgr.Manager { func (u *Updates) Start() error { // Remove old files u.m.Go("old files cleaner", func(ctx *mgr.WorkerCtx) error { - err := os.RemoveAll(u.updateIndex.PurgeDirectory) - if err != nil { - return fmt.Errorf("failed to delete folder: %w", err) - } + _ = u.registry.CleanOldFiles() + _ = u.downloader.deleteUnfinishedDownloads() return nil }) u.updateCheckWorkerMgr.Go() @@ -261,15 +188,13 @@ func (u *Updates) Start() error { return nil } +// GetFile returns the path of a file given the name. func (u *Updates) GetFile(id string) (*File, error) { - file, ok := u.files[id] + file, ok := u.registry.files[id] if ok { return &file, nil } else { log.Errorf("updates: requested file id not found: %s", id) - for _, file := range u.files { - log.Debugf("File: %s", file) - } return nil, ErrNotFound } } diff --git a/service/updates/registry.go b/service/updates/registry.go new file mode 100644 index 00000000..72248508 --- /dev/null +++ b/service/updates/registry.go @@ -0,0 +1,177 @@ +package updates + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + + semver "github.com/hashicorp/go-version" + "github.com/safing/portmaster/base/log" +) + +const ( + defaultFileMode = os.FileMode(0o0644) + executableFileMode = os.FileMode(0o0744) + executableUIFileMode = os.FileMode(0o0755) + defaultDirMode = os.FileMode(0o0755) +) + +type Registry struct { + bundle *Bundle + dir string + purgeDir string + files map[string]File + + version *semver.Version +} + +func CreateRegistry(index UpdateIndex) (Registry, error) { + registry := Registry{ + dir: index.Directory, + purgeDir: index.PurgeDirectory, + files: make(map[string]File), + } + // Parse bundle + var err error + registry.bundle, err = ParseBundle(filepath.Join(index.Directory, index.IndexFile)) + if err != nil { + return Registry{}, err + } + + // Parse version + registry.version, err = semver.NewVersion(registry.bundle.Version) + if err != nil { + log.Errorf("updates: failed to parse current version: %s", err) + } + + // Process files + for _, artifact := range registry.bundle.Artifacts { + artifactPath := filepath.Join(registry.dir, artifact.Filename) + registry.files[artifact.Filename] = File{id: artifact.Filename, path: artifactPath, version: registry.bundle.Version, sha256: artifact.SHA256} + } + return registry, nil +} + +func (r *Registry) performUpgrade(downloadDir string, indexFile string) error { + // Make sure provided update is valid + indexFilepath := filepath.Join(downloadDir, indexFile) + bundle, err := ParseBundle(indexFilepath) + if err != nil { + return fmt.Errorf("invalid update: %w", err) + } + + err = bundle.Verify(downloadDir) + if err != nil { + return fmt.Errorf("invalid update: %w", err) + } + + // Create purge dir. + err = os.MkdirAll(r.purgeDir, defaultDirMode) + if err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Read all files in the current version folder. + files, err := os.ReadDir(r.dir) + if err != nil { + return err + } + + // Move current version files into purge folder. + log.Debugf("updates: removing the old version") + for _, file := range files { + currentFilepath := filepath.Join(r.dir, file.Name()) + purgePath := filepath.Join(r.purgeDir, file.Name()) + err := os.Rename(currentFilepath, purgePath) + if err != nil { + return fmt.Errorf("failed to move file %s: %w", currentFilepath, err) + } + } + + // Move the new index file + log.Debugf("updates: installing the new version") + newIndexFile := filepath.Join(r.dir, indexFile) + err = os.Rename(indexFilepath, newIndexFile) + if err != nil { + return fmt.Errorf("failed to move index file %s: %w", indexFile, err) + } + + // Move downloaded files to the current version folder. + for _, artifact := range bundle.Artifacts { + fromFilepath := filepath.Join(downloadDir, artifact.Filename) + toFilepath := filepath.Join(r.dir, artifact.Filename) + err = os.Rename(fromFilepath, toFilepath) + if err != nil { + log.Errorf("failed to move file %s: %s", fromFilepath, err) + } else { + log.Debugf("updates: %s moved", artifact.Filename) + } + + // Special case for linux. + // When installed the portmaster ui path is `/usr/bin/portmaster`. During update the ui will be placed in `/usr/lib/portmaster/portmaster` + // After an update the original binary should be deleted and replaced by symlink + // `/usr/bin/portmaster` -> `/usr/lib/portmaster/portmaster` + if runtime.GOOS == "linux" && artifact.Filename == "portmaster" && artifact.Platform == currentPlatform { + err = r.makeSymlinkForUI() + if err != nil { + log.Errorf("failed to create symlink for the ui: %s", err) + } else { + log.Infof("updates: ui symlink successfully created") + } + } + } + + log.Infof("updates: update complete") + + err = r.CleanOldFiles() + if err != nil { + log.Warningf("updates: error while cleaning old file: %s", err) + } + + return nil +} + +func (r *Registry) CleanOldFiles() error { + err := os.RemoveAll(r.purgeDir) + if err != nil { + return fmt.Errorf("failed to delete folder: %w", err) + } + return nil +} + +func (r *Registry) makeSymlinkForUI() error { + portmasterBinPath := "/usr/bin/portmaster" + _ = os.Remove(portmasterBinPath) + err := os.Symlink(filepath.Join(r.dir, "portmaster"), portmasterBinPath) + if err != nil { + return fmt.Errorf("failed to create symlink: %w", err) + } + return nil +} + +type File struct { + id string + path string + version string + sha256 string +} + +func (f *File) Identifier() string { + return f.id +} + +func (f *File) Path() string { + return f.path +} + +func (f *File) Version() string { + return f.version +} + +func (f *File) Sha256() string { + return f.sha256 +} + +var ErrNotFound error = errors.New("file not found") diff --git a/service/updates/updater.go b/service/updates/updater.go deleted file mode 100644 index f76ed2b9..00000000 --- a/service/updates/updater.go +++ /dev/null @@ -1,107 +0,0 @@ -package updates - -import ( - "fmt" - "os" - "path/filepath" - "runtime" - "strings" - - "github.com/safing/portmaster/base/log" -) - -const ( - defaultFileMode = os.FileMode(0o0644) - executableFileMode = os.FileMode(0o0744) - defaultDirMode = os.FileMode(0o0755) -) - -func switchFolders(updateIndex UpdateIndex, newBundle Bundle) error { - // Create purge dir. - err := os.MkdirAll(updateIndex.PurgeDirectory, defaultDirMode) - if err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - - // Read all files in the current version folder. - files, err := os.ReadDir(updateIndex.Directory) - if err != nil { - return err - } - - // Move current version files into purge folder. - log.Debugf("updates: removing the old version") - for _, file := range files { - currentFilepath := filepath.Join(updateIndex.Directory, file.Name()) - purgePath := filepath.Join(updateIndex.PurgeDirectory, file.Name()) - err := os.Rename(currentFilepath, purgePath) - if err != nil { - return fmt.Errorf("failed to move file %s: %w", currentFilepath, err) - } - } - - // Move the new index file - log.Debugf("updates: installing the new version") - indexFile := filepath.Join(updateIndex.DownloadDirectory, updateIndex.IndexFile) - newIndexFile := filepath.Join(updateIndex.Directory, updateIndex.IndexFile) - err = os.Rename(indexFile, newIndexFile) - if err != nil { - return fmt.Errorf("failed to move index file %s: %w", indexFile, err) - } - - // Move downloaded files to the current version folder. - for _, artifact := range newBundle.Artifacts { - fromFilepath := filepath.Join(updateIndex.DownloadDirectory, artifact.Filename) - toFilepath := filepath.Join(updateIndex.Directory, artifact.Filename) - err = os.Rename(fromFilepath, toFilepath) - if err != nil { - log.Errorf("failed to move file %s: %s", fromFilepath, err) - } else { - log.Debugf("updates: %s moved", artifact.Filename) - } - - // Special case for linux. - // When installed the portmaster ui path is `/usr/bin/portmaster`. During update the ui will be placed in `/usr/lib/portmaster/portmaster` - // After an update the original binary should be deleted and replaced by symlink - // `/usr/bin/portmaster` -> `/usr/lib/portmaster/portmaster` - if runtime.GOOS == "linux" && artifact.Filename == "portmaster" && artifact.Platform == currentPlatform { - err = makeSymlinkForUI(updateIndex.Directory) - if err != nil { - log.Errorf("failed to create symlink for the ui: %s", err) - } else { - log.Infof("ui symlink successfully created") - } - } - } - - log.Debugf("updates: update complete") - - return nil -} - -func deleteUnfinishedDownloads(rootDir string) error { - entries, err := os.ReadDir(rootDir) - if err != nil { - return err - } - for _, e := range entries { - // Check if the current file has the download extension - if !e.IsDir() && strings.HasSuffix(e.Name(), ".download") { - path := filepath.Join(rootDir, e.Name()) - log.Warningf("updates: deleting unfinished download file: %s\n", path) - err := os.Remove(path) - if err != nil { - log.Errorf("updates: failed to delete unfinished download file %s: %s", path, err) - } - } - } - return nil -} - -func makeSymlinkForUI(directory string) error { - err := os.Symlink(filepath.Join(directory, "portmaster"), "/usr/bin/portmaster") - if err != nil { - return fmt.Errorf("failed to create symlink: %w", err) - } - return nil -}