Update netquery to support history module

This commit is contained in:
Patrick Pacher
2023-07-19 11:03:11 +02:00
committed by Daniel
parent cf2b8f26b9
commit dbffa8827b
13 changed files with 391 additions and 56 deletions

2
.gitignore vendored
View File

@@ -53,3 +53,5 @@ _testmain.go
# Custom dev scripts
win_dev_*
go.work
go.work.sum

View File

@@ -2,18 +2,23 @@ package netquery
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"path"
"sort"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/jackc/puddle/v2"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm"
"github.com/safing/portmaster/network"
@@ -46,6 +51,7 @@ type (
Schema *orm.TableSchema
readConnPool *puddle.Pool[*sqlite.Conn]
historyPath string
l sync.Mutex
writeConn *sqlite.Conn
@@ -82,7 +88,9 @@ type (
Latitude float64 `sqlite:"latitude"`
Longitude float64 `sqlite:"longitude"`
Scope netutils.IPScope `sqlite:"scope"`
Verdict network.Verdict `sqlite:"verdict"`
WorstVerdict network.Verdict `sqlite:"worst_verdict"`
ActiveVerdict network.Verdict `sqlite:"verdict"`
FirewallVerdict network.Verdict `sqlite:"firewall_verdict"`
Started time.Time `sqlite:"started,text,time"`
Ended *time.Time `sqlite:"ended,text,time"`
Tunneled bool `sqlite:"tunneled"`
@@ -93,6 +101,8 @@ type (
Allowed *bool `sqlite:"allowed"`
ProfileRevision int `sqlite:"profile_revision"`
ExitNode *string `sqlite:"exit_node"`
BWIncoming uint64 `sqlite:"bw_incoming,default=0"`
BWOutgoing uint64 `sqlite:"bw_outgoing,default=0"`
// TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
@@ -108,21 +118,27 @@ type (
// (see Execute). To perform database writes use either Save() or ExecuteWrite().
// Note that write connections are serialized by the Database object before being
// handed over to SQLite.
func New(path string) (*Database, error) {
func New(dbPath string) (*Database, error) {
historyParentDir := dataroot.Root().ChildDir("databases", 0o700)
if err := historyParentDir.Ensure(); err != nil {
return nil, fmt.Errorf("failed to ensure database directory exists: %w", err)
}
historyPath := "file://" + path.Join(historyParentDir.Path, "history.db")
constructor := func(ctx context.Context) (*sqlite.Conn, error) {
c, err := sqlite.OpenConn(
path,
dbPath,
sqlite.OpenReadOnly,
sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit.
sqlite.OpenSharedCache,
//sqlite.OpenMemory,
sqlite.OpenURI,
)
if err != nil {
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", path, err)
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", dbPath, err)
}
if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE 'file:///tmp/history.db?mode=ro' AS history", nil); err != nil {
if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE '"+historyPath+"?mode=ro' AS history", nil); err != nil {
return nil, fmt.Errorf("failed to attach history database: %w", err)
}
@@ -150,23 +166,23 @@ func New(path string) (*Database, error) {
}
writeConn, err := sqlite.OpenConn(
path,
dbPath,
sqlite.OpenCreate,
sqlite.OpenReadWrite,
sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit.
sqlite.OpenWAL,
sqlite.OpenSharedCache,
//sqlite.OpenMemory,
sqlite.OpenURI,
)
if err != nil {
return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err)
return nil, fmt.Errorf("failed to open sqlite at %s: %w", dbPath, err)
}
return &Database{
readConnPool: pool,
Schema: schema,
writeConn: writeConn,
historyPath: historyPath,
}, nil
}
@@ -197,9 +213,7 @@ func (db *Database) ApplyMigrations() error {
db.l.Lock()
defer db.l.Unlock()
// Attach the history database
log.Errorf("attaching database history")
if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE 'file:///tmp/history.db?mode=rwc' AS 'history';", nil); err != nil {
if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE '"+db.historyPath+"?mode=rwc' AS 'history';", nil); err != nil {
return fmt.Errorf("failed to attach history database: %w", err)
}
@@ -207,7 +221,7 @@ func (db *Database) ApplyMigrations() error {
for _, dbName := range dbNames {
// get the create-table SQL statement from the inferred schema
sql := db.Schema.CreateStatement(dbName, true)
log.Errorf("creating table schema for database %q", dbName)
log.Debugf("creating table schema for database %q", dbName)
// execute the SQL
if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil {
@@ -285,7 +299,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
return result[0].Count, nil
}
// Cleanup removes all connections that have ended before threshold.
// Cleanup removes all connections that have ended before threshold from the live database.
//
// NOTE(ppacher): there is no easy way to get the number of removed
// rows other than counting them in a first step. Though, that's
@@ -323,6 +337,18 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro
return result[0].Count, nil
}
func (db *Database) RemoveAllHistoryData(ctx context.Context) error {
query := fmt.Sprintf("DELETE FROM %s.connections", HistoryDatabase)
return db.ExecuteWrite(ctx, query)
}
func (db *Database) RemoveHistoryForProfile(ctx context.Context, profileID string) error {
query := fmt.Sprintf("DELETE FROM %s.connections WHERE profile = :profile", HistoryDatabase)
return db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
":profile": profileID,
}))
}
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database
// as JSON to w.
// Any error aborts dumping rows and is returned.
@@ -350,13 +376,74 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:un
return enc.Encode(conns)
}
// MarkAllHistoryConnectionsEnded marks all connections in the history database as ended.
func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
query := fmt.Sprintf("UPDATE %s.connections SET active = FALSE, ended = :ended WHERE active = TRUE", HistoryDatabase)
if err := db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
":ended": time.Now().Format(orm.SqliteTimeFormat),
})); err != nil {
return err
}
return nil
}
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, incoming *uint64, outgoing *uint64) error {
data := connID + "-" + processKey
hash := sha256.Sum256([]byte(data))
dbConnId := hex.EncodeToString(hash[:])
params := map[string]any{
":id": dbConnId,
}
parts := []string{}
if incoming != nil {
parts = append(parts, "bw_incoming = :bw_incoming")
params[":bw_incoming"] = *incoming
}
if outgoing != nil {
parts = append(parts, "bw_outgoing = :bw_outgoing")
params[":bw_outgoing"] = *outgoing
}
updateSet := strings.Join(parts, ", ")
updateStmts := []string{
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, LiveDatabase, updateSet),
}
if enableHistory {
updateStmts = append(updateStmts,
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, HistoryDatabase, updateSet),
)
}
merr := new(multierror.Error)
for _, stmt := range updateStmts {
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, err)
}
}
return merr.ErrorOrNil()
}
// Save inserts the connection conn into the SQLite database. If conn
// already exists the table row is updated instead.
//
// Save uses the database write connection instead of relying on the
// connection pool.
func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error {
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig)
// convert the connection to a param map where each key is already translated
// to the sql column name. We also skip bw_incoming and bw_outgoing since those
// will be updated independenly from the connection object.
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig, []string{
"bw_incoming",
"bw_outgoing",
})
if err != nil {
return fmt.Errorf("failed to encode connection for SQL: %w", err)
}
@@ -387,10 +474,10 @@ func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) err
// TODO(ppacher): make sure this one can be cached to speed up inserting
// and save some CPU cycles for the user
dbNames := []string{"main"}
dbNames := []DatabaseName{LiveDatabase}
if enableHistory {
dbNames = append(dbNames, "history")
dbNames = append(dbNames, HistoryDatabase)
}
for _, dbName := range dbNames {

View File

@@ -13,6 +13,8 @@ import (
"github.com/safing/portbase/log"
"github.com/safing/portbase/runtime"
"github.com/safing/portmaster/network"
"github.com/safing/spn/access"
"github.com/safing/spn/access/account"
)
type (
@@ -26,6 +28,21 @@ type (
// The ID of Conn is unique and can be trusted to never collide with other
// connections of the save device.
Save(context.Context, Conn, bool) error
// MarkAllHistoryConnectionsEnded marks all active connections in the history
// database as ended NOW.
MarkAllHistoryConnectionsEnded(context.Context) error
// RemoveHistoryForProfile removes all connections from the history database
// for a given profile ID (source/id)
RemoveHistoryForProfile(context.Context, string) error
// RemoveAllHistoryData removes all connections from the history database.
RemoveAllHistoryData(context.Context) error
// UpdateBandwidth updates bandwith data for the connection and optionally also writes
// the bandwidth data to the history database.
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, incoming *uint64, outgoing *uint64) error
}
// Manager handles new and updated network.Connections feeds and persists them
@@ -100,7 +117,20 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
log.Tracef("netquery: updating connection %s", conn.ID)
if err := mng.store.Save(ctx, *model, conn.Process().Profile().HistoryEnabled()); err != nil {
// check if we should persist the connection in the history database.
// Also make sure the current SPN User/subscription allows use of the history.
historyEnabled := conn.Process().Profile().HistoryEnabled()
if historyEnabled {
user, err := access.GetUser()
if err != nil {
// there was an error so disable history
historyEnabled = false
} else if !user.MayUse(account.FeatureHistory) {
historyEnabled = false
}
}
if err := mng.store.Save(ctx, *model, historyEnabled); err != nil {
log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err)
continue
@@ -158,7 +188,9 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
IPProtocol: conn.IPProtocol,
LocalIP: conn.LocalIP.String(),
LocalPort: conn.LocalPort,
Verdict: conn.Verdict.Firewall, // TODO: Expose both Worst and Firewall verdicts.
FirewallVerdict: conn.Verdict.Firewall,
ActiveVerdict: conn.Verdict.Active,
WorstVerdict: conn.Verdict.Worst,
Started: time.Unix(conn.Started, 0),
Tunneled: conn.Tunneled,
Encrypted: conn.Encrypted,
@@ -250,7 +282,7 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
}
func genConnID(conn *network.Connection) string {
data := conn.ID + "-" + time.Unix(conn.Started, 0).String()
data := conn.ID + "-" + conn.Process().GetID()
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}

View File

@@ -2,15 +2,19 @@ package netquery
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/hashicorp/go-multierror"
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/runtime"
"github.com/safing/portmaster/network"
)
@@ -35,6 +39,15 @@ func init() {
"network",
"database",
)
subsystems.Register(
"history",
"Network History",
"Keep Network History Data",
m.Module,
"config:history/",
nil,
)
}
func (m *module) prepare() error {
@@ -92,6 +105,58 @@ func (m *module) prepare() error {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
Path: "netquery/history/clear",
MimeType: "application/json",
Read: api.PermitUser,
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
var body struct {
ProfileIDs []string `json:"profileIDs"`
}
defer r.Body.Close()
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(body.ProfileIDs) == 0 {
if err := m.mng.store.RemoveAllHistoryData(r.Context()); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
merr := new(multierror.Error)
for _, profileID := range body.ProfileIDs {
if err := m.mng.store.RemoveHistoryForProfile(r.Context(), profileID); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err))
} else {
log.Infof("netquery: successfully cleared history for %s", profileID)
}
}
if err := merr.ErrorOrNil(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
w.WriteHeader(http.StatusNoContent)
},
Name: "Remove connections from profile history",
Description: "Remove all connections from the history database for one or more profiles",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
return nil
}
@@ -163,5 +228,16 @@ func (m *module) start() error {
}
func (m *module) stop() error {
// we don't use m.Module.Ctx here because it is already cancelled when stop is called.
// just give the clean up 1 minute to happen and abort otherwise.
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil {
// handle the error by just logging it. There's not much we can do here
// and returning an error to the module system doesn't help much as well...
log.Errorf("failed to mark connections in history database as eded: %w", err)
}
return nil
}

View File

@@ -6,6 +6,7 @@ import (
"reflect"
"time"
"golang.org/x/exp/slices"
"zombiezen.com/go/sqlite"
)
@@ -22,7 +23,7 @@ type (
// ToParamMap returns a map that contains the sqlite compatible value of each struct field of
// r using the sqlite column name as a map key. It either uses the name of the
// exported struct field or the value of the "sqlite" tag.
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) {
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) {
// make sure we work on a struct type
val := reflect.Indirect(reflect.ValueOf(r))
if val.Kind() != reflect.Struct {
@@ -45,6 +46,10 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
}
if slices.Contains(skipFields, colDef.Name) {
continue
}
x, found, err := runEncodeHooks(
colDef,
fieldType.Type,

View File

@@ -119,7 +119,7 @@ func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
for idx := range cases { //nolint:paralleltest
c := cases[idx]
t.Run(c.Desc, func(t *testing.T) {
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig)
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil)
assert.NoError(t, err)
assert.Equal(t, c.Expected, res)
})

View File

@@ -143,7 +143,23 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
currentField := reflect.New(valElemType)
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
return err
resultDump := make(map[string]any)
for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ {
name := stmt.ColumnName(colIdx)
switch stmt.ColumnType(colIdx) {
case sqlite.TypeText:
resultDump[name] = stmt.ColumnText(colIdx)
case sqlite.TypeFloat:
resultDump[name] = stmt.ColumnFloat(colIdx)
case sqlite.TypeInteger:
resultDump[name] = stmt.ColumnInt(colIdx)
case sqlite.TypeNull:
resultDump[name] = "<null>"
}
}
return fmt.Errorf("%w: %+v", err, resultDump)
}
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))

View File

@@ -7,6 +7,7 @@ import (
"strconv"
"strings"
"github.com/safing/portbase/log"
"zombiezen.com/go/sqlite"
)
@@ -25,6 +26,7 @@ var (
TagTypePrefixVarchar = "varchar"
TagTypeBlob = "blob"
TagTypeFloat = "float"
TagTypePrefixDefault = "default="
)
var sqlTypeMap = map[sqlite.ColumnType]string{
@@ -52,6 +54,7 @@ type (
AutoIncrement bool
UnixNano bool
IsTime bool
Default any
}
)
@@ -105,6 +108,21 @@ func (def ColumnDef) AsSQL() string {
if def.AutoIncrement {
sql += " AUTOINCREMENT"
}
if def.Default != nil {
sql += " DEFAULT "
switch def.Type {
case sqlite.TypeFloat:
sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64)
case sqlite.TypeInteger:
sql += strconv.FormatInt(def.Default.(int64), 10)
case sqlite.TypeText:
sql += fmt.Sprintf("%q", def.Default.(string))
default:
log.Errorf("unsupported default value: %q %q", def.Type, def.Default)
sql = strings.TrimSuffix(sql, " DEFAULT ")
}
sql += " "
}
if !def.Nullable {
sql += " NOT NULL"
}
@@ -160,7 +178,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
kind := normalizeKind(ft.Kind())
switch kind { //nolint:exhaustive
case reflect.Int:
case reflect.Int, reflect.Uint:
def.Type = sqlite.TypeInteger
case reflect.Float64:
@@ -237,6 +255,30 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
def.Length = int(length)
}
if strings.HasPrefix(k, TagTypePrefixDefault) {
defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault)
switch def.Type {
case sqlite.TypeFloat:
fv, err := strconv.ParseFloat(defaultValue, 64)
if err != nil {
return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err)
}
def.Default = fv
case sqlite.TypeInteger:
fv, err := strconv.ParseInt(defaultValue, 10, 0)
if err != nil {
return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err)
}
def.Default = fv
case sqlite.TypeText:
def.Default = defaultValue
case sqlite.TypeBlob:
return fmt.Errorf("default values for TypeBlob not yet supported")
default:
return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type)
}
}
}
}
}

View File

@@ -22,14 +22,14 @@ func TestSchemaBuilder(t *testing.T) {
Int *int `sqlite:",not-null"`
Float interface{} `sqlite:",float,nullable"`
}{},
`CREATE TABLE Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
`CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
},
{
"Varchar",
struct {
S string `sqlite:",varchar(10)"`
}{},
`CREATE TABLE Varchar ( S VARCHAR(10) NOT NULL );`,
`CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`,
},
}
@@ -38,6 +38,6 @@ func TestSchemaBuilder(t *testing.T) {
res, err := GenerateTableSchema(c.Name, c.Model)
assert.NoError(t, err)
assert.Equal(t, c.ExpectedSQL, res.CreateStatement(false))
assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false))
}
}

View File

@@ -14,6 +14,13 @@ import (
"github.com/safing/portmaster/netquery/orm"
)
type DatabaseName string
const (
LiveDatabase = DatabaseName("main")
HistoryDatabase = DatabaseName("history")
)
// Collection of Query and Matcher types.
// NOTE: whenever adding support for new operators make sure
// to update UnmarshalJSON as well.
@@ -48,11 +55,19 @@ type (
Distinct bool `json:"distinct"`
}
Min struct {
Condition *Query `json:"condition,omitempty"`
Field string `json:"field"`
As string `json:"as"`
Distinct bool `json:"distinct"`
}
Select struct {
Field string `json:"field"`
Count *Count `json:"$count,omitempty"`
Sum *Sum `json:"$sum,omitempty"`
Distinct *string `json:"$distinct"`
Min *Min `json:"$min,omitempty"`
Distinct *string `json:"$distinct,omitempty"`
}
Selects []Select
@@ -68,6 +83,9 @@ type (
OrderBy OrderBys `json:"orderBy"`
GroupBy []string `json:"groupBy"`
TextSearch *TextSearch `json:"textSearch"`
// A list of databases to query. If left empty,
// both, the LiveDatabase and the HistoryDatabase are queried
Databases []DatabaseName `json:"databases"`
Pagination
@@ -457,6 +475,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
Field string `json:"field"`
Count *Count `json:"$count"`
Sum *Sum `json:"$sum"`
Min *Min `json:"$min"`
Distinct *string `json:"$distinct"`
}
@@ -468,12 +487,23 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
sel.Field = res.Field
sel.Distinct = res.Distinct
sel.Sum = res.Sum
sel.Min = res.Min
if sel.Count != nil && sel.Count.As != "" {
if !charOnlyRegexp.MatchString(sel.Count.As) {
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
}
}
if sel.Sum != nil && sel.Sum.As != "" {
if !charOnlyRegexp.MatchString(sel.Sum.As) {
return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+")
}
}
if sel.Min != nil && sel.Min.As != "" {
if !charOnlyRegexp.MatchString(sel.Min.As) {
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
}
}
return nil
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm"
"golang.org/x/exp/slices"
)
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
@@ -152,13 +153,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
return "", nil, fmt.Errorf("generating where clause: %w", err)
}
if req.paramMap == nil {
req.paramMap = make(map[string]interface{})
}
for key, val := range paramMap {
req.paramMap[key] = val
}
req.mergeParams(paramMap)
if req.TextSearch != nil {
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
@@ -173,9 +168,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
whereClause += textClause
for key, val := range textParams {
req.paramMap[key] = val
}
req.mergeParams(textParams)
}
}
@@ -190,15 +183,21 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
}
selectClause := req.generateSelectClause()
inMem := `SELECT *, 'live' as _source FROM main.connections `
inHistory := `SELECT *, 'history' as _source FROM history.connections `
if whereClause != "" {
inMem += " WHERE " + whereClause
inHistory += " WHERE " + whereClause
whereClause = "WHERE " + whereClause
}
source := inMem + " UNION " + inHistory
if len(req.Databases) == 0 {
req.Databases = []DatabaseName{LiveDatabase, HistoryDatabase}
}
sources := make([]string, len(req.Databases))
for idx, db := range req.Databases {
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
}
source := strings.Join(sources, " UNION ")
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
@@ -210,6 +209,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
for idx, s := range req.Select {
var field string
switch {
case s.Count != nil:
field = s.Count.Field
@@ -218,6 +218,12 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
case s.Sum != nil:
// field is not used in case of $sum
field = "*"
case s.Min != nil:
if s.Min.Field != "" {
field = s.Min.Field
} else {
field = "*"
}
default:
field = s.Field
}
@@ -258,13 +264,40 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
return fmt.Errorf("in $sum: %w", err)
}
req.paramMap = params
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
case s.Min != nil:
if s.Min.As == "" {
return fmt.Errorf("missing 'as' for $min")
}
var (
clause string
params map[string]any
)
if s.Min.Field != "" {
clause = field
} else {
var err error
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $min: %w", err)
}
}
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
case s.Distinct != nil:
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
req.whitelistedFields = append(req.whitelistedFields, colName)
@@ -277,6 +310,16 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
return nil
}
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
if req.paramMap == nil {
req.paramMap = make(map[string]any)
}
for key, value := range params {
req.paramMap[key] = value
}
}
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
if len(req.GroupBy) == 0 {
return "", nil
@@ -339,16 +382,12 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
return colDef.Name, nil
}
for _, selected := range req.whitelistedFields {
if field == selected {
return field, nil
}
if slices.Contains(req.whitelistedFields, field) {
return field, nil
}
for _, selected := range req.selectedFields {
if field == selected {
return field, nil
}
if slices.Contains(req.selectedFields, field) {
return field, nil
}
return "", fmt.Errorf("column name %q not allowed", field)

View File

@@ -313,6 +313,10 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*
return process, nil
}
func (p *Process) GetID() string {
return p.processKey
}
// Builds a unique identifier for a processes.
func getProcessKey(pid int32, createdTime int64) string {
return fmt.Sprintf("%d-%d", pid, createdTime)

View File

@@ -6,6 +6,7 @@ import (
"github.com/safing/portbase/config"
"github.com/safing/portmaster/profile/endpoints"
"github.com/safing/portmaster/status"
"github.com/safing/spn/access/account"
"github.com/safing/spn/navigator"
)
@@ -105,7 +106,7 @@ var (
// Setting "Permanent Verdicts" at order 96.
CfgOptionEnableHistoryKey = "filter/enableHistory"
CfgOptionEnableHistoryKey = "history/enabled"
cfgOptionEnableHistory config.BoolOption
cfgOptionEnableHistoryOrder = 66
@@ -249,12 +250,13 @@ func registerConfiguration() error { //nolint:maintidx
Key: CfgOptionEnableHistoryKey,
Description: "Whether or not to save connections to the history database",
OptType: config.OptTypeBool,
ReleaseLevel: config.ReleaseLevelExperimental,
ReleaseLevel: config.ReleaseLevelStable,
ExpertiseLevel: config.ExpertiseLevelExpert,
DefaultValue: false,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder,
config.CategoryAnnotation: "Advanced",
config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder,
config.CategoryAnnotation: "History",
config.SettingRequiresFeaturePlan: account.FeatureHistory,
},
})
if err != nil {