diff --git a/.gitignore b/.gitignore index 7332997a..b3c96ff6 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,5 @@ _testmain.go # Custom dev scripts win_dev_* +go.work +go.work.sum diff --git a/netquery/database.go b/netquery/database.go index df038103..397a0030 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -2,18 +2,23 @@ package netquery import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" + "path" "sort" "strings" "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/jackc/puddle/v2" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" + "github.com/safing/portbase/dataroot" "github.com/safing/portbase/log" "github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/network" @@ -46,6 +51,7 @@ type ( Schema *orm.TableSchema readConnPool *puddle.Pool[*sqlite.Conn] + historyPath string l sync.Mutex writeConn *sqlite.Conn @@ -82,7 +88,9 @@ type ( Latitude float64 `sqlite:"latitude"` Longitude float64 `sqlite:"longitude"` Scope netutils.IPScope `sqlite:"scope"` - Verdict network.Verdict `sqlite:"verdict"` + WorstVerdict network.Verdict `sqlite:"worst_verdict"` + ActiveVerdict network.Verdict `sqlite:"verdict"` + FirewallVerdict network.Verdict `sqlite:"firewall_verdict"` Started time.Time `sqlite:"started,text,time"` Ended *time.Time `sqlite:"ended,text,time"` Tunneled bool `sqlite:"tunneled"` @@ -93,6 +101,8 @@ type ( Allowed *bool `sqlite:"allowed"` ProfileRevision int `sqlite:"profile_revision"` ExitNode *string `sqlite:"exit_node"` + BWIncoming uint64 `sqlite:"bw_incoming,default=0"` + BWOutgoing uint64 `sqlite:"bw_outgoing,default=0"` // TODO(ppacher): support "NOT" in search query to get rid of the following helper fields Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL" @@ -108,21 +118,27 @@ type ( // (see Execute). To perform database writes use either Save() or ExecuteWrite(). // Note that write connections are serialized by the Database object before being // handed over to SQLite. -func New(path string) (*Database, error) { +func New(dbPath string) (*Database, error) { + historyParentDir := dataroot.Root().ChildDir("databases", 0o700) + if err := historyParentDir.Ensure(); err != nil { + return nil, fmt.Errorf("failed to ensure database directory exists: %w", err) + } + + historyPath := "file://" + path.Join(historyParentDir.Path, "history.db") + constructor := func(ctx context.Context) (*sqlite.Conn, error) { c, err := sqlite.OpenConn( - path, + dbPath, sqlite.OpenReadOnly, - sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit. sqlite.OpenSharedCache, //sqlite.OpenMemory, sqlite.OpenURI, ) if err != nil { - return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", path, err) + return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", dbPath, err) } - if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE 'file:///tmp/history.db?mode=ro' AS history", nil); err != nil { + if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE '"+historyPath+"?mode=ro' AS history", nil); err != nil { return nil, fmt.Errorf("failed to attach history database: %w", err) } @@ -150,23 +166,23 @@ func New(path string) (*Database, error) { } writeConn, err := sqlite.OpenConn( - path, + dbPath, sqlite.OpenCreate, sqlite.OpenReadWrite, - sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit. sqlite.OpenWAL, sqlite.OpenSharedCache, //sqlite.OpenMemory, sqlite.OpenURI, ) if err != nil { - return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) + return nil, fmt.Errorf("failed to open sqlite at %s: %w", dbPath, err) } return &Database{ readConnPool: pool, Schema: schema, writeConn: writeConn, + historyPath: historyPath, }, nil } @@ -197,9 +213,7 @@ func (db *Database) ApplyMigrations() error { db.l.Lock() defer db.l.Unlock() - // Attach the history database - log.Errorf("attaching database history") - if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE 'file:///tmp/history.db?mode=rwc' AS 'history';", nil); err != nil { + if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE '"+db.historyPath+"?mode=rwc' AS 'history';", nil); err != nil { return fmt.Errorf("failed to attach history database: %w", err) } @@ -207,7 +221,7 @@ func (db *Database) ApplyMigrations() error { for _, dbName := range dbNames { // get the create-table SQL statement from the inferred schema sql := db.Schema.CreateStatement(dbName, true) - log.Errorf("creating table schema for database %q", dbName) + log.Debugf("creating table schema for database %q", dbName) // execute the SQL if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil { @@ -285,7 +299,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) { return result[0].Count, nil } -// Cleanup removes all connections that have ended before threshold. +// Cleanup removes all connections that have ended before threshold from the live database. // // NOTE(ppacher): there is no easy way to get the number of removed // rows other than counting them in a first step. Though, that's @@ -323,6 +337,18 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro return result[0].Count, nil } +func (db *Database) RemoveAllHistoryData(ctx context.Context) error { + query := fmt.Sprintf("DELETE FROM %s.connections", HistoryDatabase) + return db.ExecuteWrite(ctx, query) +} + +func (db *Database) RemoveHistoryForProfile(ctx context.Context, profileID string) error { + query := fmt.Sprintf("DELETE FROM %s.connections WHERE profile = :profile", HistoryDatabase) + return db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{ + ":profile": profileID, + })) +} + // dumpTo is a simple helper method that dumps all rows stored in the SQLite database // as JSON to w. // Any error aborts dumping rows and is returned. @@ -350,13 +376,74 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:un return enc.Encode(conns) } +// MarkAllHistoryConnectionsEnded marks all connections in the history database as ended. +func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error { + query := fmt.Sprintf("UPDATE %s.connections SET active = FALSE, ended = :ended WHERE active = TRUE", HistoryDatabase) + + if err := db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{ + ":ended": time.Now().Format(orm.SqliteTimeFormat), + })); err != nil { + return err + } + + return nil +} + +func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, incoming *uint64, outgoing *uint64) error { + data := connID + "-" + processKey + hash := sha256.Sum256([]byte(data)) + dbConnId := hex.EncodeToString(hash[:]) + + params := map[string]any{ + ":id": dbConnId, + } + + parts := []string{} + if incoming != nil { + parts = append(parts, "bw_incoming = :bw_incoming") + params[":bw_incoming"] = *incoming + } + + if outgoing != nil { + parts = append(parts, "bw_outgoing = :bw_outgoing") + params[":bw_outgoing"] = *outgoing + } + + updateSet := strings.Join(parts, ", ") + + updateStmts := []string{ + fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, LiveDatabase, updateSet), + } + + if enableHistory { + updateStmts = append(updateStmts, + fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, HistoryDatabase, updateSet), + ) + } + + merr := new(multierror.Error) + for _, stmt := range updateStmts { + if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil { + merr.Errors = append(merr.Errors, err) + } + } + + return merr.ErrorOrNil() +} + // Save inserts the connection conn into the SQLite database. If conn // already exists the table row is updated instead. // // Save uses the database write connection instead of relying on the // connection pool. func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error { - connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig) + // convert the connection to a param map where each key is already translated + // to the sql column name. We also skip bw_incoming and bw_outgoing since those + // will be updated independenly from the connection object. + connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig, []string{ + "bw_incoming", + "bw_outgoing", + }) if err != nil { return fmt.Errorf("failed to encode connection for SQL: %w", err) } @@ -387,10 +474,10 @@ func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) err // TODO(ppacher): make sure this one can be cached to speed up inserting // and save some CPU cycles for the user - dbNames := []string{"main"} + dbNames := []DatabaseName{LiveDatabase} if enableHistory { - dbNames = append(dbNames, "history") + dbNames = append(dbNames, HistoryDatabase) } for _, dbName := range dbNames { diff --git a/netquery/manager.go b/netquery/manager.go index bcd60618..531063b8 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -13,6 +13,8 @@ import ( "github.com/safing/portbase/log" "github.com/safing/portbase/runtime" "github.com/safing/portmaster/network" + "github.com/safing/spn/access" + "github.com/safing/spn/access/account" ) type ( @@ -26,6 +28,21 @@ type ( // The ID of Conn is unique and can be trusted to never collide with other // connections of the save device. Save(context.Context, Conn, bool) error + + // MarkAllHistoryConnectionsEnded marks all active connections in the history + // database as ended NOW. + MarkAllHistoryConnectionsEnded(context.Context) error + + // RemoveHistoryForProfile removes all connections from the history database + // for a given profile ID (source/id) + RemoveHistoryForProfile(context.Context, string) error + + // RemoveAllHistoryData removes all connections from the history database. + RemoveAllHistoryData(context.Context) error + + // UpdateBandwidth updates bandwith data for the connection and optionally also writes + // the bandwidth data to the history database. + UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, incoming *uint64, outgoing *uint64) error } // Manager handles new and updated network.Connections feeds and persists them @@ -100,7 +117,20 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect log.Tracef("netquery: updating connection %s", conn.ID) - if err := mng.store.Save(ctx, *model, conn.Process().Profile().HistoryEnabled()); err != nil { + // check if we should persist the connection in the history database. + // Also make sure the current SPN User/subscription allows use of the history. + historyEnabled := conn.Process().Profile().HistoryEnabled() + if historyEnabled { + user, err := access.GetUser() + if err != nil { + // there was an error so disable history + historyEnabled = false + } else if !user.MayUse(account.FeatureHistory) { + historyEnabled = false + } + } + + if err := mng.store.Save(ctx, *model, historyEnabled); err != nil { log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err) continue @@ -158,7 +188,9 @@ func convertConnection(conn *network.Connection) (*Conn, error) { IPProtocol: conn.IPProtocol, LocalIP: conn.LocalIP.String(), LocalPort: conn.LocalPort, - Verdict: conn.Verdict.Firewall, // TODO: Expose both Worst and Firewall verdicts. + FirewallVerdict: conn.Verdict.Firewall, + ActiveVerdict: conn.Verdict.Active, + WorstVerdict: conn.Verdict.Worst, Started: time.Unix(conn.Started, 0), Tunneled: conn.Tunneled, Encrypted: conn.Encrypted, @@ -250,7 +282,7 @@ func convertConnection(conn *network.Connection) (*Conn, error) { } func genConnID(conn *network.Connection) string { - data := conn.ID + "-" + time.Unix(conn.Started, 0).String() + data := conn.ID + "-" + conn.Process().GetID() hash := sha256.Sum256([]byte(data)) return hex.EncodeToString(hash[:]) } diff --git a/netquery/module_api.go b/netquery/module_api.go index 4cb02462..3746fdf4 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -2,15 +2,19 @@ package netquery import ( "context" + "encoding/json" "fmt" + "net/http" "time" + "github.com/hashicorp/go-multierror" "github.com/safing/portbase/api" "github.com/safing/portbase/config" "github.com/safing/portbase/database" "github.com/safing/portbase/database/query" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" + "github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/runtime" "github.com/safing/portmaster/network" ) @@ -35,6 +39,15 @@ func init() { "network", "database", ) + + subsystems.Register( + "history", + "Network History", + "Keep Network History Data", + m.Module, + "config:history/", + nil, + ) } func (m *module) prepare() error { @@ -92,6 +105,58 @@ func (m *module) prepare() error { return fmt.Errorf("failed to register API endpoint: %w", err) } + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "netquery/history/clear", + MimeType: "application/json", + Read: api.PermitUser, + Write: api.PermitUser, + BelongsTo: m.Module, + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + var body struct { + ProfileIDs []string `json:"profileIDs"` + } + + defer r.Body.Close() + + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + + if err := dec.Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if len(body.ProfileIDs) == 0 { + if err := m.mng.store.RemoveAllHistoryData(r.Context()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + + return + } + } else { + merr := new(multierror.Error) + for _, profileID := range body.ProfileIDs { + if err := m.mng.store.RemoveHistoryForProfile(r.Context(), profileID); err != nil { + merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err)) + } else { + log.Infof("netquery: successfully cleared history for %s", profileID) + } + } + + if err := merr.ErrorOrNil(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + + return + } + } + + w.WriteHeader(http.StatusNoContent) + }, + Name: "Remove connections from profile history", + Description: "Remove all connections from the history database for one or more profiles", + }); err != nil { + return fmt.Errorf("failed to register API endpoint: %w", err) + } + return nil } @@ -163,5 +228,16 @@ func (m *module) start() error { } func (m *module) stop() error { + // we don't use m.Module.Ctx here because it is already cancelled when stop is called. + // just give the clean up 1 minute to happen and abort otherwise. + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil { + // handle the error by just logging it. There's not much we can do here + // and returning an error to the module system doesn't help much as well... + log.Errorf("failed to mark connections in history database as eded: %w", err) + } + return nil } diff --git a/netquery/orm/encoder.go b/netquery/orm/encoder.go index 7961f088..ef86b842 100644 --- a/netquery/orm/encoder.go +++ b/netquery/orm/encoder.go @@ -6,6 +6,7 @@ import ( "reflect" "time" + "golang.org/x/exp/slices" "zombiezen.com/go/sqlite" ) @@ -22,7 +23,7 @@ type ( // ToParamMap returns a map that contains the sqlite compatible value of each struct field of // r using the sqlite column name as a map key. It either uses the name of the // exported struct field or the value of the "sqlite" tag. -func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) { +func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) { // make sure we work on a struct type val := reflect.Indirect(reflect.ValueOf(r)) if val.Kind() != reflect.Struct { @@ -45,6 +46,10 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err) } + if slices.Contains(skipFields, colDef.Name) { + continue + } + x, found, err := runEncodeHooks( colDef, fieldType.Type, diff --git a/netquery/orm/encoder_test.go b/netquery/orm/encoder_test.go index e5142962..d0d3c039 100644 --- a/netquery/orm/encoder_test.go +++ b/netquery/orm/encoder_test.go @@ -119,7 +119,7 @@ func TestEncodeAsMap(t *testing.T) { //nolint:tparallel for idx := range cases { //nolint:paralleltest c := cases[idx] t.Run(c.Desc, func(t *testing.T) { - res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig) + res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil) assert.NoError(t, err) assert.Equal(t, c.Expected, res) }) diff --git a/netquery/orm/query_runner.go b/netquery/orm/query_runner.go index 55bafe30..f59cca79 100644 --- a/netquery/orm/query_runner.go +++ b/netquery/orm/query_runner.go @@ -143,7 +143,23 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q currentField := reflect.New(valElemType) if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil { - return err + resultDump := make(map[string]any) + + for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ { + name := stmt.ColumnName(colIdx) + + switch stmt.ColumnType(colIdx) { + case sqlite.TypeText: + resultDump[name] = stmt.ColumnText(colIdx) + case sqlite.TypeFloat: + resultDump[name] = stmt.ColumnFloat(colIdx) + case sqlite.TypeInteger: + resultDump[name] = stmt.ColumnInt(colIdx) + case sqlite.TypeNull: + resultDump[name] = "" + } + } + return fmt.Errorf("%w: %+v", err, resultDump) } sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField)) diff --git a/netquery/orm/schema_builder.go b/netquery/orm/schema_builder.go index e9e46874..080c5003 100644 --- a/netquery/orm/schema_builder.go +++ b/netquery/orm/schema_builder.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" + "github.com/safing/portbase/log" "zombiezen.com/go/sqlite" ) @@ -25,6 +26,7 @@ var ( TagTypePrefixVarchar = "varchar" TagTypeBlob = "blob" TagTypeFloat = "float" + TagTypePrefixDefault = "default=" ) var sqlTypeMap = map[sqlite.ColumnType]string{ @@ -52,6 +54,7 @@ type ( AutoIncrement bool UnixNano bool IsTime bool + Default any } ) @@ -105,6 +108,21 @@ func (def ColumnDef) AsSQL() string { if def.AutoIncrement { sql += " AUTOINCREMENT" } + if def.Default != nil { + sql += " DEFAULT " + switch def.Type { + case sqlite.TypeFloat: + sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64) + case sqlite.TypeInteger: + sql += strconv.FormatInt(def.Default.(int64), 10) + case sqlite.TypeText: + sql += fmt.Sprintf("%q", def.Default.(string)) + default: + log.Errorf("unsupported default value: %q %q", def.Type, def.Default) + sql = strings.TrimSuffix(sql, " DEFAULT ") + } + sql += " " + } if !def.Nullable { sql += " NOT NULL" } @@ -160,7 +178,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) { kind := normalizeKind(ft.Kind()) switch kind { //nolint:exhaustive - case reflect.Int: + case reflect.Int, reflect.Uint: def.Type = sqlite.TypeInteger case reflect.Float64: @@ -237,6 +255,30 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error { def.Length = int(length) } + if strings.HasPrefix(k, TagTypePrefixDefault) { + defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault) + switch def.Type { + case sqlite.TypeFloat: + fv, err := strconv.ParseFloat(defaultValue, 64) + if err != nil { + return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err) + } + def.Default = fv + case sqlite.TypeInteger: + fv, err := strconv.ParseInt(defaultValue, 10, 0) + if err != nil { + return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err) + } + def.Default = fv + case sqlite.TypeText: + def.Default = defaultValue + case sqlite.TypeBlob: + return fmt.Errorf("default values for TypeBlob not yet supported") + default: + return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type) + } + } + } } } diff --git a/netquery/orm/schema_builder_test.go b/netquery/orm/schema_builder_test.go index 734da981..fdd43ec7 100644 --- a/netquery/orm/schema_builder_test.go +++ b/netquery/orm/schema_builder_test.go @@ -22,14 +22,14 @@ func TestSchemaBuilder(t *testing.T) { Int *int `sqlite:",not-null"` Float interface{} `sqlite:",float,nullable"` }{}, - `CREATE TABLE Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`, + `CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`, }, { "Varchar", struct { S string `sqlite:",varchar(10)"` }{}, - `CREATE TABLE Varchar ( S VARCHAR(10) NOT NULL );`, + `CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`, }, } @@ -38,6 +38,6 @@ func TestSchemaBuilder(t *testing.T) { res, err := GenerateTableSchema(c.Name, c.Model) assert.NoError(t, err) - assert.Equal(t, c.ExpectedSQL, res.CreateStatement(false)) + assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false)) } } diff --git a/netquery/query.go b/netquery/query.go index 83dbc217..264f0bd7 100644 --- a/netquery/query.go +++ b/netquery/query.go @@ -14,6 +14,13 @@ import ( "github.com/safing/portmaster/netquery/orm" ) +type DatabaseName string + +const ( + LiveDatabase = DatabaseName("main") + HistoryDatabase = DatabaseName("history") +) + // Collection of Query and Matcher types. // NOTE: whenever adding support for new operators make sure // to update UnmarshalJSON as well. @@ -48,11 +55,19 @@ type ( Distinct bool `json:"distinct"` } + Min struct { + Condition *Query `json:"condition,omitempty"` + Field string `json:"field"` + As string `json:"as"` + Distinct bool `json:"distinct"` + } + Select struct { Field string `json:"field"` Count *Count `json:"$count,omitempty"` Sum *Sum `json:"$sum,omitempty"` - Distinct *string `json:"$distinct"` + Min *Min `json:"$min,omitempty"` + Distinct *string `json:"$distinct,omitempty"` } Selects []Select @@ -68,6 +83,9 @@ type ( OrderBy OrderBys `json:"orderBy"` GroupBy []string `json:"groupBy"` TextSearch *TextSearch `json:"textSearch"` + // A list of databases to query. If left empty, + // both, the LiveDatabase and the HistoryDatabase are queried + Databases []DatabaseName `json:"databases"` Pagination @@ -457,6 +475,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { Field string `json:"field"` Count *Count `json:"$count"` Sum *Sum `json:"$sum"` + Min *Min `json:"$min"` Distinct *string `json:"$distinct"` } @@ -468,12 +487,23 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { sel.Field = res.Field sel.Distinct = res.Distinct sel.Sum = res.Sum + sel.Min = res.Min if sel.Count != nil && sel.Count.As != "" { if !charOnlyRegexp.MatchString(sel.Count.As) { return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+") } } + if sel.Sum != nil && sel.Sum.As != "" { + if !charOnlyRegexp.MatchString(sel.Sum.As) { + return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+") + } + } + if sel.Min != nil && sel.Min.As != "" { + if !charOnlyRegexp.MatchString(sel.Min.As) { + return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+") + } + } return nil } diff --git a/netquery/query_handler.go b/netquery/query_handler.go index 141f7dd5..3c6bb453 100644 --- a/netquery/query_handler.go +++ b/netquery/query_handler.go @@ -14,6 +14,7 @@ import ( "github.com/safing/portbase/log" "github.com/safing/portmaster/netquery/orm" + "golang.org/x/exp/slices" ) var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+") @@ -152,13 +153,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab return "", nil, fmt.Errorf("generating where clause: %w", err) } - if req.paramMap == nil { - req.paramMap = make(map[string]interface{}) - } - - for key, val := range paramMap { - req.paramMap[key] = val - } + req.mergeParams(paramMap) if req.TextSearch != nil { textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig) @@ -173,9 +168,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab whereClause += textClause - for key, val := range textParams { - req.paramMap[key] = val - } + req.mergeParams(textParams) } } @@ -190,15 +183,21 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab } selectClause := req.generateSelectClause() - inMem := `SELECT *, 'live' as _source FROM main.connections ` - inHistory := `SELECT *, 'history' as _source FROM history.connections ` if whereClause != "" { - inMem += " WHERE " + whereClause - inHistory += " WHERE " + whereClause + whereClause = "WHERE " + whereClause } - source := inMem + " UNION " + inHistory + if len(req.Databases) == 0 { + req.Databases = []DatabaseName{LiveDatabase, HistoryDatabase} + } + + sources := make([]string, len(req.Databases)) + for idx, db := range req.Databases { + sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause) + } + + source := strings.Join(sources, " UNION ") query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) ` @@ -210,6 +209,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error { for idx, s := range req.Select { var field string + switch { case s.Count != nil: field = s.Count.Field @@ -218,6 +218,12 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem case s.Sum != nil: // field is not used in case of $sum field = "*" + case s.Min != nil: + if s.Min.Field != "" { + field = s.Min.Field + } else { + field = "*" + } default: field = s.Field } @@ -258,13 +264,40 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem return fmt.Errorf("in $sum: %w", err) } - req.paramMap = params + req.mergeParams(params) req.selectedFields = append( req.selectedFields, fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As), ) req.whitelistedFields = append(req.whitelistedFields, s.Sum.As) + case s.Min != nil: + if s.Min.As == "" { + return fmt.Errorf("missing 'as' for $min") + } + + var ( + clause string + params map[string]any + ) + + if s.Min.Field != "" { + clause = field + } else { + var err error + clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig) + if err != nil { + return fmt.Errorf("in $min: %w", err) + } + } + + req.mergeParams(params) + req.selectedFields = append( + req.selectedFields, + fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As), + ) + req.whitelistedFields = append(req.whitelistedFields, s.Min.As) + case s.Distinct != nil: req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName)) req.whitelistedFields = append(req.whitelistedFields, colName) @@ -277,6 +310,16 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem return nil } +func (req *QueryRequestPayload) mergeParams(params map[string]any) { + if req.paramMap == nil { + req.paramMap = make(map[string]any) + } + + for key, value := range params { + req.paramMap[key] = value + } +} + func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) { if len(req.GroupBy) == 0 { return "", nil @@ -339,16 +382,12 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel return colDef.Name, nil } - for _, selected := range req.whitelistedFields { - if field == selected { - return field, nil - } + if slices.Contains(req.whitelistedFields, field) { + return field, nil } - for _, selected := range req.selectedFields { - if field == selected { - return field, nil - } + if slices.Contains(req.selectedFields, field) { + return field, nil } return "", fmt.Errorf("column name %q not allowed", field) diff --git a/process/process.go b/process/process.go index 3f2779f9..99c281cc 100644 --- a/process/process.go +++ b/process/process.go @@ -313,6 +313,10 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (* return process, nil } +func (p *Process) GetID() string { + return p.processKey +} + // Builds a unique identifier for a processes. func getProcessKey(pid int32, createdTime int64) string { return fmt.Sprintf("%d-%d", pid, createdTime) diff --git a/profile/config.go b/profile/config.go index 3c87b009..51fd731b 100644 --- a/profile/config.go +++ b/profile/config.go @@ -6,6 +6,7 @@ import ( "github.com/safing/portbase/config" "github.com/safing/portmaster/profile/endpoints" "github.com/safing/portmaster/status" + "github.com/safing/spn/access/account" "github.com/safing/spn/navigator" ) @@ -105,7 +106,7 @@ var ( // Setting "Permanent Verdicts" at order 96. - CfgOptionEnableHistoryKey = "filter/enableHistory" + CfgOptionEnableHistoryKey = "history/enabled" cfgOptionEnableHistory config.BoolOption cfgOptionEnableHistoryOrder = 66 @@ -249,12 +250,13 @@ func registerConfiguration() error { //nolint:maintidx Key: CfgOptionEnableHistoryKey, Description: "Whether or not to save connections to the history database", OptType: config.OptTypeBool, - ReleaseLevel: config.ReleaseLevelExperimental, + ReleaseLevel: config.ReleaseLevelStable, ExpertiseLevel: config.ExpertiseLevelExpert, DefaultValue: false, Annotations: config.Annotations{ - config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder, - config.CategoryAnnotation: "Advanced", + config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder, + config.CategoryAnnotation: "History", + config.SettingRequiresFeaturePlan: account.FeatureHistory, }, }) if err != nil {