Update netquery to support history module
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -53,3 +53,5 @@ _testmain.go
|
||||
|
||||
# Custom dev scripts
|
||||
win_dev_*
|
||||
go.work
|
||||
go.work.sum
|
||||
|
||||
@@ -2,18 +2,23 @@ package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/jackc/puddle/v2"
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
"github.com/safing/portmaster/network"
|
||||
@@ -46,6 +51,7 @@ type (
|
||||
Schema *orm.TableSchema
|
||||
|
||||
readConnPool *puddle.Pool[*sqlite.Conn]
|
||||
historyPath string
|
||||
|
||||
l sync.Mutex
|
||||
writeConn *sqlite.Conn
|
||||
@@ -82,7 +88,9 @@ type (
|
||||
Latitude float64 `sqlite:"latitude"`
|
||||
Longitude float64 `sqlite:"longitude"`
|
||||
Scope netutils.IPScope `sqlite:"scope"`
|
||||
Verdict network.Verdict `sqlite:"verdict"`
|
||||
WorstVerdict network.Verdict `sqlite:"worst_verdict"`
|
||||
ActiveVerdict network.Verdict `sqlite:"verdict"`
|
||||
FirewallVerdict network.Verdict `sqlite:"firewall_verdict"`
|
||||
Started time.Time `sqlite:"started,text,time"`
|
||||
Ended *time.Time `sqlite:"ended,text,time"`
|
||||
Tunneled bool `sqlite:"tunneled"`
|
||||
@@ -93,6 +101,8 @@ type (
|
||||
Allowed *bool `sqlite:"allowed"`
|
||||
ProfileRevision int `sqlite:"profile_revision"`
|
||||
ExitNode *string `sqlite:"exit_node"`
|
||||
BWIncoming uint64 `sqlite:"bw_incoming,default=0"`
|
||||
BWOutgoing uint64 `sqlite:"bw_outgoing,default=0"`
|
||||
|
||||
// TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
|
||||
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
|
||||
@@ -108,21 +118,27 @@ type (
|
||||
// (see Execute). To perform database writes use either Save() or ExecuteWrite().
|
||||
// Note that write connections are serialized by the Database object before being
|
||||
// handed over to SQLite.
|
||||
func New(path string) (*Database, error) {
|
||||
func New(dbPath string) (*Database, error) {
|
||||
historyParentDir := dataroot.Root().ChildDir("databases", 0o700)
|
||||
if err := historyParentDir.Ensure(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure database directory exists: %w", err)
|
||||
}
|
||||
|
||||
historyPath := "file://" + path.Join(historyParentDir.Path, "history.db")
|
||||
|
||||
constructor := func(ctx context.Context) (*sqlite.Conn, error) {
|
||||
c, err := sqlite.OpenConn(
|
||||
path,
|
||||
dbPath,
|
||||
sqlite.OpenReadOnly,
|
||||
sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit.
|
||||
sqlite.OpenSharedCache,
|
||||
//sqlite.OpenMemory,
|
||||
sqlite.OpenURI,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", path, err)
|
||||
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", dbPath, err)
|
||||
}
|
||||
|
||||
if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE 'file:///tmp/history.db?mode=ro' AS history", nil); err != nil {
|
||||
if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE '"+historyPath+"?mode=ro' AS history", nil); err != nil {
|
||||
return nil, fmt.Errorf("failed to attach history database: %w", err)
|
||||
}
|
||||
|
||||
@@ -150,23 +166,23 @@ func New(path string) (*Database, error) {
|
||||
}
|
||||
|
||||
writeConn, err := sqlite.OpenConn(
|
||||
path,
|
||||
dbPath,
|
||||
sqlite.OpenCreate,
|
||||
sqlite.OpenReadWrite,
|
||||
sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit.
|
||||
sqlite.OpenWAL,
|
||||
sqlite.OpenSharedCache,
|
||||
//sqlite.OpenMemory,
|
||||
sqlite.OpenURI,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err)
|
||||
return nil, fmt.Errorf("failed to open sqlite at %s: %w", dbPath, err)
|
||||
}
|
||||
|
||||
return &Database{
|
||||
readConnPool: pool,
|
||||
Schema: schema,
|
||||
writeConn: writeConn,
|
||||
historyPath: historyPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -197,9 +213,7 @@ func (db *Database) ApplyMigrations() error {
|
||||
db.l.Lock()
|
||||
defer db.l.Unlock()
|
||||
|
||||
// Attach the history database
|
||||
log.Errorf("attaching database history")
|
||||
if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE 'file:///tmp/history.db?mode=rwc' AS 'history';", nil); err != nil {
|
||||
if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE '"+db.historyPath+"?mode=rwc' AS 'history';", nil); err != nil {
|
||||
return fmt.Errorf("failed to attach history database: %w", err)
|
||||
}
|
||||
|
||||
@@ -207,7 +221,7 @@ func (db *Database) ApplyMigrations() error {
|
||||
for _, dbName := range dbNames {
|
||||
// get the create-table SQL statement from the inferred schema
|
||||
sql := db.Schema.CreateStatement(dbName, true)
|
||||
log.Errorf("creating table schema for database %q", dbName)
|
||||
log.Debugf("creating table schema for database %q", dbName)
|
||||
|
||||
// execute the SQL
|
||||
if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil {
|
||||
@@ -285,7 +299,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
|
||||
return result[0].Count, nil
|
||||
}
|
||||
|
||||
// Cleanup removes all connections that have ended before threshold.
|
||||
// Cleanup removes all connections that have ended before threshold from the live database.
|
||||
//
|
||||
// NOTE(ppacher): there is no easy way to get the number of removed
|
||||
// rows other than counting them in a first step. Though, that's
|
||||
@@ -323,6 +337,18 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro
|
||||
return result[0].Count, nil
|
||||
}
|
||||
|
||||
func (db *Database) RemoveAllHistoryData(ctx context.Context) error {
|
||||
query := fmt.Sprintf("DELETE FROM %s.connections", HistoryDatabase)
|
||||
return db.ExecuteWrite(ctx, query)
|
||||
}
|
||||
|
||||
func (db *Database) RemoveHistoryForProfile(ctx context.Context, profileID string) error {
|
||||
query := fmt.Sprintf("DELETE FROM %s.connections WHERE profile = :profile", HistoryDatabase)
|
||||
return db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
|
||||
":profile": profileID,
|
||||
}))
|
||||
}
|
||||
|
||||
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database
|
||||
// as JSON to w.
|
||||
// Any error aborts dumping rows and is returned.
|
||||
@@ -350,13 +376,74 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:un
|
||||
return enc.Encode(conns)
|
||||
}
|
||||
|
||||
// MarkAllHistoryConnectionsEnded marks all connections in the history database as ended.
|
||||
func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
|
||||
query := fmt.Sprintf("UPDATE %s.connections SET active = FALSE, ended = :ended WHERE active = TRUE", HistoryDatabase)
|
||||
|
||||
if err := db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
|
||||
":ended": time.Now().Format(orm.SqliteTimeFormat),
|
||||
})); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, incoming *uint64, outgoing *uint64) error {
|
||||
data := connID + "-" + processKey
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
dbConnId := hex.EncodeToString(hash[:])
|
||||
|
||||
params := map[string]any{
|
||||
":id": dbConnId,
|
||||
}
|
||||
|
||||
parts := []string{}
|
||||
if incoming != nil {
|
||||
parts = append(parts, "bw_incoming = :bw_incoming")
|
||||
params[":bw_incoming"] = *incoming
|
||||
}
|
||||
|
||||
if outgoing != nil {
|
||||
parts = append(parts, "bw_outgoing = :bw_outgoing")
|
||||
params[":bw_outgoing"] = *outgoing
|
||||
}
|
||||
|
||||
updateSet := strings.Join(parts, ", ")
|
||||
|
||||
updateStmts := []string{
|
||||
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, LiveDatabase, updateSet),
|
||||
}
|
||||
|
||||
if enableHistory {
|
||||
updateStmts = append(updateStmts,
|
||||
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, HistoryDatabase, updateSet),
|
||||
)
|
||||
}
|
||||
|
||||
merr := new(multierror.Error)
|
||||
for _, stmt := range updateStmts {
|
||||
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
|
||||
merr.Errors = append(merr.Errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
return merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// Save inserts the connection conn into the SQLite database. If conn
|
||||
// already exists the table row is updated instead.
|
||||
//
|
||||
// Save uses the database write connection instead of relying on the
|
||||
// connection pool.
|
||||
func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error {
|
||||
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig)
|
||||
// convert the connection to a param map where each key is already translated
|
||||
// to the sql column name. We also skip bw_incoming and bw_outgoing since those
|
||||
// will be updated independenly from the connection object.
|
||||
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig, []string{
|
||||
"bw_incoming",
|
||||
"bw_outgoing",
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode connection for SQL: %w", err)
|
||||
}
|
||||
@@ -387,10 +474,10 @@ func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) err
|
||||
|
||||
// TODO(ppacher): make sure this one can be cached to speed up inserting
|
||||
// and save some CPU cycles for the user
|
||||
dbNames := []string{"main"}
|
||||
dbNames := []DatabaseName{LiveDatabase}
|
||||
|
||||
if enableHistory {
|
||||
dbNames = append(dbNames, "history")
|
||||
dbNames = append(dbNames, HistoryDatabase)
|
||||
}
|
||||
|
||||
for _, dbName := range dbNames {
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/runtime"
|
||||
"github.com/safing/portmaster/network"
|
||||
"github.com/safing/spn/access"
|
||||
"github.com/safing/spn/access/account"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -26,6 +28,21 @@ type (
|
||||
// The ID of Conn is unique and can be trusted to never collide with other
|
||||
// connections of the save device.
|
||||
Save(context.Context, Conn, bool) error
|
||||
|
||||
// MarkAllHistoryConnectionsEnded marks all active connections in the history
|
||||
// database as ended NOW.
|
||||
MarkAllHistoryConnectionsEnded(context.Context) error
|
||||
|
||||
// RemoveHistoryForProfile removes all connections from the history database
|
||||
// for a given profile ID (source/id)
|
||||
RemoveHistoryForProfile(context.Context, string) error
|
||||
|
||||
// RemoveAllHistoryData removes all connections from the history database.
|
||||
RemoveAllHistoryData(context.Context) error
|
||||
|
||||
// UpdateBandwidth updates bandwith data for the connection and optionally also writes
|
||||
// the bandwidth data to the history database.
|
||||
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, incoming *uint64, outgoing *uint64) error
|
||||
}
|
||||
|
||||
// Manager handles new and updated network.Connections feeds and persists them
|
||||
@@ -100,7 +117,20 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
|
||||
|
||||
log.Tracef("netquery: updating connection %s", conn.ID)
|
||||
|
||||
if err := mng.store.Save(ctx, *model, conn.Process().Profile().HistoryEnabled()); err != nil {
|
||||
// check if we should persist the connection in the history database.
|
||||
// Also make sure the current SPN User/subscription allows use of the history.
|
||||
historyEnabled := conn.Process().Profile().HistoryEnabled()
|
||||
if historyEnabled {
|
||||
user, err := access.GetUser()
|
||||
if err != nil {
|
||||
// there was an error so disable history
|
||||
historyEnabled = false
|
||||
} else if !user.MayUse(account.FeatureHistory) {
|
||||
historyEnabled = false
|
||||
}
|
||||
}
|
||||
|
||||
if err := mng.store.Save(ctx, *model, historyEnabled); err != nil {
|
||||
log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err)
|
||||
|
||||
continue
|
||||
@@ -158,7 +188,9 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
|
||||
IPProtocol: conn.IPProtocol,
|
||||
LocalIP: conn.LocalIP.String(),
|
||||
LocalPort: conn.LocalPort,
|
||||
Verdict: conn.Verdict.Firewall, // TODO: Expose both Worst and Firewall verdicts.
|
||||
FirewallVerdict: conn.Verdict.Firewall,
|
||||
ActiveVerdict: conn.Verdict.Active,
|
||||
WorstVerdict: conn.Verdict.Worst,
|
||||
Started: time.Unix(conn.Started, 0),
|
||||
Tunneled: conn.Tunneled,
|
||||
Encrypted: conn.Encrypted,
|
||||
@@ -250,7 +282,7 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
|
||||
}
|
||||
|
||||
func genConnID(conn *network.Connection) string {
|
||||
data := conn.ID + "-" + time.Unix(conn.Started, 0).String()
|
||||
data := conn.ID + "-" + conn.Process().GetID()
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
@@ -2,15 +2,19 @@ package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
"github.com/safing/portbase/runtime"
|
||||
"github.com/safing/portmaster/network"
|
||||
)
|
||||
@@ -35,6 +39,15 @@ func init() {
|
||||
"network",
|
||||
"database",
|
||||
)
|
||||
|
||||
subsystems.Register(
|
||||
"history",
|
||||
"Network History",
|
||||
"Keep Network History Data",
|
||||
m.Module,
|
||||
"config:history/",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (m *module) prepare() error {
|
||||
@@ -92,6 +105,58 @@ func (m *module) prepare() error {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "netquery/history/clear",
|
||||
MimeType: "application/json",
|
||||
Read: api.PermitUser,
|
||||
Write: api.PermitUser,
|
||||
BelongsTo: m.Module,
|
||||
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
ProfileIDs []string `json:"profileIDs"`
|
||||
}
|
||||
|
||||
defer r.Body.Close()
|
||||
|
||||
dec := json.NewDecoder(r.Body)
|
||||
dec.DisallowUnknownFields()
|
||||
|
||||
if err := dec.Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(body.ProfileIDs) == 0 {
|
||||
if err := m.mng.store.RemoveAllHistoryData(r.Context()); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
} else {
|
||||
merr := new(multierror.Error)
|
||||
for _, profileID := range body.ProfileIDs {
|
||||
if err := m.mng.store.RemoveHistoryForProfile(r.Context(), profileID); err != nil {
|
||||
merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err))
|
||||
} else {
|
||||
log.Infof("netquery: successfully cleared history for %s", profileID)
|
||||
}
|
||||
}
|
||||
|
||||
if err := merr.ErrorOrNil(); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
},
|
||||
Name: "Remove connections from profile history",
|
||||
Description: "Remove all connections from the history database for one or more profiles",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -163,5 +228,16 @@ func (m *module) start() error {
|
||||
}
|
||||
|
||||
func (m *module) stop() error {
|
||||
// we don't use m.Module.Ctx here because it is already cancelled when stop is called.
|
||||
// just give the clean up 1 minute to happen and abort otherwise.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil {
|
||||
// handle the error by just logging it. There's not much we can do here
|
||||
// and returning an error to the module system doesn't help much as well...
|
||||
log.Errorf("failed to mark connections in history database as eded: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ type (
|
||||
// ToParamMap returns a map that contains the sqlite compatible value of each struct field of
|
||||
// r using the sqlite column name as a map key. It either uses the name of the
|
||||
// exported struct field or the value of the "sqlite" tag.
|
||||
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) {
|
||||
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) {
|
||||
// make sure we work on a struct type
|
||||
val := reflect.Indirect(reflect.ValueOf(r))
|
||||
if val.Kind() != reflect.Struct {
|
||||
@@ -45,6 +46,10 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode
|
||||
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
if slices.Contains(skipFields, colDef.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
x, found, err := runEncodeHooks(
|
||||
colDef,
|
||||
fieldType.Type,
|
||||
|
||||
@@ -119,7 +119,7 @@ func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig)
|
||||
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Expected, res)
|
||||
})
|
||||
|
||||
@@ -143,7 +143,23 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
|
||||
currentField := reflect.New(valElemType)
|
||||
|
||||
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||
return err
|
||||
resultDump := make(map[string]any)
|
||||
|
||||
for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ {
|
||||
name := stmt.ColumnName(colIdx)
|
||||
|
||||
switch stmt.ColumnType(colIdx) {
|
||||
case sqlite.TypeText:
|
||||
resultDump[name] = stmt.ColumnText(colIdx)
|
||||
case sqlite.TypeFloat:
|
||||
resultDump[name] = stmt.ColumnFloat(colIdx)
|
||||
case sqlite.TypeInteger:
|
||||
resultDump[name] = stmt.ColumnInt(colIdx)
|
||||
case sqlite.TypeNull:
|
||||
resultDump[name] = "<null>"
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%w: %+v", err, resultDump)
|
||||
}
|
||||
|
||||
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
@@ -25,6 +26,7 @@ var (
|
||||
TagTypePrefixVarchar = "varchar"
|
||||
TagTypeBlob = "blob"
|
||||
TagTypeFloat = "float"
|
||||
TagTypePrefixDefault = "default="
|
||||
)
|
||||
|
||||
var sqlTypeMap = map[sqlite.ColumnType]string{
|
||||
@@ -52,6 +54,7 @@ type (
|
||||
AutoIncrement bool
|
||||
UnixNano bool
|
||||
IsTime bool
|
||||
Default any
|
||||
}
|
||||
)
|
||||
|
||||
@@ -105,6 +108,21 @@ func (def ColumnDef) AsSQL() string {
|
||||
if def.AutoIncrement {
|
||||
sql += " AUTOINCREMENT"
|
||||
}
|
||||
if def.Default != nil {
|
||||
sql += " DEFAULT "
|
||||
switch def.Type {
|
||||
case sqlite.TypeFloat:
|
||||
sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64)
|
||||
case sqlite.TypeInteger:
|
||||
sql += strconv.FormatInt(def.Default.(int64), 10)
|
||||
case sqlite.TypeText:
|
||||
sql += fmt.Sprintf("%q", def.Default.(string))
|
||||
default:
|
||||
log.Errorf("unsupported default value: %q %q", def.Type, def.Default)
|
||||
sql = strings.TrimSuffix(sql, " DEFAULT ")
|
||||
}
|
||||
sql += " "
|
||||
}
|
||||
if !def.Nullable {
|
||||
sql += " NOT NULL"
|
||||
}
|
||||
@@ -160,7 +178,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
|
||||
kind := normalizeKind(ft.Kind())
|
||||
|
||||
switch kind { //nolint:exhaustive
|
||||
case reflect.Int:
|
||||
case reflect.Int, reflect.Uint:
|
||||
def.Type = sqlite.TypeInteger
|
||||
|
||||
case reflect.Float64:
|
||||
@@ -237,6 +255,30 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
|
||||
def.Length = int(length)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(k, TagTypePrefixDefault) {
|
||||
defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault)
|
||||
switch def.Type {
|
||||
case sqlite.TypeFloat:
|
||||
fv, err := strconv.ParseFloat(defaultValue, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err)
|
||||
}
|
||||
def.Default = fv
|
||||
case sqlite.TypeInteger:
|
||||
fv, err := strconv.ParseInt(defaultValue, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err)
|
||||
}
|
||||
def.Default = fv
|
||||
case sqlite.TypeText:
|
||||
def.Default = defaultValue
|
||||
case sqlite.TypeBlob:
|
||||
return fmt.Errorf("default values for TypeBlob not yet supported")
|
||||
default:
|
||||
return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,14 +22,14 @@ func TestSchemaBuilder(t *testing.T) {
|
||||
Int *int `sqlite:",not-null"`
|
||||
Float interface{} `sqlite:",float,nullable"`
|
||||
}{},
|
||||
`CREATE TABLE Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
|
||||
`CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
|
||||
},
|
||||
{
|
||||
"Varchar",
|
||||
struct {
|
||||
S string `sqlite:",varchar(10)"`
|
||||
}{},
|
||||
`CREATE TABLE Varchar ( S VARCHAR(10) NOT NULL );`,
|
||||
`CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -38,6 +38,6 @@ func TestSchemaBuilder(t *testing.T) {
|
||||
|
||||
res, err := GenerateTableSchema(c.Name, c.Model)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.ExpectedSQL, res.CreateStatement(false))
|
||||
assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,13 @@ import (
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
)
|
||||
|
||||
type DatabaseName string
|
||||
|
||||
const (
|
||||
LiveDatabase = DatabaseName("main")
|
||||
HistoryDatabase = DatabaseName("history")
|
||||
)
|
||||
|
||||
// Collection of Query and Matcher types.
|
||||
// NOTE: whenever adding support for new operators make sure
|
||||
// to update UnmarshalJSON as well.
|
||||
@@ -48,11 +55,19 @@ type (
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
Min struct {
|
||||
Condition *Query `json:"condition,omitempty"`
|
||||
Field string `json:"field"`
|
||||
As string `json:"as"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
Select struct {
|
||||
Field string `json:"field"`
|
||||
Count *Count `json:"$count,omitempty"`
|
||||
Sum *Sum `json:"$sum,omitempty"`
|
||||
Distinct *string `json:"$distinct"`
|
||||
Min *Min `json:"$min,omitempty"`
|
||||
Distinct *string `json:"$distinct,omitempty"`
|
||||
}
|
||||
|
||||
Selects []Select
|
||||
@@ -68,6 +83,9 @@ type (
|
||||
OrderBy OrderBys `json:"orderBy"`
|
||||
GroupBy []string `json:"groupBy"`
|
||||
TextSearch *TextSearch `json:"textSearch"`
|
||||
// A list of databases to query. If left empty,
|
||||
// both, the LiveDatabase and the HistoryDatabase are queried
|
||||
Databases []DatabaseName `json:"databases"`
|
||||
|
||||
Pagination
|
||||
|
||||
@@ -457,6 +475,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||
Field string `json:"field"`
|
||||
Count *Count `json:"$count"`
|
||||
Sum *Sum `json:"$sum"`
|
||||
Min *Min `json:"$min"`
|
||||
Distinct *string `json:"$distinct"`
|
||||
}
|
||||
|
||||
@@ -468,12 +487,23 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||
sel.Field = res.Field
|
||||
sel.Distinct = res.Distinct
|
||||
sel.Sum = res.Sum
|
||||
sel.Min = res.Min
|
||||
|
||||
if sel.Count != nil && sel.Count.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Count.As) {
|
||||
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
if sel.Sum != nil && sel.Sum.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Sum.As) {
|
||||
return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
if sel.Min != nil && sel.Min.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Min.As) {
|
||||
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
|
||||
@@ -152,13 +153,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||
return "", nil, fmt.Errorf("generating where clause: %w", err)
|
||||
}
|
||||
|
||||
if req.paramMap == nil {
|
||||
req.paramMap = make(map[string]interface{})
|
||||
}
|
||||
|
||||
for key, val := range paramMap {
|
||||
req.paramMap[key] = val
|
||||
}
|
||||
req.mergeParams(paramMap)
|
||||
|
||||
if req.TextSearch != nil {
|
||||
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
|
||||
@@ -173,9 +168,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||
|
||||
whereClause += textClause
|
||||
|
||||
for key, val := range textParams {
|
||||
req.paramMap[key] = val
|
||||
}
|
||||
req.mergeParams(textParams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,15 +183,21 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||
}
|
||||
|
||||
selectClause := req.generateSelectClause()
|
||||
inMem := `SELECT *, 'live' as _source FROM main.connections `
|
||||
inHistory := `SELECT *, 'history' as _source FROM history.connections `
|
||||
|
||||
if whereClause != "" {
|
||||
inMem += " WHERE " + whereClause
|
||||
inHistory += " WHERE " + whereClause
|
||||
whereClause = "WHERE " + whereClause
|
||||
}
|
||||
|
||||
source := inMem + " UNION " + inHistory
|
||||
if len(req.Databases) == 0 {
|
||||
req.Databases = []DatabaseName{LiveDatabase, HistoryDatabase}
|
||||
}
|
||||
|
||||
sources := make([]string, len(req.Databases))
|
||||
for idx, db := range req.Databases {
|
||||
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
|
||||
}
|
||||
|
||||
source := strings.Join(sources, " UNION ")
|
||||
|
||||
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
|
||||
|
||||
@@ -210,6 +209,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
|
||||
for idx, s := range req.Select {
|
||||
var field string
|
||||
|
||||
switch {
|
||||
case s.Count != nil:
|
||||
field = s.Count.Field
|
||||
@@ -218,6 +218,12 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
|
||||
case s.Sum != nil:
|
||||
// field is not used in case of $sum
|
||||
field = "*"
|
||||
case s.Min != nil:
|
||||
if s.Min.Field != "" {
|
||||
field = s.Min.Field
|
||||
} else {
|
||||
field = "*"
|
||||
}
|
||||
default:
|
||||
field = s.Field
|
||||
}
|
||||
@@ -258,13 +264,40 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
|
||||
return fmt.Errorf("in $sum: %w", err)
|
||||
}
|
||||
|
||||
req.paramMap = params
|
||||
req.mergeParams(params)
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
|
||||
|
||||
case s.Min != nil:
|
||||
if s.Min.As == "" {
|
||||
return fmt.Errorf("missing 'as' for $min")
|
||||
}
|
||||
|
||||
var (
|
||||
clause string
|
||||
params map[string]any
|
||||
)
|
||||
|
||||
if s.Min.Field != "" {
|
||||
clause = field
|
||||
} else {
|
||||
var err error
|
||||
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in $min: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req.mergeParams(params)
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
|
||||
|
||||
case s.Distinct != nil:
|
||||
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
|
||||
req.whitelistedFields = append(req.whitelistedFields, colName)
|
||||
@@ -277,6 +310,16 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
|
||||
if req.paramMap == nil {
|
||||
req.paramMap = make(map[string]any)
|
||||
}
|
||||
|
||||
for key, value := range params {
|
||||
req.paramMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
|
||||
if len(req.GroupBy) == 0 {
|
||||
return "", nil
|
||||
@@ -339,16 +382,12 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
|
||||
return colDef.Name, nil
|
||||
}
|
||||
|
||||
for _, selected := range req.whitelistedFields {
|
||||
if field == selected {
|
||||
return field, nil
|
||||
}
|
||||
if slices.Contains(req.whitelistedFields, field) {
|
||||
return field, nil
|
||||
}
|
||||
|
||||
for _, selected := range req.selectedFields {
|
||||
if field == selected {
|
||||
return field, nil
|
||||
}
|
||||
if slices.Contains(req.selectedFields, field) {
|
||||
return field, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("column name %q not allowed", field)
|
||||
|
||||
@@ -313,6 +313,10 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*
|
||||
return process, nil
|
||||
}
|
||||
|
||||
func (p *Process) GetID() string {
|
||||
return p.processKey
|
||||
}
|
||||
|
||||
// Builds a unique identifier for a processes.
|
||||
func getProcessKey(pid int32, createdTime int64) string {
|
||||
return fmt.Sprintf("%d-%d", pid, createdTime)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portmaster/profile/endpoints"
|
||||
"github.com/safing/portmaster/status"
|
||||
"github.com/safing/spn/access/account"
|
||||
"github.com/safing/spn/navigator"
|
||||
)
|
||||
|
||||
@@ -105,7 +106,7 @@ var (
|
||||
|
||||
// Setting "Permanent Verdicts" at order 96.
|
||||
|
||||
CfgOptionEnableHistoryKey = "filter/enableHistory"
|
||||
CfgOptionEnableHistoryKey = "history/enabled"
|
||||
cfgOptionEnableHistory config.BoolOption
|
||||
cfgOptionEnableHistoryOrder = 66
|
||||
|
||||
@@ -249,12 +250,13 @@ func registerConfiguration() error { //nolint:maintidx
|
||||
Key: CfgOptionEnableHistoryKey,
|
||||
Description: "Whether or not to save connections to the history database",
|
||||
OptType: config.OptTypeBool,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
DefaultValue: false,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder,
|
||||
config.CategoryAnnotation: "Advanced",
|
||||
config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder,
|
||||
config.CategoryAnnotation: "History",
|
||||
config.SettingRequiresFeaturePlan: account.FeatureHistory,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user