diff --git a/firewall/master.go b/firewall/master.go index bc96c6f0..b8dadb01 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -48,7 +48,7 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) connection.Deny("no profile set") return } - profileSet.Update(status.CurrentSecurityLevel()) + profileSet.Update(status.ActiveSecurityLevel()) // check for any network access if !profileSet.CheckFlag(profile.Internet) && !profileSet.CheckFlag(profile.LAN) { @@ -152,7 +152,7 @@ func DecideOnConnectionAfterIntel(connection *network.Connection, fqdn string, r connection.Deny("no profile") return rrCache } - profileSet.Update(status.CurrentSecurityLevel()) + profileSet.Update(status.ActiveSecurityLevel()) // TODO: Stamp integration @@ -188,7 +188,7 @@ func DecideOnConnection(connection *network.Connection, pkt packet.Packet) { connection.Deny("no profile") return } - profileSet.Update(status.CurrentSecurityLevel()) + profileSet.Update(status.ActiveSecurityLevel()) // check connection type switch connection.Domain { @@ -286,7 +286,7 @@ func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet link.Block("no profile") return } - profileSet.Update(status.CurrentSecurityLevel()) + profileSet.Update(status.ActiveSecurityLevel()) // get host var domainOrIP string diff --git a/intel/resolve.go b/intel/resolve.go index a389b4f0..87bf0249 100644 --- a/intel/resolve.go +++ b/intel/resolve.go @@ -251,13 +251,13 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel uint8) *RRCac func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype dns.Type, securityLevel uint8) (*RRCache, bool) { // skip if not allowed in current security level - if resolver.AllowedSecurityLevel < status.CurrentSecurityLevel() || resolver.AllowedSecurityLevel < securityLevel { - log.Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, status.CurrentSecurityLevel(), securityLevel) + if resolver.AllowedSecurityLevel < status.ActiveSecurityLevel() || resolver.AllowedSecurityLevel < securityLevel { + log.Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, status.ActiveSecurityLevel(), securityLevel) return nil, false } // skip if not security level denies assigned dns servers if doNotUseAssignedNameservers(securityLevel) && resolver.Source == "dhcp" { - log.Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d", resolver, status.CurrentSecurityLevel(), securityLevel) + log.Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d", resolver, status.ActiveSecurityLevel(), securityLevel) return nil, false } // check if failed recently diff --git a/status/database.go b/status/database.go index 935e4049..8aa34ef7 100644 --- a/status/database.go +++ b/status/database.go @@ -26,15 +26,18 @@ func (sh *statusHook) UsesPrePut() bool { // PrePut implements the Hook interface. func (sh *statusHook) PrePut(r record.Record) (record.Record, error) { + // record is already locked! + newStatus, err := EnsureSystemStatus(r) if err != nil { return nil, err } - newStatus.Lock() - defer newStatus.Unlock() // apply applicable settings - setSelectedSecurityLevel(newStatus.SelectedSecurityLevel) + if SelectedSecurityLevel() != newStatus.SelectedSecurityLevel { + go setSelectedSecurityLevel(newStatus.SelectedSecurityLevel) + } + // TODO: allow setting of Gate17 status (on/off) // return original status diff --git a/status/module.go b/status/module.go index 4aebbc0c..96e0b093 100644 --- a/status/module.go +++ b/status/module.go @@ -11,7 +11,7 @@ var ( ) func init() { - modules.Register("status", nil, start, stop) + modules.Register("status", nil, start, stop, "database") } func start() error { diff --git a/updates/get.go b/updates/get.go index d72ac0a3..7f3a2160 100644 --- a/updates/get.go +++ b/updates/get.go @@ -33,7 +33,7 @@ func getLatestFilePath(identifier string) (versionedFilePath, version string, st updatesLock.RLock() version, ok = stableUpdates[identifier] if !ok { - version, ok = latestUpdates[identifier] + version, ok = localUpdates[identifier] if !ok { log.Tracef("updates: file %s does not exist", identifier) return "", "", false, false @@ -58,6 +58,7 @@ func loadOrFetchFile(identifier string) (*File, error) { realFilePath := filepath.Join(updateStoragePath, versionedFilePath) if _, err := os.Stat(realFilePath); err == nil { // file exists + updateUsedStatus(identifier, version) return newFile(realFilePath, version, stable), nil } @@ -69,6 +70,7 @@ func loadOrFetchFile(identifier string) (*File, error) { if err != nil { log.Tracef("updates: failed to download %s: %s, retrying (%d)", versionedFilePath, err, tries+1) } else { + updateUsedStatus(identifier, version) return newFile(realFilePath, version, stable), nil } } diff --git a/updates/latest.go b/updates/latest.go index fc41b252..d77c529d 100644 --- a/updates/latest.go +++ b/updates/latest.go @@ -17,24 +17,26 @@ import ( var ( stableUpdates = make(map[string]string) betaUpdates = make(map[string]string) - latestUpdates = make(map[string]string) + localUpdates = make(map[string]string) updatesLock sync.RWMutex ) // ReloadLatest reloads available updates from disk. func ReloadLatest() error { - newLatestUpdates := make(map[string]string) + newLocalUpdates := make(map[string]string) // all - new, err1 := ScanForLatest(filepath.Join(updateStoragePath, "all"), false) + prefix := "all" + new, err1 := ScanForLatest(filepath.Join(updateStoragePath, prefix), false) for key, val := range new { - newLatestUpdates[key] = val + newLocalUpdates[filepath.Join(prefix, key)] = val } // os_platform - new, err2 := ScanForLatest(filepath.Join(updateStoragePath, fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH)), false) + prefix = fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH) + new, err2 := ScanForLatest(filepath.Join(updateStoragePath, prefix), false) for key, val := range new { - newLatestUpdates[key] = val + newLocalUpdates[filepath.Join(prefix, key)] = val } if err1 != nil && err2 != nil { @@ -43,12 +45,12 @@ func ReloadLatest() error { log.Tracef("updates: loading latest updates:") - for key, val := range newLatestUpdates { + for key, val := range newLocalUpdates { log.Tracef("updates: %s v%s", key, val) } updatesLock.Lock() - latestUpdates = newLatestUpdates + localUpdates = newLocalUpdates updatesLock.Unlock() log.Tracef("updates: load complete") @@ -60,6 +62,11 @@ func ReloadLatest() error { } } + // update version status + updatesLock.RLock() + defer updatesLock.RUnlock() + updateStatus(versionClassLocal, localUpdates) + return nil } @@ -137,5 +144,10 @@ func loadIndexesFromDisk() error { stableUpdates = newStableUpdates updatesLock.Unlock() + // update version status + updatesLock.RLock() + defer updatesLock.RUnlock() + updateStatus(versionClassStable, stableUpdates) + return nil } diff --git a/updates/latest_test.go b/updates/latest_test.go index 542bdb45..b3afc600 100644 --- a/updates/latest_test.go +++ b/updates/latest_test.go @@ -32,14 +32,14 @@ func testLoadLatestScope(t *testing.T, basePath, filePath, expectedIdentifier, e return } for key, val := range latest { - latestUpdates[key] = val + localUpdates[key] = val } // test result - version, ok := latestUpdates[expectedIdentifier] + version, ok := localUpdates[expectedIdentifier] if !ok { t.Errorf("identifier %s not in map", expectedIdentifier) - t.Errorf("current map: %v", latestUpdates) + t.Errorf("current map: %v", localUpdates) } if version != expectedVersion { t.Errorf("unexpected version for %s: %s", filePath, version) diff --git a/updates/main.go b/updates/main.go index 4f6695b7..98ae95e2 100644 --- a/updates/main.go +++ b/updates/main.go @@ -15,7 +15,7 @@ var ( ) func init() { - modules.Register("updates", prep, start, nil, "database") + modules.Register("updates", prep, start, nil, "global", "database") } func prep() error { @@ -30,7 +30,12 @@ func prep() error { } func start() error { - err := ReloadLatest() + err := initUpdateStatusHook() + if err != nil { + return err + } + + err = ReloadLatest() if err != nil { return err } diff --git a/updates/status.go b/updates/status.go new file mode 100644 index 00000000..2333c66c --- /dev/null +++ b/updates/status.go @@ -0,0 +1,129 @@ +package updates + +import ( + "errors" + "sync" + + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/log" + "github.com/tevino/abool" +) + +// database key for update information +const ( + statusDBKey = "core:status/updates" +) + +// version type +type versionClass int + +const ( + versionClassLocal versionClass = iota + versionClassStable + versionClassBeta +) + +// working vars +var ( + status *versionStatus + + statusDB = database.NewInterface(nil) + statusHook *database.RegisteredHook + enableStatusSave = abool.NewBool(false) +) + +func init() { + status = &versionStatus{ + Versions: make(map[string]*versionStatusEntry), + } + status.SetKey(statusDBKey) +} + +// versionStatus holds update version status information. +type versionStatus struct { + record.Base + sync.Mutex + Versions map[string]*versionStatusEntry +} + +func (vs *versionStatus) save() { + enableStatusSave.SetTo(true) + err := statusDB.Put(vs) + if err != nil { + log.Warningf("could not save updates version status: %s", err) + } +} + +// versionStatusEntry holds information about the update status of a module. +type versionStatusEntry struct { + LastVersionUsed string + LocalVersion string + StableVersion string + BetaVersion string +} + +func updateUsedStatus(identifier string, version string) { + status.Lock() + defer status.Unlock() + + entry, ok := status.Versions[identifier] + if !ok { + entry = &versionStatusEntry{} + status.Versions[identifier] = entry + } + + entry.LastVersionUsed = version + + log.Tracef("updates: updated last used version of %s: %s", identifier, version) + + go status.save() +} + +func updateStatus(vClass versionClass, state map[string]string) { + status.Lock() + defer status.Unlock() + + for identifier, version := range state { + + entry, ok := status.Versions[identifier] + if !ok { + entry = &versionStatusEntry{} + status.Versions[identifier] = entry + } + + switch vClass { + case versionClassLocal: + entry.LocalVersion = version + case versionClassStable: + entry.StableVersion = version + case versionClassBeta: + entry.BetaVersion = version + } + } + + go status.save() +} + +type updateStatusHook struct { + database.HookBase +} + +// UsesPrePut implements the Hook interface. +func (sh *updateStatusHook) UsesPrePut() bool { + return true +} + +// PrePut implements the Hook interface. +func (sh *updateStatusHook) PrePut(r record.Record) (record.Record, error) { + if enableStatusSave.SetToIf(true, false) { + return r, nil + } + return nil, errors.New("may only be changed by updates module") +} + +func initUpdateStatusHook() (err error) { + statusHook, err = database.RegisterHook(query.New(statusDBKey), &updateStatusHook{}) + return err +} diff --git a/updates/updater.go b/updates/updater.go index 56a39738..91769679 100644 --- a/updates/updater.go +++ b/updates/updater.go @@ -53,7 +53,7 @@ func checkForUpdates() error { log.Tracef("updates: updating existing files") updatesLock.RLock() for identifier, newVersion := range newStableUpdates { - oldVersion, ok := latestUpdates[identifier] + oldVersion, ok := localUpdates[identifier] if ok && newVersion != oldVersion { filePath := getVersionedPath(identifier, newVersion) @@ -84,5 +84,10 @@ func checkForUpdates() error { log.Warningf("updates: failed to save new version of stable.json: %s", err) } + // update version status + updatesLock.RLock() + defer updatesLock.RUnlock() + updateStatus(versionClassStable, stableUpdates) + return nil }