diff --git a/broadcasts/api.go b/broadcasts/api.go new file mode 100644 index 00000000..f855ddfc --- /dev/null +++ b/broadcasts/api.go @@ -0,0 +1,116 @@ +package broadcasts + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/safing/portbase/api" + "github.com/safing/portbase/database" + "github.com/safing/portbase/database/accessor" +) + +func registerAPIEndpoints() error { + if err := api.RegisterEndpoint(api.Endpoint{ + Path: `broadcasts/matching-data`, + Read: api.PermitAdmin, + BelongsTo: module, + StructFunc: handleMatchingData, + Name: "Get Broadcast Notifications Matching Data", + Description: "Returns the data used by the broadcast notifications to match the instance.", + }); err != nil { + return err + } + + if err := api.RegisterEndpoint(api.Endpoint{ + Path: `broadcasts/reset-state`, + Write: api.PermitAdmin, + WriteMethod: http.MethodPost, + BelongsTo: module, + ActionFunc: handleResetState, + Name: "Resets the Broadcast Notification States", + Description: "Delete the cache of Broadcast Notifications, making them appear again.", + }); err != nil { + return err + } + + if err := api.RegisterEndpoint(api.Endpoint{ + Path: `broadcasts/simulate`, + Write: api.PermitAdmin, + WriteMethod: http.MethodPost, + BelongsTo: module, + ActionFunc: handleSimulate, + Name: "Simulate Broadcast Notifications", + Description: "Test broadcast notifications by sending a valid source file in the body.", + Parameters: []api.Parameter{ + { + Method: http.MethodPost, + Field: "state", + Value: "true", + Description: "Check against state when deciding to display a broadcast notification. Acknowledgements are always saved.", + }, + }, + }); err != nil { + return err + } + + return nil +} + +func handleMatchingData(ar *api.Request) (i interface{}, err error) { + return collectData(), nil +} + +func handleResetState(ar *api.Request) (msg string, err error) { + err = db.Delete(broadcastStatesDBKey) + if err != nil { + return "", err + } + return "Reset complete.", nil +} + +func handleSimulate(ar *api.Request) (msg string, err error) { + // Parse broadcast notification data. + broadcasts, err := parseBroadcastSource(ar.InputData) + if err != nil { + return "", fmt.Errorf("failed to parse broadcast notifications update: %w", err) + } + + // Get and marshal matching data. + matchingData := collectData() + matchingJSON, err := json.Marshal(matchingData) + if err != nil { + return "", fmt.Errorf("failed to marshal broadcast notifications matching data: %w", err) + } + matchingDataAccessor := accessor.NewJSONBytesAccessor(&matchingJSON) + + var bss *BroadcastStates + if ar.URL.Query().Get("state") == "true" { + // Get broadcast notification states. + bss, err = getBroadcastStates() + if err != nil { + if !errors.Is(err, database.ErrNotFound) { + return "", fmt.Errorf("failed to get broadcast notifications states: %w", err) + } + bss = newBroadcastStates() + } + } + + // Go through all broadcast nofications and check if they match. + var results []string + for _, bn := range broadcasts.Notifications { + err := handleBroadcast(bn, matchingDataAccessor, bss) + switch { + case err == nil: + results = append(results, fmt.Sprintf("%30s: displayed", bn.id)) + case errors.Is(err, ErrSkip): + results = append(results, fmt.Sprintf("%30s: %s", bn.id, err)) + default: + results = append(results, fmt.Sprintf("FAILED %23s: %s", bn.id, err)) + } + } + + return strings.Join(results, "\n"), nil +} diff --git a/broadcasts/data.go b/broadcasts/data.go new file mode 100644 index 00000000..73da1bb9 --- /dev/null +++ b/broadcasts/data.go @@ -0,0 +1,102 @@ +package broadcasts + +import ( + "time" + + "github.com/safing/portbase/config" + "github.com/safing/portmaster/intel/geoip" + "github.com/safing/portmaster/netenv" + "github.com/safing/portmaster/updates" + "github.com/safing/spn/access" + "github.com/safing/spn/captain" +) + +var portmasterStarted = time.Now() + +func collectData() interface{} { + data := make(map[string]interface{}) + + // Get data about versions. + versions := updates.GetSimpleVersions() + data["Updates"] = versions + data["Version"] = versions.Build.Version + numericVersion, err := MakeNumericVersion(versions.Build.Version) + if err != nil { + data["NumericVersion"] = &DataError{ + Error: err, + } + } else { + data["NumericVersion"] = numericVersion + } + + // Get data about install. + installInfo, err := GetInstallInfo() + if err != nil { + data["Install"] = &DataError{ + Error: err, + } + } else { + data["Install"] = installInfo + } + + // Get global configuration. + data["Config"] = config.GetActiveConfigValues() + + // Get data about device location. + locs, ok := netenv.GetInternetLocation() + if ok && locs.Best().LocationOrNil() != nil { + loc := locs.Best() + data["Location"] = &Location{ + Country: loc.Location.Country.ISOCode, + Coordinates: loc.Location.Coordinates, + ASN: loc.Location.AutonomousSystemNumber, + ASOrg: loc.Location.AutonomousSystemOrganization, + Source: loc.Source, + SourceAccuracy: loc.SourceAccuracy, + } + } + + // Get data about SPN status. + data["SPN"] = captain.GetSPNStatus() + + // Get data about account. + userRecord, err := access.GetUser() + if err != nil { + data["Account"] = &DataError{ + Error: err, + } + } else { + data["Account"] = &Account{ + UserRecord: userRecord, + UpToDate: userRecord.Meta().Modified > time.Now().Add(-7*24*time.Hour).Unix(), + MayUseUSP: userRecord.MayUseSPN(), + } + } + + // Time running. + data["UptimeHours"] = int(time.Since(portmasterStarted).Hours()) + + return data +} + +// Location holds location matching data. +type Location struct { + Country string + Coordinates geoip.Coordinates + ASN uint + ASOrg string + Source netenv.DeviceLocationSource + SourceAccuracy int +} + +// Account holds SPN account matching data. +type Account struct { + *access.UserRecord + UpToDate bool + MayUseUSP bool +} + +// DataError represents an error getting some matching data. +type DataError struct { + Error error +} diff --git a/broadcasts/install_info.go b/broadcasts/install_info.go new file mode 100644 index 00000000..2f667a17 --- /dev/null +++ b/broadcasts/install_info.go @@ -0,0 +1,175 @@ +package broadcasts + +import ( + "errors" + "fmt" + "strconv" + "sync" + "time" + + semver "github.com/hashicorp/go-version" + + "github.com/safing/portbase/database" + "github.com/safing/portbase/database/query" + "github.com/safing/portbase/database/record" + "github.com/safing/portbase/info" + "github.com/safing/portbase/log" +) + +const installInfoDBKey = "core:status/install-info" + +// InstallInfo holds generic info about the install. +type InstallInfo struct { + record.Base + sync.Mutex + + Version string + NumericVersion int64 + + Time time.Time + NumericDate int64 + DaysSinceInstall int64 + UnixTimestamp int64 +} + +// GetInstallInfo returns the install info from the database. +func GetInstallInfo() (*InstallInfo, error) { + r, err := db.Get(installInfoDBKey) + if err != nil { + return nil, err + } + + // Unwrap. + if r.IsWrapped() { + // Only allocate a new struct, if we need it. + newRecord := &InstallInfo{} + err = record.Unwrap(r, newRecord) + if err != nil { + return nil, err + } + return newRecord, nil + } + + // or adjust type + newRecord, ok := r.(*InstallInfo) + if !ok { + return nil, fmt.Errorf("record not of type *InstallInfo, but %T", r) + } + return newRecord, nil +} + +func ensureInstallInfo() { + // Get current install info from database. + installInfo, err := GetInstallInfo() + if err != nil { + installInfo = &InstallInfo{} + if !errors.Is(err, database.ErrNotFound) { + log.Warningf("updates: failed to load install info: %s", err) + } + } + + // Fill in missing data and save. + installInfo.checkAll() + if err := installInfo.save(); err != nil { + log.Warningf("updates: failed to save install info: %s", err) + } +} + +func (ii *InstallInfo) save() error { + if !ii.KeyIsSet() { + ii.SetKey(installInfoDBKey) + } + return db.Put(ii) +} + +func (ii *InstallInfo) checkAll() { + ii.checkVersion() + ii.checkInstallDate() +} + +func (ii *InstallInfo) checkVersion() { + // Check if everything is present. + if ii.Version != "" && ii.NumericVersion > 0 { + return + } + + // Update version information. + versionInfo := info.GetInfo() + ii.Version = versionInfo.Version + + // Update numeric version. + if versionInfo.Version != "" { + numericVersion, err := MakeNumericVersion(versionInfo.Version) + if err != nil { + log.Warningf("updates: failed to make numeric version: %s", err) + } else { + ii.NumericVersion = numericVersion + } + } +} + +// MakeNumericVersion makes a numeric version with the first three version +// segment always using three digits. +func MakeNumericVersion(version string) (numericVersion int64, err error) { + // Parse version string. + ver, err := semver.NewVersion(version) + if err != nil { + return 0, fmt.Errorf("failed to parse core version: %w", err) + } + + // Transform version for numeric representation. + segments := ver.Segments() + for i := 0; i < 3 && i < len(segments); i++ { + segmentNumber := int64(segments[i]) + if segmentNumber > 999 { + segmentNumber = 999 + } + switch i { + case 0: + numericVersion += segmentNumber * 1000000 + case 1: + numericVersion += segmentNumber * 1000 + case 2: + numericVersion += segmentNumber + } + } + + return numericVersion, nil +} + +func (ii *InstallInfo) checkInstallDate() { + // Check if everything is present. + if ii.UnixTimestamp > 0 && + ii.NumericDate > 0 && + ii.DaysSinceInstall > 0 && + !ii.Time.IsZero() { + return + } + + // Find oldest created database entry and use it as install time. + oldest := time.Now().Unix() + it, err := db.Query(query.New("core")) + if err != nil { + log.Warningf("updates: failed to create iterator for searching DB for install time: %s", err) + return + } + defer it.Cancel() + for r := range it.Next { + if oldest > r.Meta().Created { + oldest = r.Meta().Created + } + } + + // Set data. + ii.UnixTimestamp = oldest + ii.Time = time.Unix(oldest, 0) + ii.DaysSinceInstall = int64(time.Since(ii.Time).Hours()) / 24 + + // Transform date for numeric representation. + numericDate, err := strconv.ParseInt(ii.Time.Format("20060102"), 10, 64) + if err != nil { + log.Warningf("updates: failed to make numeric date from %s: %s", ii.Time, err) + } else { + ii.NumericDate = numericDate + } +} diff --git a/broadcasts/module.go b/broadcasts/module.go new file mode 100644 index 00000000..360bc912 --- /dev/null +++ b/broadcasts/module.go @@ -0,0 +1,46 @@ +package broadcasts + +import ( + "sync" + "time" + + "github.com/safing/portbase/database" + "github.com/safing/portbase/modules" +) + +var ( + module *modules.Module + + db = database.NewInterface(&database.Options{ + Local: true, + Internal: true, + }) + + startOnce sync.Once +) + +func init() { + module = modules.Register("broadcasts", prep, start, nil, "updates", "netenv", "notifications") +} + +func prep() error { + // Register API endpoints. + if err := registerAPIEndpoints(); err != nil { + return err + } + + return nil +} + +func start() error { + // Ensure the install info is up to date. + ensureInstallInfo() + + // Start broadcast notifier task. + startOnce.Do(func() { + module.NewTask("broadcast notifier", broadcastNotify). + Repeat(10 * time.Minute).Queue() + }) + + return nil +} diff --git a/broadcasts/notify.go b/broadcasts/notify.go new file mode 100644 index 00000000..c7b7d661 --- /dev/null +++ b/broadcasts/notify.go @@ -0,0 +1,285 @@ +package broadcasts + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "strings" + "sync" + "time" + + "github.com/ghodss/yaml" + + "github.com/safing/portbase/database" + "github.com/safing/portbase/database/accessor" + "github.com/safing/portbase/database/query" + "github.com/safing/portbase/log" + "github.com/safing/portbase/modules" + "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/updates" +) + +const ( + broadcastsResourcePath = "intel/portmaster/notifications.yaml" + + broadcastNotificationIDPrefix = "broadcasts:" + + minRepeatDuration = 1 * time.Hour +) + +// Errors. +var ( + ErrSkip = errors.New("broadcast skipped") + ErrSkipDoesNotMatch = fmt.Errorf("%w: does not match", ErrSkip) + ErrSkipAlreadyActive = fmt.Errorf("%w: already active", ErrSkip) + ErrSkipAlreadyShown = fmt.Errorf("%w: already shown", ErrSkip) + ErrSkipRemovedByMismatch = fmt.Errorf("%w: removed due to mismatch", ErrSkip) + ErrSkipRemovedBySource = fmt.Errorf("%w: removed by source", ErrSkip) +) + +// BroadcastNotifications holds the data structure of the broadcast +// notifications update file. +type BroadcastNotifications struct { + Notifications map[string]*BroadcastNotification +} + +// BroadcastNotification is a single broadcast notification. +type BroadcastNotification struct { + *notifications.Notification + id string + + // Match holds a query string that needs to match the local matching data in + // order for the broadcast to be displayed. + Match string + matchingQuery *query.Query + // AttachToModule signifies if the broadcast notification should be attached to the module. + AttachToModule bool + // Remove signifies that the broadcast should be canceled and its state removed. + Remove bool + // Permanent signifies that the broadcast cannot be acknowledge by the user + // and remains in the UI indefinitely. + Permanent bool + // Repeat specifies a duration after which the broadcast should be shown again. + Repeat string + repeatDuration time.Duration +} + +func broadcastNotify(ctx context.Context, t *modules.Task) error { + // Get broadcast notifications file, load it from disk and parse it. + broadcastsResource, err := updates.GetFile(broadcastsResourcePath) + if err != nil { + return fmt.Errorf("failed to get broadcast notifications update: %w", err) + } + broadcastsData, err := ioutil.ReadFile(broadcastsResource.Path()) + if err != nil { + return fmt.Errorf("failed to load broadcast notifications update: %w", err) + } + broadcasts, err := parseBroadcastSource(broadcastsData) + if err != nil { + return fmt.Errorf("failed to parse broadcast notifications update: %w", err) + } + + // Get and marshal matching data. + matchingData := collectData() + matchingJSON, err := json.Marshal(matchingData) + if err != nil { + return fmt.Errorf("failed to marshal broadcast notifications matching data: %w", err) + } + matchingDataAccessor := accessor.NewJSONBytesAccessor(&matchingJSON) + + // Get broadcast notification states. + bss, err := getBroadcastStates() + if err != nil { + if !errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("failed to get broadcast notifications states: %w", err) + } + bss = newBroadcastStates() + } + + // Go through all broadcast nofications and check if they match. + for _, bn := range broadcasts.Notifications { + err := handleBroadcast(bn, matchingDataAccessor, bss) + switch { + case err == nil: + log.Infof("broadcasts: displaying broadcast %s", bn.id) + case errors.Is(err, ErrSkip): + log.Tracef("broadcasts: skipped displaying broadcast %s: %s", bn.id, err) + default: + log.Warningf("broadcasts: failed to handle broadcast %s: %s", bn.id, err) + } + } + + return nil +} + +func parseBroadcastSource(yamlData []byte) (*BroadcastNotifications, error) { + // Parse data. + broadcasts := &BroadcastNotifications{} + err := yaml.Unmarshal(yamlData, broadcasts) + if err != nil { + return nil, err + } + + // Add IDs to struct for easier handling. + for id, bn := range broadcasts.Notifications { + bn.id = id + + // Parse matching query. + if bn.Match != "" { + q, err := query.ParseQuery("query / where " + bn.Match) + if err != nil { + return nil, fmt.Errorf("failed to parse query of broadcast notification %s: %w", bn.id, err) + } + bn.matchingQuery = q + } + + // Parse the repeat duration. + if bn.Repeat != "" { + duration, err := time.ParseDuration(bn.Repeat) + if err != nil { + return nil, fmt.Errorf("failed to parse repeat duration of broadcast notification %s: %w", bn.id, err) + } + bn.repeatDuration = duration + // Raise duration to minimum. + if bn.repeatDuration < minRepeatDuration { + bn.repeatDuration = minRepeatDuration + } + } + } + + return broadcasts, nil +} + +func handleBroadcast(bn *BroadcastNotification, matchingDataAccessor accessor.Accessor, bss *BroadcastStates) error { + // Check if broadcast was already shown. + if bss != nil { + state, ok := bss.States[bn.id] + switch { + case !ok || state.Read.IsZero(): + // Was never shown, continue. + case bn.repeatDuration == 0 && !state.Read.IsZero(): + // Was already shown and is not repeated, skip. + return ErrSkipAlreadyShown + case bn.repeatDuration > 0 && time.Now().Add(-bn.repeatDuration).After(state.Read): + // Was already shown, but should be repeated now, continue. + } + } + + // Check if broadcast should be removed. + if bn.Remove { + removeBroadcast(bn, bss) + return ErrSkipRemovedBySource + } + + // Skip if broadcast does not match. + if bn.matchingQuery != nil && !bn.matchingQuery.MatchesAccessor(matchingDataAccessor) { + removed := removeBroadcast(bn, bss) + if removed { + return ErrSkipRemovedByMismatch + } + return ErrSkipDoesNotMatch + } + + // Check if there is already an active notification for this. + eventID := broadcastNotificationIDPrefix + bn.id + n := notifications.Get(eventID) + if n != nil { + // Already active! + return ErrSkipAlreadyActive + } + + // Prepare notification for displaying. + n = bn.Notification + n.EventID = eventID + n.GUID = "" + n.State = "" + n.SelectedActionID = "" + + // It is okay to edit the notification, as they are loaded from the file every time. + // Add dismiss button if the notification is not permanent. + if !bn.Permanent { + n.AvailableActions = append(n.AvailableActions, ¬ifications.Action{ + ID: "ack", + Text: "Got it!", + }) + } + n.SetActionFunction(markBroadcastAsRead) + + // Display notification. + n.Save() + + // Attach to module to raise more awareness. + if bn.AttachToModule { + n.AttachToModule(module) + } + + return nil +} + +func removeBroadcast(bn *BroadcastNotification, bss *BroadcastStates) (removed bool) { + // Remove any active notification. + n := notifications.Get(broadcastNotificationIDPrefix + bn.id) + if n != nil { + removed = true + n.Delete() + } + + // Remove any state. + if bss != nil { + delete(bss.States, bn.id) + } + + return +} + +var savingBroadcastStateLock sync.Mutex + +func markBroadcastAsRead(ctx context.Context, n *notifications.Notification) error { + // Lock persisting broadcast state. + savingBroadcastStateLock.Lock() + defer savingBroadcastStateLock.Unlock() + + // Get notification data. + var broadcastID, actionID string + func() { + n.Lock() + defer n.Unlock() + broadcastID = strings.TrimPrefix(n.EventID, broadcastNotificationIDPrefix) + actionID = n.SelectedActionID + }() + + // Check response. + switch actionID { + case "ack": + case "": + return fmt.Errorf("no action ID for %s", broadcastID) + default: + return fmt.Errorf("unexpected action ID for %s: %s", broadcastID, actionID) + } + + // Get broadcast notification states. + bss, err := getBroadcastStates() + if err != nil { + if !errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("failed to get broadcast notifications states: %w", err) + } + bss = newBroadcastStates() + } + + // Get state for this notification. + bs, ok := bss.States[broadcastID] + if !ok { + bs = &BroadcastState{} + bss.States[broadcastID] = bs + } + + // Delete to allow for timely repeats. + n.Delete() + + // Mark as read and save to DB. + log.Infof("broadcasts: user acknowledged broadcast %s", broadcastID) + bs.Read = time.Now() + return bss.save() +} diff --git a/broadcasts/state.go b/broadcasts/state.go new file mode 100644 index 00000000..afe8994c --- /dev/null +++ b/broadcasts/state.go @@ -0,0 +1,64 @@ +package broadcasts + +import ( + "fmt" + "sync" + "time" + + "github.com/safing/portbase/database/record" +) + +const broadcastStatesDBKey = "core:broadcasts/state" + +// BroadcastStates holds states for broadcast notifications. +type BroadcastStates struct { + record.Base + sync.Mutex + + States map[string]*BroadcastState +} + +// BroadcastState holds state for a single broadcast notifications. +type BroadcastState struct { + Read time.Time +} + +func (bss *BroadcastStates) save() error { + return db.Put(bss) +} + +// getbroadcastStates returns the broadcast states from the database. +func getBroadcastStates() (*BroadcastStates, error) { + r, err := db.Get(broadcastStatesDBKey) + if err != nil { + return nil, err + } + + // Unwrap. + if r.IsWrapped() { + // Only allocate a new struct, if we need it. + newRecord := &BroadcastStates{} + err = record.Unwrap(r, newRecord) + if err != nil { + return nil, err + } + return newRecord, nil + } + + // or adjust type + newRecord, ok := r.(*BroadcastStates) + if !ok { + return nil, fmt.Errorf("record not of type *BroadcastStates, but %T", r) + } + return newRecord, nil +} + +// newBroadcastStates returns a new BroadcastStates. +func newBroadcastStates() *BroadcastStates { + bss := &BroadcastStates{ + States: make(map[string]*BroadcastState), + } + bss.SetKey(broadcastStatesDBKey) + + return bss +} diff --git a/core/core.go b/core/core.go index 80bd6343..8b70cba6 100644 --- a/core/core.go +++ b/core/core.go @@ -7,6 +7,7 @@ import ( "github.com/safing/portbase/modules" "github.com/safing/portbase/modules/subsystems" + _ "github.com/safing/portmaster/broadcasts" _ "github.com/safing/portmaster/netenv" _ "github.com/safing/portmaster/status" _ "github.com/safing/portmaster/ui" @@ -25,7 +26,7 @@ var ( ) func init() { - module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "interception", "compat") + module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "interception", "compat", "broadcasts") subsystems.Register( "core", "Core",