diff --git a/service/updates/bundle.go b/service/updates/bundle.go index c2de5a92..e590446d 100644 --- a/service/updates/bundle.go +++ b/service/updates/bundle.go @@ -115,12 +115,12 @@ func (bundle Bundle) CopyMatchingFilesFromCurrent(current Bundle, currentDir, ne return nil } -func (bundle Bundle) DownloadAndVerify(dir string) { - client := http.Client{} +func (bundle Bundle) DownloadAndVerify(client *http.Client, dir string) { + // Make sure dir exists + _ = os.MkdirAll(dir, defaultDirMode) + for _, artifact := range bundle.Artifacts { filePath := filepath.Join(dir, artifact.Filename) - // TODO(vladimir): is this needed? - _ = os.MkdirAll(filepath.Dir(filePath), defaultDirMode) // Check file is already downloaded and valid. exists, _ := checkIfFileIsValid(filePath, artifact) @@ -130,7 +130,7 @@ func (bundle Bundle) DownloadAndVerify(dir string) { } // Download artifact - err := processArtifact(&client, artifact, filePath) + err := processArtifact(client, artifact, filePath) if err != nil { log.Errorf("updates: %s", err) } @@ -202,9 +202,7 @@ func processArtifact(client *http.Client, artifact Artifact, filePath string) er // Verify hash := sha256.Sum256(content) if !bytes.Equal(providedHash, hash[:]) { - // FIXME(vladimir): just for testing. Make it an error. - err = fmt.Errorf("failed to verify artifact: %s", artifact.Filename) - log.Debugf("updates: %s", err) + return fmt.Errorf("failed to verify artifact: %s", artifact.Filename) } // Save diff --git a/service/updates/index.go b/service/updates/index.go index 19713064..b753450c 100644 --- a/service/updates/index.go +++ b/service/updates/index.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "os" + "path/filepath" "github.com/safing/portmaster/base/log" ) @@ -19,10 +20,11 @@ type UpdateIndex struct { AutoApply bool } -func (ui *UpdateIndex) DownloadIndexFile() (err error) { +func (ui *UpdateIndex) DownloadIndexFile(client *http.Client) (err error) { + // Make sure dir exists _ = os.MkdirAll(ui.DownloadDirectory, defaultDirMode) for _, url := range ui.IndexURLs { - err = ui.downloadIndexFileFromURL(url) + err = ui.downloadIndexFileFromURL(client, url) if err != nil { log.Warningf("updates: failed while downloading index file %s", err) continue @@ -34,20 +36,27 @@ func (ui *UpdateIndex) DownloadIndexFile() (err error) { return } -func (ui *UpdateIndex) downloadIndexFileFromURL(url string) error { - client := http.Client{} +func (ui *UpdateIndex) downloadIndexFileFromURL(client *http.Client, url string) error { + // Request the index file resp, err := client.Get(url) if err != nil { return fmt.Errorf("failed GET request to %s: %w", url, err) } defer func() { _ = resp.Body.Close() }() - filePath := fmt.Sprintf("%s/%s", ui.DownloadDirectory, ui.IndexFile) + + // Check the status code + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("received error from the server status code: %s", resp.Status) + } + // Create file + filePath := filepath.Join(ui.DownloadDirectory, ui.IndexFile) file, err := os.Create(filePath) if err != nil { return err } defer func() { _ = file.Close() }() + // Write body of the response _, err = io.Copy(file, resp.Body) if err != nil { return err diff --git a/service/updates/module.go b/service/updates/module.go index ac15d1f8..99c110da 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -3,6 +3,7 @@ package updates import ( "errors" "fmt" + "net/http" "os" "time" @@ -131,7 +132,8 @@ func (reg *Updates) processBundle(bundle *Bundle) { } func (u *Updates) checkForUpdates(_ *mgr.WorkerCtx) error { - err := u.updateIndex.DownloadIndexFile() + httpClient := http.Client{} + err := u.updateIndex.DownloadIndexFile(&httpClient) if err != nil { return fmt.Errorf("failed to download index file: %s", err) } @@ -153,7 +155,7 @@ func (u *Updates) checkForUpdates(_ *mgr.WorkerCtx) error { } log.Infof("updates: check complete: downloading new version: %s %s", u.updateBundle.Name, u.updateBundle.Version) - err = u.downloadUpdates() + err = u.downloadUpdates(&httpClient) if err != nil { log.Errorf("updates: failed to download bundle: %s", err) } else { @@ -181,7 +183,7 @@ func (u *Updates) checkVersionIncrement() (bool, error) { return downloadVersion.GreaterThan(currentVersion), nil } -func (u *Updates) downloadUpdates() error { +func (u *Updates) downloadUpdates(client *http.Client) error { if u.updateBundle == nil { // checkForUpdates needs to be called before this. return fmt.Errorf("no valid update bundle found") @@ -191,7 +193,7 @@ func (u *Updates) downloadUpdates() error { if err != nil { log.Warningf("updates: error while coping file from current to update: %s", err) } - u.updateBundle.DownloadAndVerify(u.updateIndex.DownloadDirectory) + u.updateBundle.DownloadAndVerify(client, u.updateIndex.DownloadDirectory) return nil } diff --git a/service/updates/updater.go b/service/updates/updater.go index 3c6b2892..4a0358ec 100644 --- a/service/updates/updater.go +++ b/service/updates/updater.go @@ -59,20 +59,20 @@ func switchFolders(updateIndex UpdateIndex, newBundle Bundle) error { } func deleteUnfinishedDownloads(rootDir string) error { - return filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - + entries, err := os.ReadDir(rootDir) + if err != nil { + return err + } + for _, e := range entries { // Check if the current file has the download extension - if !info.IsDir() && strings.HasSuffix(info.Name(), ".download") { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".download") { + path := filepath.Join(rootDir, e.Name()) log.Warningf("updates: deleting unfinished download file: %s\n", path) err := os.Remove(path) if err != nil { - log.Errorf("updates: failed to delete unfinished download file %s: %w", path, err) + log.Errorf("updates: failed to delete unfinished download file %s: %s", path, err) } } - - return nil - }) + } + return nil }