Update netquery to support history module

This commit is contained in:
Patrick Pacher
2023-07-19 11:03:11 +02:00
committed by Daniel
parent cf2b8f26b9
commit dbffa8827b
13 changed files with 391 additions and 56 deletions

2
.gitignore vendored
View File

@@ -53,3 +53,5 @@ _testmain.go
# Custom dev scripts # Custom dev scripts
win_dev_* win_dev_*
go.work
go.work.sum

View File

@@ -2,18 +2,23 @@ package netquery
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"path"
"sort" "sort"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
"github.com/jackc/puddle/v2" "github.com/jackc/puddle/v2"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex" "zombiezen.com/go/sqlite/sqlitex"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/netquery/orm"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
@@ -46,6 +51,7 @@ type (
Schema *orm.TableSchema Schema *orm.TableSchema
readConnPool *puddle.Pool[*sqlite.Conn] readConnPool *puddle.Pool[*sqlite.Conn]
historyPath string
l sync.Mutex l sync.Mutex
writeConn *sqlite.Conn writeConn *sqlite.Conn
@@ -82,7 +88,9 @@ type (
Latitude float64 `sqlite:"latitude"` Latitude float64 `sqlite:"latitude"`
Longitude float64 `sqlite:"longitude"` Longitude float64 `sqlite:"longitude"`
Scope netutils.IPScope `sqlite:"scope"` Scope netutils.IPScope `sqlite:"scope"`
Verdict network.Verdict `sqlite:"verdict"` WorstVerdict network.Verdict `sqlite:"worst_verdict"`
ActiveVerdict network.Verdict `sqlite:"verdict"`
FirewallVerdict network.Verdict `sqlite:"firewall_verdict"`
Started time.Time `sqlite:"started,text,time"` Started time.Time `sqlite:"started,text,time"`
Ended *time.Time `sqlite:"ended,text,time"` Ended *time.Time `sqlite:"ended,text,time"`
Tunneled bool `sqlite:"tunneled"` Tunneled bool `sqlite:"tunneled"`
@@ -93,6 +101,8 @@ type (
Allowed *bool `sqlite:"allowed"` Allowed *bool `sqlite:"allowed"`
ProfileRevision int `sqlite:"profile_revision"` ProfileRevision int `sqlite:"profile_revision"`
ExitNode *string `sqlite:"exit_node"` ExitNode *string `sqlite:"exit_node"`
BWIncoming uint64 `sqlite:"bw_incoming,default=0"`
BWOutgoing uint64 `sqlite:"bw_outgoing,default=0"`
// TODO(ppacher): support "NOT" in search query to get rid of the following helper fields // TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL" Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
@@ -108,21 +118,27 @@ type (
// (see Execute). To perform database writes use either Save() or ExecuteWrite(). // (see Execute). To perform database writes use either Save() or ExecuteWrite().
// Note that write connections are serialized by the Database object before being // Note that write connections are serialized by the Database object before being
// handed over to SQLite. // handed over to SQLite.
func New(path string) (*Database, error) { func New(dbPath string) (*Database, error) {
historyParentDir := dataroot.Root().ChildDir("databases", 0o700)
if err := historyParentDir.Ensure(); err != nil {
return nil, fmt.Errorf("failed to ensure database directory exists: %w", err)
}
historyPath := "file://" + path.Join(historyParentDir.Path, "history.db")
constructor := func(ctx context.Context) (*sqlite.Conn, error) { constructor := func(ctx context.Context) (*sqlite.Conn, error) {
c, err := sqlite.OpenConn( c, err := sqlite.OpenConn(
path, dbPath,
sqlite.OpenReadOnly, sqlite.OpenReadOnly,
sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit.
sqlite.OpenSharedCache, sqlite.OpenSharedCache,
//sqlite.OpenMemory, //sqlite.OpenMemory,
sqlite.OpenURI, sqlite.OpenURI,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", path, err) return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", dbPath, err)
} }
if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE 'file:///tmp/history.db?mode=ro' AS history", nil); err != nil { if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE '"+historyPath+"?mode=ro' AS history", nil); err != nil {
return nil, fmt.Errorf("failed to attach history database: %w", err) return nil, fmt.Errorf("failed to attach history database: %w", err)
} }
@@ -150,23 +166,23 @@ func New(path string) (*Database, error) {
} }
writeConn, err := sqlite.OpenConn( writeConn, err := sqlite.OpenConn(
path, dbPath,
sqlite.OpenCreate, sqlite.OpenCreate,
sqlite.OpenReadWrite, sqlite.OpenReadWrite,
sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit.
sqlite.OpenWAL, sqlite.OpenWAL,
sqlite.OpenSharedCache, sqlite.OpenSharedCache,
//sqlite.OpenMemory, //sqlite.OpenMemory,
sqlite.OpenURI, sqlite.OpenURI,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) return nil, fmt.Errorf("failed to open sqlite at %s: %w", dbPath, err)
} }
return &Database{ return &Database{
readConnPool: pool, readConnPool: pool,
Schema: schema, Schema: schema,
writeConn: writeConn, writeConn: writeConn,
historyPath: historyPath,
}, nil }, nil
} }
@@ -197,9 +213,7 @@ func (db *Database) ApplyMigrations() error {
db.l.Lock() db.l.Lock()
defer db.l.Unlock() defer db.l.Unlock()
// Attach the history database if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE '"+db.historyPath+"?mode=rwc' AS 'history';", nil); err != nil {
log.Errorf("attaching database history")
if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE 'file:///tmp/history.db?mode=rwc' AS 'history';", nil); err != nil {
return fmt.Errorf("failed to attach history database: %w", err) return fmt.Errorf("failed to attach history database: %w", err)
} }
@@ -207,7 +221,7 @@ func (db *Database) ApplyMigrations() error {
for _, dbName := range dbNames { for _, dbName := range dbNames {
// get the create-table SQL statement from the inferred schema // get the create-table SQL statement from the inferred schema
sql := db.Schema.CreateStatement(dbName, true) sql := db.Schema.CreateStatement(dbName, true)
log.Errorf("creating table schema for database %q", dbName) log.Debugf("creating table schema for database %q", dbName)
// execute the SQL // execute the SQL
if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil { if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil {
@@ -285,7 +299,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
return result[0].Count, nil return result[0].Count, nil
} }
// Cleanup removes all connections that have ended before threshold. // Cleanup removes all connections that have ended before threshold from the live database.
// //
// NOTE(ppacher): there is no easy way to get the number of removed // NOTE(ppacher): there is no easy way to get the number of removed
// rows other than counting them in a first step. Though, that's // rows other than counting them in a first step. Though, that's
@@ -323,6 +337,18 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro
return result[0].Count, nil return result[0].Count, nil
} }
func (db *Database) RemoveAllHistoryData(ctx context.Context) error {
query := fmt.Sprintf("DELETE FROM %s.connections", HistoryDatabase)
return db.ExecuteWrite(ctx, query)
}
func (db *Database) RemoveHistoryForProfile(ctx context.Context, profileID string) error {
query := fmt.Sprintf("DELETE FROM %s.connections WHERE profile = :profile", HistoryDatabase)
return db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
":profile": profileID,
}))
}
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database // dumpTo is a simple helper method that dumps all rows stored in the SQLite database
// as JSON to w. // as JSON to w.
// Any error aborts dumping rows and is returned. // Any error aborts dumping rows and is returned.
@@ -350,13 +376,74 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:un
return enc.Encode(conns) return enc.Encode(conns)
} }
// MarkAllHistoryConnectionsEnded marks all connections in the history database as ended.
func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
query := fmt.Sprintf("UPDATE %s.connections SET active = FALSE, ended = :ended WHERE active = TRUE", HistoryDatabase)
if err := db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
":ended": time.Now().Format(orm.SqliteTimeFormat),
})); err != nil {
return err
}
return nil
}
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, incoming *uint64, outgoing *uint64) error {
data := connID + "-" + processKey
hash := sha256.Sum256([]byte(data))
dbConnId := hex.EncodeToString(hash[:])
params := map[string]any{
":id": dbConnId,
}
parts := []string{}
if incoming != nil {
parts = append(parts, "bw_incoming = :bw_incoming")
params[":bw_incoming"] = *incoming
}
if outgoing != nil {
parts = append(parts, "bw_outgoing = :bw_outgoing")
params[":bw_outgoing"] = *outgoing
}
updateSet := strings.Join(parts, ", ")
updateStmts := []string{
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, LiveDatabase, updateSet),
}
if enableHistory {
updateStmts = append(updateStmts,
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, HistoryDatabase, updateSet),
)
}
merr := new(multierror.Error)
for _, stmt := range updateStmts {
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, err)
}
}
return merr.ErrorOrNil()
}
// Save inserts the connection conn into the SQLite database. If conn // Save inserts the connection conn into the SQLite database. If conn
// already exists the table row is updated instead. // already exists the table row is updated instead.
// //
// Save uses the database write connection instead of relying on the // Save uses the database write connection instead of relying on the
// connection pool. // connection pool.
func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error { func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error {
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig) // convert the connection to a param map where each key is already translated
// to the sql column name. We also skip bw_incoming and bw_outgoing since those
// will be updated independenly from the connection object.
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig, []string{
"bw_incoming",
"bw_outgoing",
})
if err != nil { if err != nil {
return fmt.Errorf("failed to encode connection for SQL: %w", err) return fmt.Errorf("failed to encode connection for SQL: %w", err)
} }
@@ -387,10 +474,10 @@ func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) err
// TODO(ppacher): make sure this one can be cached to speed up inserting // TODO(ppacher): make sure this one can be cached to speed up inserting
// and save some CPU cycles for the user // and save some CPU cycles for the user
dbNames := []string{"main"} dbNames := []DatabaseName{LiveDatabase}
if enableHistory { if enableHistory {
dbNames = append(dbNames, "history") dbNames = append(dbNames, HistoryDatabase)
} }
for _, dbName := range dbNames { for _, dbName := range dbNames {

View File

@@ -13,6 +13,8 @@ import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/runtime" "github.com/safing/portbase/runtime"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/spn/access"
"github.com/safing/spn/access/account"
) )
type ( type (
@@ -26,6 +28,21 @@ type (
// The ID of Conn is unique and can be trusted to never collide with other // The ID of Conn is unique and can be trusted to never collide with other
// connections of the save device. // connections of the save device.
Save(context.Context, Conn, bool) error Save(context.Context, Conn, bool) error
// MarkAllHistoryConnectionsEnded marks all active connections in the history
// database as ended NOW.
MarkAllHistoryConnectionsEnded(context.Context) error
// RemoveHistoryForProfile removes all connections from the history database
// for a given profile ID (source/id)
RemoveHistoryForProfile(context.Context, string) error
// RemoveAllHistoryData removes all connections from the history database.
RemoveAllHistoryData(context.Context) error
// UpdateBandwidth updates bandwith data for the connection and optionally also writes
// the bandwidth data to the history database.
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, incoming *uint64, outgoing *uint64) error
} }
// Manager handles new and updated network.Connections feeds and persists them // Manager handles new and updated network.Connections feeds and persists them
@@ -100,7 +117,20 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
log.Tracef("netquery: updating connection %s", conn.ID) log.Tracef("netquery: updating connection %s", conn.ID)
if err := mng.store.Save(ctx, *model, conn.Process().Profile().HistoryEnabled()); err != nil { // check if we should persist the connection in the history database.
// Also make sure the current SPN User/subscription allows use of the history.
historyEnabled := conn.Process().Profile().HistoryEnabled()
if historyEnabled {
user, err := access.GetUser()
if err != nil {
// there was an error so disable history
historyEnabled = false
} else if !user.MayUse(account.FeatureHistory) {
historyEnabled = false
}
}
if err := mng.store.Save(ctx, *model, historyEnabled); err != nil {
log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err) log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err)
continue continue
@@ -158,7 +188,9 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
IPProtocol: conn.IPProtocol, IPProtocol: conn.IPProtocol,
LocalIP: conn.LocalIP.String(), LocalIP: conn.LocalIP.String(),
LocalPort: conn.LocalPort, LocalPort: conn.LocalPort,
Verdict: conn.Verdict.Firewall, // TODO: Expose both Worst and Firewall verdicts. FirewallVerdict: conn.Verdict.Firewall,
ActiveVerdict: conn.Verdict.Active,
WorstVerdict: conn.Verdict.Worst,
Started: time.Unix(conn.Started, 0), Started: time.Unix(conn.Started, 0),
Tunneled: conn.Tunneled, Tunneled: conn.Tunneled,
Encrypted: conn.Encrypted, Encrypted: conn.Encrypted,
@@ -250,7 +282,7 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
} }
func genConnID(conn *network.Connection) string { func genConnID(conn *network.Connection) string {
data := conn.ID + "-" + time.Unix(conn.Started, 0).String() data := conn.ID + "-" + conn.Process().GetID()
hash := sha256.Sum256([]byte(data)) hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])
} }

View File

@@ -2,15 +2,19 @@ package netquery
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http"
"time" "time"
"github.com/hashicorp/go-multierror"
"github.com/safing/portbase/api" "github.com/safing/portbase/api"
"github.com/safing/portbase/config" "github.com/safing/portbase/config"
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/database/query" "github.com/safing/portbase/database/query"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/runtime" "github.com/safing/portbase/runtime"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
) )
@@ -35,6 +39,15 @@ func init() {
"network", "network",
"database", "database",
) )
subsystems.Register(
"history",
"Network History",
"Keep Network History Data",
m.Module,
"config:history/",
nil,
)
} }
func (m *module) prepare() error { func (m *module) prepare() error {
@@ -92,6 +105,58 @@ func (m *module) prepare() error {
return fmt.Errorf("failed to register API endpoint: %w", err) return fmt.Errorf("failed to register API endpoint: %w", err)
} }
if err := api.RegisterEndpoint(api.Endpoint{
Path: "netquery/history/clear",
MimeType: "application/json",
Read: api.PermitUser,
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
var body struct {
ProfileIDs []string `json:"profileIDs"`
}
defer r.Body.Close()
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(body.ProfileIDs) == 0 {
if err := m.mng.store.RemoveAllHistoryData(r.Context()); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
merr := new(multierror.Error)
for _, profileID := range body.ProfileIDs {
if err := m.mng.store.RemoveHistoryForProfile(r.Context(), profileID); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err))
} else {
log.Infof("netquery: successfully cleared history for %s", profileID)
}
}
if err := merr.ErrorOrNil(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
w.WriteHeader(http.StatusNoContent)
},
Name: "Remove connections from profile history",
Description: "Remove all connections from the history database for one or more profiles",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
return nil return nil
} }
@@ -163,5 +228,16 @@ func (m *module) start() error {
} }
func (m *module) stop() error { func (m *module) stop() error {
// we don't use m.Module.Ctx here because it is already cancelled when stop is called.
// just give the clean up 1 minute to happen and abort otherwise.
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil {
// handle the error by just logging it. There's not much we can do here
// and returning an error to the module system doesn't help much as well...
log.Errorf("failed to mark connections in history database as eded: %w", err)
}
return nil return nil
} }

View File

@@ -6,6 +6,7 @@ import (
"reflect" "reflect"
"time" "time"
"golang.org/x/exp/slices"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
@@ -22,7 +23,7 @@ type (
// ToParamMap returns a map that contains the sqlite compatible value of each struct field of // ToParamMap returns a map that contains the sqlite compatible value of each struct field of
// r using the sqlite column name as a map key. It either uses the name of the // r using the sqlite column name as a map key. It either uses the name of the
// exported struct field or the value of the "sqlite" tag. // exported struct field or the value of the "sqlite" tag.
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) { func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) {
// make sure we work on a struct type // make sure we work on a struct type
val := reflect.Indirect(reflect.ValueOf(r)) val := reflect.Indirect(reflect.ValueOf(r))
if val.Kind() != reflect.Struct { if val.Kind() != reflect.Struct {
@@ -45,6 +46,10 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err) return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
} }
if slices.Contains(skipFields, colDef.Name) {
continue
}
x, found, err := runEncodeHooks( x, found, err := runEncodeHooks(
colDef, colDef,
fieldType.Type, fieldType.Type,

View File

@@ -119,7 +119,7 @@ func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
for idx := range cases { //nolint:paralleltest for idx := range cases { //nolint:paralleltest
c := cases[idx] c := cases[idx]
t.Run(c.Desc, func(t *testing.T) { t.Run(c.Desc, func(t *testing.T) {
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig) res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, c.Expected, res) assert.Equal(t, c.Expected, res)
}) })

View File

@@ -143,7 +143,23 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
currentField := reflect.New(valElemType) currentField := reflect.New(valElemType)
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil { if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
return err resultDump := make(map[string]any)
for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ {
name := stmt.ColumnName(colIdx)
switch stmt.ColumnType(colIdx) {
case sqlite.TypeText:
resultDump[name] = stmt.ColumnText(colIdx)
case sqlite.TypeFloat:
resultDump[name] = stmt.ColumnFloat(colIdx)
case sqlite.TypeInteger:
resultDump[name] = stmt.ColumnInt(colIdx)
case sqlite.TypeNull:
resultDump[name] = "<null>"
}
}
return fmt.Errorf("%w: %+v", err, resultDump)
} }
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField)) sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))

View File

@@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/safing/portbase/log"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
@@ -25,6 +26,7 @@ var (
TagTypePrefixVarchar = "varchar" TagTypePrefixVarchar = "varchar"
TagTypeBlob = "blob" TagTypeBlob = "blob"
TagTypeFloat = "float" TagTypeFloat = "float"
TagTypePrefixDefault = "default="
) )
var sqlTypeMap = map[sqlite.ColumnType]string{ var sqlTypeMap = map[sqlite.ColumnType]string{
@@ -52,6 +54,7 @@ type (
AutoIncrement bool AutoIncrement bool
UnixNano bool UnixNano bool
IsTime bool IsTime bool
Default any
} }
) )
@@ -105,6 +108,21 @@ func (def ColumnDef) AsSQL() string {
if def.AutoIncrement { if def.AutoIncrement {
sql += " AUTOINCREMENT" sql += " AUTOINCREMENT"
} }
if def.Default != nil {
sql += " DEFAULT "
switch def.Type {
case sqlite.TypeFloat:
sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64)
case sqlite.TypeInteger:
sql += strconv.FormatInt(def.Default.(int64), 10)
case sqlite.TypeText:
sql += fmt.Sprintf("%q", def.Default.(string))
default:
log.Errorf("unsupported default value: %q %q", def.Type, def.Default)
sql = strings.TrimSuffix(sql, " DEFAULT ")
}
sql += " "
}
if !def.Nullable { if !def.Nullable {
sql += " NOT NULL" sql += " NOT NULL"
} }
@@ -160,7 +178,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
kind := normalizeKind(ft.Kind()) kind := normalizeKind(ft.Kind())
switch kind { //nolint:exhaustive switch kind { //nolint:exhaustive
case reflect.Int: case reflect.Int, reflect.Uint:
def.Type = sqlite.TypeInteger def.Type = sqlite.TypeInteger
case reflect.Float64: case reflect.Float64:
@@ -237,6 +255,30 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
def.Length = int(length) def.Length = int(length)
} }
if strings.HasPrefix(k, TagTypePrefixDefault) {
defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault)
switch def.Type {
case sqlite.TypeFloat:
fv, err := strconv.ParseFloat(defaultValue, 64)
if err != nil {
return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err)
}
def.Default = fv
case sqlite.TypeInteger:
fv, err := strconv.ParseInt(defaultValue, 10, 0)
if err != nil {
return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err)
}
def.Default = fv
case sqlite.TypeText:
def.Default = defaultValue
case sqlite.TypeBlob:
return fmt.Errorf("default values for TypeBlob not yet supported")
default:
return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type)
}
}
} }
} }
} }

View File

@@ -22,14 +22,14 @@ func TestSchemaBuilder(t *testing.T) {
Int *int `sqlite:",not-null"` Int *int `sqlite:",not-null"`
Float interface{} `sqlite:",float,nullable"` Float interface{} `sqlite:",float,nullable"`
}{}, }{},
`CREATE TABLE Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`, `CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
}, },
{ {
"Varchar", "Varchar",
struct { struct {
S string `sqlite:",varchar(10)"` S string `sqlite:",varchar(10)"`
}{}, }{},
`CREATE TABLE Varchar ( S VARCHAR(10) NOT NULL );`, `CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`,
}, },
} }
@@ -38,6 +38,6 @@ func TestSchemaBuilder(t *testing.T) {
res, err := GenerateTableSchema(c.Name, c.Model) res, err := GenerateTableSchema(c.Name, c.Model)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, c.ExpectedSQL, res.CreateStatement(false)) assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false))
} }
} }

View File

@@ -14,6 +14,13 @@ import (
"github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/netquery/orm"
) )
type DatabaseName string
const (
LiveDatabase = DatabaseName("main")
HistoryDatabase = DatabaseName("history")
)
// Collection of Query and Matcher types. // Collection of Query and Matcher types.
// NOTE: whenever adding support for new operators make sure // NOTE: whenever adding support for new operators make sure
// to update UnmarshalJSON as well. // to update UnmarshalJSON as well.
@@ -48,11 +55,19 @@ type (
Distinct bool `json:"distinct"` Distinct bool `json:"distinct"`
} }
Min struct {
Condition *Query `json:"condition,omitempty"`
Field string `json:"field"`
As string `json:"as"`
Distinct bool `json:"distinct"`
}
Select struct { Select struct {
Field string `json:"field"` Field string `json:"field"`
Count *Count `json:"$count,omitempty"` Count *Count `json:"$count,omitempty"`
Sum *Sum `json:"$sum,omitempty"` Sum *Sum `json:"$sum,omitempty"`
Distinct *string `json:"$distinct"` Min *Min `json:"$min,omitempty"`
Distinct *string `json:"$distinct,omitempty"`
} }
Selects []Select Selects []Select
@@ -68,6 +83,9 @@ type (
OrderBy OrderBys `json:"orderBy"` OrderBy OrderBys `json:"orderBy"`
GroupBy []string `json:"groupBy"` GroupBy []string `json:"groupBy"`
TextSearch *TextSearch `json:"textSearch"` TextSearch *TextSearch `json:"textSearch"`
// A list of databases to query. If left empty,
// both, the LiveDatabase and the HistoryDatabase are queried
Databases []DatabaseName `json:"databases"`
Pagination Pagination
@@ -457,6 +475,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
Field string `json:"field"` Field string `json:"field"`
Count *Count `json:"$count"` Count *Count `json:"$count"`
Sum *Sum `json:"$sum"` Sum *Sum `json:"$sum"`
Min *Min `json:"$min"`
Distinct *string `json:"$distinct"` Distinct *string `json:"$distinct"`
} }
@@ -468,12 +487,23 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
sel.Field = res.Field sel.Field = res.Field
sel.Distinct = res.Distinct sel.Distinct = res.Distinct
sel.Sum = res.Sum sel.Sum = res.Sum
sel.Min = res.Min
if sel.Count != nil && sel.Count.As != "" { if sel.Count != nil && sel.Count.As != "" {
if !charOnlyRegexp.MatchString(sel.Count.As) { if !charOnlyRegexp.MatchString(sel.Count.As) {
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+") return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
} }
} }
if sel.Sum != nil && sel.Sum.As != "" {
if !charOnlyRegexp.MatchString(sel.Sum.As) {
return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+")
}
}
if sel.Min != nil && sel.Min.As != "" {
if !charOnlyRegexp.MatchString(sel.Min.As) {
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
}
}
return nil return nil
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/netquery/orm"
"golang.org/x/exp/slices"
) )
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+") var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
@@ -152,13 +153,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
return "", nil, fmt.Errorf("generating where clause: %w", err) return "", nil, fmt.Errorf("generating where clause: %w", err)
} }
if req.paramMap == nil { req.mergeParams(paramMap)
req.paramMap = make(map[string]interface{})
}
for key, val := range paramMap {
req.paramMap[key] = val
}
if req.TextSearch != nil { if req.TextSearch != nil {
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig) textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
@@ -173,9 +168,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
whereClause += textClause whereClause += textClause
for key, val := range textParams { req.mergeParams(textParams)
req.paramMap[key] = val
}
} }
} }
@@ -190,15 +183,21 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
} }
selectClause := req.generateSelectClause() selectClause := req.generateSelectClause()
inMem := `SELECT *, 'live' as _source FROM main.connections `
inHistory := `SELECT *, 'history' as _source FROM history.connections `
if whereClause != "" { if whereClause != "" {
inMem += " WHERE " + whereClause whereClause = "WHERE " + whereClause
inHistory += " WHERE " + whereClause
} }
source := inMem + " UNION " + inHistory if len(req.Databases) == 0 {
req.Databases = []DatabaseName{LiveDatabase, HistoryDatabase}
}
sources := make([]string, len(req.Databases))
for idx, db := range req.Databases {
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
}
source := strings.Join(sources, " UNION ")
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) ` query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
@@ -210,6 +209,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error { func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
for idx, s := range req.Select { for idx, s := range req.Select {
var field string var field string
switch { switch {
case s.Count != nil: case s.Count != nil:
field = s.Count.Field field = s.Count.Field
@@ -218,6 +218,12 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
case s.Sum != nil: case s.Sum != nil:
// field is not used in case of $sum // field is not used in case of $sum
field = "*" field = "*"
case s.Min != nil:
if s.Min.Field != "" {
field = s.Min.Field
} else {
field = "*"
}
default: default:
field = s.Field field = s.Field
} }
@@ -258,13 +264,40 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
return fmt.Errorf("in $sum: %w", err) return fmt.Errorf("in $sum: %w", err)
} }
req.paramMap = params req.mergeParams(params)
req.selectedFields = append( req.selectedFields = append(
req.selectedFields, req.selectedFields,
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As), fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
) )
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As) req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
case s.Min != nil:
if s.Min.As == "" {
return fmt.Errorf("missing 'as' for $min")
}
var (
clause string
params map[string]any
)
if s.Min.Field != "" {
clause = field
} else {
var err error
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $min: %w", err)
}
}
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
case s.Distinct != nil: case s.Distinct != nil:
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName)) req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
req.whitelistedFields = append(req.whitelistedFields, colName) req.whitelistedFields = append(req.whitelistedFields, colName)
@@ -277,6 +310,16 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
return nil return nil
} }
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
if req.paramMap == nil {
req.paramMap = make(map[string]any)
}
for key, value := range params {
req.paramMap[key] = value
}
}
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) { func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
if len(req.GroupBy) == 0 { if len(req.GroupBy) == 0 {
return "", nil return "", nil
@@ -339,16 +382,12 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
return colDef.Name, nil return colDef.Name, nil
} }
for _, selected := range req.whitelistedFields { if slices.Contains(req.whitelistedFields, field) {
if field == selected { return field, nil
return field, nil
}
} }
for _, selected := range req.selectedFields { if slices.Contains(req.selectedFields, field) {
if field == selected { return field, nil
return field, nil
}
} }
return "", fmt.Errorf("column name %q not allowed", field) return "", fmt.Errorf("column name %q not allowed", field)

View File

@@ -313,6 +313,10 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*
return process, nil return process, nil
} }
func (p *Process) GetID() string {
return p.processKey
}
// Builds a unique identifier for a processes. // Builds a unique identifier for a processes.
func getProcessKey(pid int32, createdTime int64) string { func getProcessKey(pid int32, createdTime int64) string {
return fmt.Sprintf("%d-%d", pid, createdTime) return fmt.Sprintf("%d-%d", pid, createdTime)

View File

@@ -6,6 +6,7 @@ import (
"github.com/safing/portbase/config" "github.com/safing/portbase/config"
"github.com/safing/portmaster/profile/endpoints" "github.com/safing/portmaster/profile/endpoints"
"github.com/safing/portmaster/status" "github.com/safing/portmaster/status"
"github.com/safing/spn/access/account"
"github.com/safing/spn/navigator" "github.com/safing/spn/navigator"
) )
@@ -105,7 +106,7 @@ var (
// Setting "Permanent Verdicts" at order 96. // Setting "Permanent Verdicts" at order 96.
CfgOptionEnableHistoryKey = "filter/enableHistory" CfgOptionEnableHistoryKey = "history/enabled"
cfgOptionEnableHistory config.BoolOption cfgOptionEnableHistory config.BoolOption
cfgOptionEnableHistoryOrder = 66 cfgOptionEnableHistoryOrder = 66
@@ -249,12 +250,13 @@ func registerConfiguration() error { //nolint:maintidx
Key: CfgOptionEnableHistoryKey, Key: CfgOptionEnableHistoryKey,
Description: "Whether or not to save connections to the history database", Description: "Whether or not to save connections to the history database",
OptType: config.OptTypeBool, OptType: config.OptTypeBool,
ReleaseLevel: config.ReleaseLevelExperimental, ReleaseLevel: config.ReleaseLevelStable,
ExpertiseLevel: config.ExpertiseLevelExpert, ExpertiseLevel: config.ExpertiseLevelExpert,
DefaultValue: false, DefaultValue: false,
Annotations: config.Annotations{ Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder, config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder,
config.CategoryAnnotation: "Advanced", config.CategoryAnnotation: "History",
config.SettingRequiresFeaturePlan: account.FeatureHistory,
}, },
}) })
if err != nil { if err != nil {