wip: migrate to mono-repo. SPN has already been moved to spn/

This commit is contained in:
Patrick Pacher
2024-03-15 11:55:13 +01:00
parent b30fd00ccf
commit 8579430db9
577 changed files with 35981 additions and 818 deletions

View File

@@ -0,0 +1,141 @@
package netquery
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/safing/portmaster/service/netquery/orm"
)
// ActiveChartHandler handles requests for connection charts.
type ActiveChartHandler struct {
Database *Database
}
func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := ch.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
orm.WithSchema(*ch.Database.Schema),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
"results": result,
"query": query,
"params": paramMap,
})
}
func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
var body io.Reader
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
var requestPayload QueryActiveConnectionChartPayload
blob, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
func (req *QueryActiveConnectionChartPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
template := `
WITH RECURSIVE epoch(x) AS (
SELECT strftime('%%s')-600
UNION ALL
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
)
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked
FROM epoch
JOIN connections
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
%s
GROUP BY round(timestamp/10, 0)*10;`
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
if err != nil {
return "", nil, err
}
if params == nil {
params = make(map[string]interface{})
}
if req.TextSearch != nil {
textSearch, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
if err != nil {
return "", nil, err
}
if textSearch != "" {
if clause != "" {
clause += " AND "
}
clause += textSearch
for key, val := range textParams {
params[key] = val
}
}
}
if clause == "" {
return fmt.Sprintf(template, ""), map[string]interface{}{}, nil
}
return fmt.Sprintf(template, "WHERE ( "+clause+")"), params, nil
}

View File

@@ -0,0 +1,148 @@
package netquery
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/safing/portmaster/service/netquery/orm"
)
// BandwidthChartHandler handles requests for connection charts.
type BandwidthChartHandler struct {
Database *Database
}
// BandwidthChartRequest holds a request for a bandwidth chart.
type BandwidthChartRequest struct {
Interval int `json:"interval"`
Query Query `json:"query"`
GroupBy []string `json:"groupBy"`
}
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := ch.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
orm.WithSchema(*ch.Database.Schema),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
"results": result,
"query": query,
"params": paramMap,
})
}
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
var body io.Reader
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
var requestPayload BandwidthChartRequest
blob, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
if req.Interval == 0 {
req.Interval = 10
}
interval := fmt.Sprintf("round(time/%d, 0)*%d", req.Interval, req.Interval)
// make sure there are only allowed fields specified in the request group-by
for _, gb := range req.GroupBy {
def := schema.GetColumnDef(gb)
if def == nil {
return "", nil, fmt.Errorf("unsupported groupBy key: %q", gb)
}
}
selects := append([]string{
interval + " as timestamp",
"SUM(incoming) as incoming",
"SUM(outgoing) as outgoing",
}, req.GroupBy...)
groupBy := append([]string{interval}, req.GroupBy...)
whereClause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
if err != nil {
return "", nil, err
}
if whereClause != "" {
whereClause = "WHERE " + whereClause
}
template := fmt.Sprintf(
`SELECT %s
FROM main.bandwidth AS bw
JOIN main.connections AS conns
ON bw.conn_id = conns.id
%s
GROUP BY %s
ORDER BY time ASC`,
strings.Join(selects, ", "),
whereClause,
strings.Join(groupBy, ", "),
)
return template, params, nil
}

View File

@@ -0,0 +1,681 @@
package netquery
import (
"context"
"encoding/json"
"fmt"
"io"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/jackc/puddle/v2"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
"github.com/safing/portbase/config"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/netquery/orm"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/profile"
)
// InMemory is the "file path" to open a new in-memory database.
const InMemory = "file:inmem.db?mode=memory"
// Available connection types as their string representation.
const (
ConnTypeDNS = "dns"
ConnTypeIP = "ip"
)
// ConnectionTypeToString is a lookup map to get the string representation
// of a network.ConnectionType as used by this package.
var ConnectionTypeToString = map[network.ConnectionType]string{
network.DNSRequest: ConnTypeDNS,
network.IPConnection: ConnTypeIP,
}
type (
// Database represents a SQLite3 backed connection database.
// It's use is tailored for persistence and querying of network.Connection.
// Access to the underlying SQLite database is synchronized.
//
Database struct {
Schema *orm.TableSchema
readConnPool *puddle.Pool[*sqlite.Conn]
historyPath string
l sync.Mutex
writeConn *sqlite.Conn
}
// BatchExecute executes multiple queries in one transaction.
BatchExecute struct {
ID string
SQL string
Params map[string]any
Result *[]map[string]any
}
// Conn is a network connection that is stored in a SQLite database and accepted
// by the *Database type of this package. This also defines, using the ./orm package,
// the table schema and the model that is exposed via the runtime database as well as
// the query API.
//
// Use ConvertConnection from this package to convert a network.Connection to this
// representation.
Conn struct { //nolint:maligned
// ID is a device-unique identifier for the connection. It is built
// from network.Connection by hashing the connection ID and the start
// time. We cannot just use the network.Connection.ID because it is only unique
// as long as the connection is still active and might be, although unlikely,
// reused afterwards.
ID string `sqlite:"id,primary"`
ProfileID string `sqlite:"profile"`
Path string `sqlite:"path"`
Type string `sqlite:"type,varchar(8)"`
External bool `sqlite:"external"`
IPVersion packet.IPVersion `sqlite:"ip_version"`
IPProtocol packet.IPProtocol `sqlite:"ip_protocol"`
LocalIP string `sqlite:"local_ip"`
LocalPort uint16 `sqlite:"local_port"`
RemoteIP string `sqlite:"remote_ip"`
RemotePort uint16 `sqlite:"remote_port"`
Domain string `sqlite:"domain"`
Country string `sqlite:"country,varchar(2)"`
ASN uint `sqlite:"asn"`
ASOwner string `sqlite:"as_owner"`
Latitude float64 `sqlite:"latitude"`
Longitude float64 `sqlite:"longitude"`
Scope netutils.IPScope `sqlite:"scope"`
WorstVerdict network.Verdict `sqlite:"worst_verdict"`
ActiveVerdict network.Verdict `sqlite:"verdict"`
FirewallVerdict network.Verdict `sqlite:"firewall_verdict"`
Started time.Time `sqlite:"started,text,time"`
Ended *time.Time `sqlite:"ended,text,time"`
Tunneled bool `sqlite:"tunneled"`
Encrypted bool `sqlite:"encrypted"`
Internal bool `sqlite:"internal"`
Direction string `sqlite:"direction"`
ExtraData json.RawMessage `sqlite:"extra_data"`
Allowed *bool `sqlite:"allowed"`
ProfileRevision int `sqlite:"profile_revision"`
ExitNode *string `sqlite:"exit_node"`
BytesReceived uint64 `sqlite:"bytes_received,default=0"`
BytesSent uint64 `sqlite:"bytes_sent,default=0"`
// TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
// TODO(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here
ProfileName string `sqlite:"profile_name"`
}
)
// New opens a new in-memory database named path and attaches a persistent history database.
//
// The returned Database used connection pooling for read-only connections
// (see Execute). To perform database writes use either Save() or ExecuteWrite().
// Note that write connections are serialized by the Database object before being
// handed over to SQLite.
func New(dbPath string) (*Database, error) {
historyParentDir := dataroot.Root().ChildDir("databases", 0o700)
if err := historyParentDir.Ensure(); err != nil {
return nil, fmt.Errorf("failed to ensure database directory exists: %w", err)
}
// Get file location of history database.
historyFile := filepath.Join(historyParentDir.Path, "history.db")
// Convert to SQLite URI path.
historyURI := "file:///" + strings.TrimPrefix(filepath.ToSlash(historyFile), "/")
constructor := func(ctx context.Context) (*sqlite.Conn, error) {
c, err := sqlite.OpenConn(
dbPath,
sqlite.OpenReadOnly,
sqlite.OpenSharedCache,
sqlite.OpenURI,
)
if err != nil {
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", dbPath, err)
}
if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE '"+historyURI+"?mode=ro' AS history", nil); err != nil {
return nil, fmt.Errorf("failed to attach history database: %w", err)
}
return c, nil
}
destructor := func(resource *sqlite.Conn) {
if err := resource.Close(); err != nil {
log.Errorf("failed to close pooled SQlite database connection: %s", err)
}
}
pool, err := puddle.NewPool(&puddle.Config[*sqlite.Conn]{
Constructor: constructor,
Destructor: destructor,
MaxSize: 10,
})
if err != nil {
return nil, err
}
schema, err := orm.GenerateTableSchema("connections", Conn{})
if err != nil {
return nil, err
}
writeConn, err := sqlite.OpenConn(
dbPath,
sqlite.OpenCreate,
sqlite.OpenReadWrite,
sqlite.OpenWAL,
sqlite.OpenSharedCache,
sqlite.OpenURI,
)
if err != nil {
return nil, fmt.Errorf("failed to open sqlite at %s: %w", dbPath, err)
}
return &Database{
readConnPool: pool,
Schema: schema,
writeConn: writeConn,
historyPath: historyURI,
}, nil
}
// NewInMemory is like New but creates a new in-memory database and
// automatically applies the connection table schema.
func NewInMemory() (*Database, error) {
db, err := New(InMemory)
if err != nil {
return nil, err
}
// this should actually never happen because an in-memory database
// always starts empty...
if err := db.ApplyMigrations(); err != nil {
return nil, fmt.Errorf("failed to prepare database: %w", err)
}
return db, nil
}
// Close closes the database, including pools and connections.
func (db *Database) Close() error {
db.readConnPool.Close()
if err := db.writeConn.Close(); err != nil {
return err
}
return nil
}
// VacuumHistory rewrites the history database in order to purge deleted records.
func VacuumHistory(ctx context.Context) (err error) {
historyParentDir := dataroot.Root().ChildDir("databases", 0o700)
if err := historyParentDir.Ensure(); err != nil {
return fmt.Errorf("failed to ensure database directory exists: %w", err)
}
// Get file location of history database.
historyFile := filepath.Join(historyParentDir.Path, "history.db")
// Convert to SQLite URI path.
historyURI := "file:///" + strings.TrimPrefix(filepath.ToSlash(historyFile), "/")
writeConn, err := sqlite.OpenConn(
historyURI,
sqlite.OpenCreate,
sqlite.OpenReadWrite,
sqlite.OpenWAL,
sqlite.OpenSharedCache,
sqlite.OpenURI,
)
if err != nil {
return err
}
defer func() {
if closeErr := writeConn.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
return orm.RunQuery(ctx, writeConn, "VACUUM")
}
// ApplyMigrations applies any table and data migrations that are needed
// to bring db up-to-date with the built-in schema.
// TODO(ppacher): right now this only applies the current schema and ignores
// any data-migrations. Once the history module is implemented this should
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration.
func (db *Database) ApplyMigrations() error {
db.l.Lock()
defer db.l.Unlock()
if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE '"+db.historyPath+"?mode=rwc' AS 'history';", nil); err != nil {
return fmt.Errorf("failed to attach history database: %w", err)
}
dbNames := []string{"main", "history"}
for _, dbName := range dbNames {
// get the create-table SQL statement from the inferred schema
sql := db.Schema.CreateStatement(dbName, true)
log.Debugf("creating table schema for database %q", dbName)
// execute the SQL
if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil {
return fmt.Errorf("failed to create schema on database %q: %w", dbName, err)
}
// create a few indexes
indexes := []string{
`CREATE INDEX IF NOT EXISTS %sprofile_id_index ON %s (profile)`,
`CREATE INDEX IF NOT EXISTS %sstarted_time_index ON %s (strftime('%%s', started)+0)`,
`CREATE INDEX IF NOT EXISTS %sstarted_ended_time_index ON %s (strftime('%%s', started)+0, strftime('%%s', ended)+0) WHERE ended IS NOT NULL`,
}
for _, idx := range indexes {
name := ""
if dbName != "" {
name = dbName + "."
}
stmt := fmt.Sprintf(idx, name, db.Schema.Name)
if err := sqlitex.ExecuteTransient(db.writeConn, stmt, nil); err != nil {
return fmt.Errorf("failed to create index on database %q: %q: %w", dbName, idx, err)
}
}
}
bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth (
conn_id TEXT NOT NULL,
time INTEGER NOT NULL,
incoming INTEGER NOT NULL,
outgoing INTEGER NOT NULL,
CONSTRAINT fk_conn_id
FOREIGN KEY(conn_id) REFERENCES connections(id)
ON DELETE CASCADE
)`
if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil {
return fmt.Errorf("failed to create main.bandwidth database: %w", err)
}
return nil
}
func (db *Database) withConn(ctx context.Context, fn func(conn *sqlite.Conn) error) error {
res, err := db.readConnPool.Acquire(ctx)
if err != nil {
return err
}
defer res.Release()
return fn(res.Value())
}
// ExecuteWrite executes a custom SQL query using a writable connection against the SQLite
// database used by db.
// It uses orm.RunQuery() under the hood so please refer to the orm package for
// more information about available options.
func (db *Database) ExecuteWrite(ctx context.Context, sql string, args ...orm.QueryOption) error {
db.l.Lock()
defer db.l.Unlock()
return orm.RunQuery(ctx, db.writeConn, sql, args...)
}
// Execute executes a custom SQL query using a read-only connection against the SQLite
// database used by db.
// It uses orm.RunQuery() under the hood so please refer to the orm package for
// more information about available options.
func (db *Database) Execute(ctx context.Context, sql string, args ...orm.QueryOption) error {
return db.withConn(ctx, func(conn *sqlite.Conn) error {
return orm.RunQuery(ctx, conn, sql, args...)
})
}
// ExecuteBatch executes multiple custom SQL query using a read-only connection against the SQLite
// database used by db.
func (db *Database) ExecuteBatch(ctx context.Context, batches []BatchExecute) error {
return db.withConn(ctx, func(conn *sqlite.Conn) error {
merr := new(multierror.Error)
for _, batch := range batches {
if err := orm.RunQuery(ctx, conn, batch.SQL, orm.WithNamedArgs(batch.Params), orm.WithResult(batch.Result)); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("%s: %w", batch.ID, err))
}
}
return merr.ErrorOrNil()
})
}
// CountRows returns the number of rows stored in the database.
func (db *Database) CountRows(ctx context.Context) (int, error) {
var result []struct {
Count int `sqlite:"count"`
}
if err := db.Execute(ctx, "SELECT COUNT(*) AS count FROM (SELECT * FROM main.connections UNION SELECT * from history.connections)", orm.WithResult(&result)); err != nil {
return 0, fmt.Errorf("failed to perform query: %w", err)
}
if len(result) != 1 {
return 0, fmt.Errorf("unexpected number of rows returned, expected 1 got %d", len(result))
}
return result[0].Count, nil
}
// Cleanup removes all connections that have ended before threshold from the live database.
//
// NOTE(ppacher): there is no easy way to get the number of removed
// rows other than counting them in a first step. Though, that's
// probably not worth the cylces...
func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) {
where := `WHERE ended IS NOT NULL
AND datetime(ended) < datetime(:threshold)`
sql := "DELETE FROM main.connections " + where + ";"
args := orm.WithNamedArgs(map[string]interface{}{
":threshold": threshold.UTC().Format(orm.SqliteTimeFormat),
})
var result []struct {
Count int `sqlite:"count"`
}
if err := db.Execute(
ctx,
"SELECT COUNT(*) AS count FROM connections "+where,
args,
orm.WithTransient(),
orm.WithResult(&result),
); err != nil {
return 0, fmt.Errorf("failed to perform query: %w", err)
}
if len(result) != 1 {
return 0, fmt.Errorf("unexpected number of rows, expected 1 got %d", len(result))
}
err := db.ExecuteWrite(ctx, sql, args)
if err != nil {
return 0, err
}
return result[0].Count, nil
}
// RemoveAllHistoryData removes all connections from the history database.
func (db *Database) RemoveAllHistoryData(ctx context.Context) error {
query := fmt.Sprintf("DELETE FROM %s.connections", HistoryDatabase)
return db.ExecuteWrite(ctx, query)
}
// RemoveHistoryForProfile removes all connections from the history database
// for a given profile ID (source/id).
func (db *Database) RemoveHistoryForProfile(ctx context.Context, profileID string) error {
query := fmt.Sprintf("DELETE FROM %s.connections WHERE profile = :profile", HistoryDatabase)
return db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
":profile": profileID,
}))
}
// MigrateProfileID migrates the given profile IDs in the history database.
// This needs to be done when profiles are deleted and replaced by a different profile.
func (db *Database) MigrateProfileID(ctx context.Context, from string, to string) error {
return db.ExecuteWrite(ctx, "UPDATE history.connections SET profile = :to WHERE profile = :from", orm.WithNamedArgs(map[string]any{
":from": from,
":to": to,
}))
}
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database
// as JSON to w.
// Any error aborts dumping rows and is returned.
func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:unused
var conns []Conn
err := db.withConn(ctx, func(conn *sqlite.Conn) error {
return sqlitex.ExecuteTransient(conn, "SELECT * FROM connections", &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error {
var c Conn
if err := orm.DecodeStmt(ctx, db.Schema, stmt, &c, orm.DefaultDecodeConfig); err != nil {
return err
}
conns = append(conns, c)
return nil
},
})
})
if err != nil {
return err
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
return enc.Encode(conns)
}
// CleanupHistory deletes history data outside of the (per-app) retention time frame.
func (db *Database) CleanupHistory(ctx context.Context) error {
// Setup tracer for the clean up process.
ctx, tracer := log.AddTracer(ctx)
defer tracer.Submit()
// Get list of profiles in history.
query := "SELECT DISTINCT profile FROM history.connections"
var result []struct {
Profile string `sqlite:"profile"`
}
if err := db.Execute(ctx, query, orm.WithResult(&result)); err != nil {
return fmt.Errorf("failed to get a list of profiles from the history database: %w", err)
}
var (
// Get global retention days - do not delete in case of error.
globalRetentionDays = config.GetAsInt(profile.CfgOptionKeepHistoryKey, 0)()
profileName string
retentionDays int64
profileCnt int
merr = new(multierror.Error)
)
for _, row := range result {
// Get profile and retention days.
id := strings.TrimPrefix(row.Profile, string(profile.SourceLocal)+"/")
p, err := profile.GetLocalProfile(id, nil, nil)
if err == nil {
profileName = p.String()
retentionDays = p.LayeredProfile().KeepHistory()
} else {
// Getting profile failed, fallback to global setting.
tracer.Errorf("history: failed to load profile for id %s: %s", id, err)
profileName = row.Profile
retentionDays = globalRetentionDays
}
// Skip deleting if history should be kept forever.
if retentionDays == 0 {
tracer.Tracef("history: retention is disabled for %s, skipping", profileName)
continue
}
// Count profiles where connections were deleted.
profileCnt++
// TODO: count cleared connections
threshold := time.Now().Add(-1 * time.Duration(retentionDays) * time.Hour * 24)
if err := db.ExecuteWrite(ctx,
"DELETE FROM history.connections WHERE profile = :profile AND active = FALSE AND datetime(started) < datetime(:threshold)",
orm.WithNamedArgs(map[string]any{
":profile": row.Profile,
":threshold": threshold.Format(orm.SqliteTimeFormat),
}),
); err != nil {
tracer.Warningf("history: failed to delete connections of %s: %s", profileName, err)
merr.Errors = append(merr.Errors, fmt.Errorf("profile %s: %w", row.Profile, err))
} else {
tracer.Debugf(
"history: deleted connections older than %d days (before %s) of %s",
retentionDays,
threshold.Format(time.RFC822),
profileName,
)
}
}
// Log summary.
tracer.Infof("history: deleted connections outside of retention from %d profiles", profileCnt)
return merr.ErrorOrNil()
}
// MarkAllHistoryConnectionsEnded marks all connections in the history database as ended.
func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
query := fmt.Sprintf("UPDATE %s.connections SET active = FALSE, ended = :ended WHERE active = TRUE", HistoryDatabase)
if err := db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
":ended": time.Now().Format(orm.SqliteTimeFormat),
})); err != nil {
return err
}
return nil
}
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database.
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
params := map[string]any{
":id": makeNqIDFromParts(processKey, connID),
}
parts := []string{}
parts = append(parts, "bytes_received = (bytes_received + :bytes_received)")
params[":bytes_received"] = bytesReceived
parts = append(parts, "bytes_sent = (bytes_sent + :bytes_sent)")
params[":bytes_sent"] = bytesSent
updateSet := strings.Join(parts, ", ")
updateStmts := []string{
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, LiveDatabase, updateSet),
}
if enableHistory {
updateStmts = append(updateStmts,
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, HistoryDatabase, updateSet),
)
}
merr := new(multierror.Error)
for _, stmt := range updateStmts {
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, err)
}
}
// also add the date to the in-memory bandwidth database
params[":time"] = time.Now().Unix()
stmt := "INSERT INTO main.bandwidth (conn_id, time, incoming, outgoing) VALUES(:id, :time, :bytes_received, :bytes_sent)"
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err))
}
return merr.ErrorOrNil()
}
// Save inserts the connection conn into the SQLite database. If conn
// already exists the table row is updated instead.
//
// Save uses the database write connection instead of relying on the
// connection pool.
func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error {
// convert the connection to a param map where each key is already translated
// to the sql column name. We also skip bytes_received and bytes_sent since those
// will be updated independently from the connection object.
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig, []string{
"bytes_received",
"bytes_sent",
})
if err != nil {
return fmt.Errorf("failed to encode connection for SQL: %w", err)
}
columns := make([]string, 0, len(connMap))
placeholders := make([]string, 0, len(connMap))
values := make(map[string]interface{}, len(connMap))
updateSets := make([]string, 0, len(connMap))
// sort keys so we get a stable SQLite query that can be better cached.
keys := make([]string, 0, len(connMap))
for key := range connMap {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
value := connMap[key]
columns = append(columns, key)
placeholders = append(placeholders, ":"+key)
values[":"+key] = value
updateSets = append(updateSets, fmt.Sprintf("%s = :%s", key, key))
}
db.l.Lock()
defer db.l.Unlock()
// TODO(ppacher): make sure this one can be cached to speed up inserting
// and save some CPU cycles for the user
dbNames := []DatabaseName{LiveDatabase}
// TODO: Should we only add ended connection to the history database to save
// a couple INSERTs per connection?
// This means we need to write the current live DB to the history DB on
// shutdown in order to be able to pick the back up after a restart.
// Save to history DB if enabled.
if enableHistory {
dbNames = append(dbNames, HistoryDatabase)
}
for _, dbName := range dbNames {
sql := fmt.Sprintf(
`INSERT INTO %s.connections (%s)
VALUES(%s)
ON CONFLICT(id) DO UPDATE SET
%s
`,
dbName,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "),
strings.Join(updateSets, ", "),
)
if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{
Named: values,
ResultFunc: func(stmt *sqlite.Stmt) error {
log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount())
return nil
},
}); err != nil {
log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values)
return err
}
}
return nil
}

277
service/netquery/manager.go Normal file
View File

@@ -0,0 +1,277 @@
package netquery
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portbase/runtime"
"github.com/safing/portmaster/service/network"
)
type (
// ConnectionStore describes the interface that is used by Manager
// to save new or updated connection objects.
// It is implemented by the *Database type of this package.
ConnectionStore interface {
// Save is called to perists the new or updated connection. If required,
// It's up to the implementation to figure out if the operation is an
// insert or an update.
// The ID of Conn is unique and can be trusted to never collide with other
// connections of the save device.
Save(context.Context, Conn, bool) error
// MarkAllHistoryConnectionsEnded marks all active connections in the history
// database as ended NOW.
MarkAllHistoryConnectionsEnded(context.Context) error
// RemoveAllHistoryData removes all connections from the history database.
RemoveAllHistoryData(context.Context) error
// RemoveHistoryForProfile removes all connections from the history database.
// for a given profile ID (source/id)
RemoveHistoryForProfile(context.Context, string) error
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database.
UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
// CleanupHistory deletes data outside of the retention time frame from the history database.
CleanupHistory(ctx context.Context) error
// Close closes the connection store. It must not be used afterwards.
Close() error
}
// Manager handles new and updated network.Connections feeds and persists them
// at a connection store.
// Manager also registers itself as a runtime database and pushes updates to
// connections using the local format.
// Users should use this update feed rather than the deprecated "network:" database.
Manager struct {
store ConnectionStore
push runtime.PushFunc
runtimeReg *runtime.Registry
pushPrefix string
}
)
// NewManager returns a new connection manager that persists all newly created or
// updated connections at store.
func NewManager(store ConnectionStore, pushPrefix string, reg *runtime.Registry) (*Manager, error) {
mng := &Manager{
store: store,
runtimeReg: reg,
pushPrefix: pushPrefix,
}
push, err := reg.Register(pushPrefix, runtime.SimpleValueGetterFunc(mng.runtimeGet))
if err != nil {
return nil, err
}
mng.push = push
return mng, nil
}
func (mng *Manager) runtimeGet(keyOrPrefix string) ([]record.Record, error) {
// TODO(ppacher):
// we don't yet support querying using the runtime database here ...
// consider exposing connection from the database at least by ID.
//
// NOTE(ppacher):
// for debugging purposes use RuntimeQueryRunner to execute plain
// SQL queries against the database using portbase/database/runtime.
return nil, nil
}
// HandleFeed starts reading new and updated connections from feed and persists them
// in the configured ConnectionStore. HandleFeed blocks until either ctx is cancelled
// or feed is closed.
// Any errors encountered when processing new or updated connections are logged but
// otherwise ignored.
// HandleFeed handles and persists updates one after each other! Depending on the system
// load the user might want to use a buffered channel for feed.
func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connection) {
for {
select {
case <-ctx.Done():
return
case conn, ok := <-feed:
if !ok {
return
}
func() {
conn.Lock()
defer conn.Unlock()
if !conn.DataIsComplete() {
return
}
model, err := convertConnection(conn)
if err != nil {
log.Errorf("netquery: failed to convert connection %s to sqlite model: %s", conn.ID, err)
return
}
// DEBUG:
// log.Tracef("netquery: updating connection %s", conn.ID)
// Save to netquery database.
// Do not include internal connections in history.
if err := mng.store.Save(ctx, *model, conn.HistoryEnabled); err != nil {
log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err)
return
}
// we clone the record metadata from the connection
// into the new model so the portbase/database layer
// can handle NEW/UPDATE correctly.
cloned := conn.Meta().Duplicate()
// push an update for the connection
if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil {
log.Errorf("netquery: failed to push update for conn %s via database system: %s", conn.ID, err)
}
}()
}
}
}
func (mng *Manager) pushConnUpdate(_ context.Context, meta record.Meta, conn Conn) error {
blob, err := json.Marshal(conn)
if err != nil {
return fmt.Errorf("failed to marshal connection: %w", err)
}
key := fmt.Sprintf("%s:%s%s", mng.runtimeReg.DatabaseName(), mng.pushPrefix, conn.ID)
wrapper, err := record.NewWrapper(
key,
&meta,
dsd.JSON,
blob,
)
if err != nil {
return fmt.Errorf("failed to create record wrapper: %w", err)
}
mng.push(wrapper)
return nil
}
// convertConnection converts conn to the local representation used
// to persist the information in SQLite.
// The caller must hold the lock to the given network.Connection.
func convertConnection(conn *network.Connection) (*Conn, error) {
direction := "outbound"
if conn.Inbound {
direction = "inbound"
}
c := Conn{
ID: makeNqIDFromConn(conn),
External: conn.External,
IPVersion: conn.IPVersion,
IPProtocol: conn.IPProtocol,
LocalIP: conn.LocalIP.String(),
LocalPort: conn.LocalPort,
ActiveVerdict: conn.Verdict,
Started: time.Unix(conn.Started, 0),
Tunneled: conn.Tunneled,
Encrypted: conn.Encrypted,
Internal: conn.Internal,
Direction: direction,
Type: ConnectionTypeToString[conn.Type],
ProfileID: conn.ProcessContext.Source + "/" + conn.ProcessContext.Profile,
Path: conn.ProcessContext.BinaryPath,
ProfileRevision: int(conn.ProfileRevisionCounter),
ProfileName: conn.ProcessContext.ProfileName,
}
switch conn.Type {
case network.DNSRequest:
c.Type = "dns"
case network.IPConnection:
c.Type = "ip"
case network.Undefined:
c.Type = ""
}
c.Allowed = &conn.ConnectionEstablished
if conn.Ended > 0 {
ended := time.Unix(conn.Ended, 0)
c.Ended = &ended
c.Active = false
} else {
c.Active = true
}
extraData := map[string]interface{}{
"pid": conn.ProcessContext.PID,
"processCreatedAt": conn.ProcessContext.CreatedAt,
}
if conn.TunnelContext != nil {
extraData["tunnel"] = conn.TunnelContext
exitNode := conn.TunnelContext.GetExitNodeID()
c.ExitNode = &exitNode
}
if conn.DNSContext != nil {
extraData["dns"] = conn.DNSContext
}
// TODO(ppacher): enable when TLS inspection is merged
// if conn.TLSContext != nil {
// extraData["tls"] = conn.TLSContext
// }
if conn.Entity != nil {
extraData["cname"] = conn.Entity.CNAME
extraData["blockedByLists"] = conn.Entity.BlockedByLists
extraData["blockedEntities"] = conn.Entity.BlockedEntities
extraData["reason"] = conn.Reason
c.RemoteIP = conn.Entity.IP.String()
c.RemotePort = conn.Entity.Port
c.Domain = conn.Entity.Domain
c.Country = conn.Entity.Country
c.ASN = conn.Entity.ASN
c.ASOwner = conn.Entity.ASOrg
c.Scope = conn.Entity.IPScope
if conn.Entity.Coordinates != nil {
c.Latitude = conn.Entity.Coordinates.Latitude
c.Longitude = conn.Entity.Coordinates.Longitude
}
}
// pre-compute the JSON blob for the extra data column
// and assign it.
extraDataBlob, err := json.Marshal(extraData)
if err != nil {
return nil, fmt.Errorf("failed to marshal extra data: %w", err)
}
c.ExtraData = extraDataBlob
return &c, nil
}
// makeNqIDFromConn creates a netquery connection ID from the network connection.
func makeNqIDFromConn(conn *network.Connection) string {
return makeNqIDFromParts(conn.Process().GetKey(), conn.ID)
}
// makeNqIDFromParts creates a netquery connection ID from the given network
// connection ID and the process key.
func makeNqIDFromParts(processKey string, netConnID string) string {
return processKey + "-" + netConnID
}

View File

@@ -0,0 +1,315 @@
package netquery
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/hashicorp/go-multierror"
servertiming "github.com/mitchellh/go-server-timing"
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/runtime"
"github.com/safing/portmaster/service/network"
)
// DefaultModule is the default netquery module.
var DefaultModule *module
type module struct {
*modules.Module
Store *Database
db *database.Interface
mng *Manager
feed chan *network.Connection
}
func init() {
DefaultModule = new(module)
DefaultModule.Module = modules.Register(
"netquery",
DefaultModule.prepare,
DefaultModule.start,
DefaultModule.stop,
"api",
"network",
"database",
)
subsystems.Register(
"history",
"Network History",
"Keep Network History Data",
DefaultModule.Module,
"config:history/",
nil,
)
}
func (m *module) prepare() error {
var err error
m.db = database.NewInterface(&database.Options{
Local: true,
Internal: true,
})
// TODO: Open database in start() phase.
m.Store, err = NewInMemory()
if err != nil {
return fmt.Errorf("failed to create in-memory database: %w", err)
}
m.mng, err = NewManager(m.Store, "netquery/data/", runtime.DefaultRegistry)
if err != nil {
return fmt.Errorf("failed to create manager: %w", err)
}
m.feed = make(chan *network.Connection, 1000)
queryHander := &QueryHandler{
Database: m.Store,
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
}
batchHander := &BatchQueryHandler{
Database: m.Store,
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
}
chartHandler := &ActiveChartHandler{
Database: m.Store,
}
bwChartHandler := &BandwidthChartHandler{
Database: m.Store,
}
if err := api.RegisterEndpoint(api.Endpoint{
Name: "Query Connections",
Description: "Query the in-memory sqlite connection database.",
Path: "netquery/query",
MimeType: "application/json",
Read: api.PermitUser, // Needs read+write as the query is sent using POST data.
Write: api.PermitUser, // Needs read+write as the query is sent using POST data.
BelongsTo: m.Module,
HandlerFunc: servertiming.Middleware(queryHander, nil).ServeHTTP,
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
Name: "Batch Query Connections",
Description: "Batch query the in-memory sqlite connection database.",
Path: "netquery/query/batch",
MimeType: "application/json",
Read: api.PermitUser, // Needs read+write as the query is sent using POST data.
Write: api.PermitUser, // Needs read+write as the query is sent using POST data.
BelongsTo: m.Module,
HandlerFunc: servertiming.Middleware(batchHander, nil).ServeHTTP,
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
Name: "Active Connections Chart",
Description: "Query the in-memory sqlite connection database and return a chart of active connections.",
Path: "netquery/charts/connection-active",
MimeType: "application/json",
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: servertiming.Middleware(chartHandler, nil).ServeHTTP,
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
// TODO: Use query parameters instead.
Path: "netquery/charts/bandwidth",
MimeType: "application/json",
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: bwChartHandler.ServeHTTP,
Name: "Bandwidth Chart",
Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
Name: "Remove connections from profile history",
Description: "Remove all connections from the history database for one or more profiles",
Path: "netquery/history/clear",
MimeType: "application/json",
Write: api.PermitUser,
BelongsTo: m.Module,
ActionFunc: func(ar *api.Request) (msg string, err error) {
var body struct {
ProfileIDs []string `json:"profileIDs"`
}
if err := json.Unmarshal(ar.InputData, &body); err != nil {
return "", fmt.Errorf("failed to decode parameters in body: %w", err)
}
if len(body.ProfileIDs) == 0 {
if err := m.mng.store.RemoveAllHistoryData(ar.Context()); err != nil {
return "", fmt.Errorf("failed to remove all history: %w", err)
}
} else {
merr := new(multierror.Error)
for _, profileID := range body.ProfileIDs {
if err := m.mng.store.RemoveHistoryForProfile(ar.Context(), profileID); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err))
} else {
log.Infof("netquery: successfully cleared history for %s", profileID)
}
}
if err := merr.ErrorOrNil(); err != nil {
return "", err
}
}
return "Successfully cleared history.", nil
},
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
Name: "Apply connection history retention threshold",
Path: "netquery/history/cleanup",
Write: api.PermitUser,
BelongsTo: m.Module,
ActionFunc: func(ar *api.Request) (msg string, err error) {
if err := m.Store.CleanupHistory(ar.Context()); err != nil {
return "", err
}
return "Deleted outdated connections.", nil
},
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
return nil
}
func (m *module) start() error {
m.StartServiceWorker("netquery connection feed listener", 0, func(ctx context.Context) error {
sub, err := m.db.Subscribe(query.New("network:"))
if err != nil {
return fmt.Errorf("failed to subscribe to network tree: %w", err)
}
defer close(m.feed)
defer func() {
_ = sub.Cancel()
}()
for {
select {
case <-ctx.Done():
return nil
case rec, ok := <-sub.Feed:
if !ok {
return nil
}
conn, ok := rec.(*network.Connection)
if !ok {
// This is fine as we also receive process updates on
// this channel.
continue
}
m.feed <- conn
}
}
})
m.StartServiceWorker("netquery connection feed handler", 0, func(ctx context.Context) error {
m.mng.HandleFeed(ctx, m.feed)
return nil
})
m.StartServiceWorker("netquery live db cleaner", 0, func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-time.After(10 * time.Second):
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
count, err := m.Store.Cleanup(ctx, threshold)
if err != nil {
log.Errorf("netquery: failed to removed old connections from live db: %s", err)
} else {
log.Tracef("netquery: successfully removed %d old connections from live db that ended before %s", count, threshold)
}
}
}
})
m.NewTask("network history cleaner", func(ctx context.Context, _ *modules.Task) error {
return m.Store.CleanupHistory(ctx)
}).Repeat(time.Hour).Schedule(time.Now().Add(10 * time.Minute))
// For debugging, provide a simple direct SQL query interface using
// the runtime database.
// Only expose in development mode.
if config.GetAsBool(config.CfgDevModeKey, false)() {
_, err := NewRuntimeQueryRunner(m.Store, "netquery/query/", runtime.DefaultRegistry)
if err != nil {
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
}
}
// Migrate profile IDs in history database when profiles are migrated/merged.
if err := m.RegisterEventHook(
"profiles",
"profile migrated",
"migrate profile IDs in history database",
func(ctx context.Context, data interface{}) error {
if profileIDs, ok := data.([]string); ok && len(profileIDs) == 2 {
return m.Store.MigrateProfileID(ctx, profileIDs[0], profileIDs[1])
}
return nil
},
); err != nil {
return err
}
return nil
}
func (m *module) stop() error {
// we don't use m.Module.Ctx here because it is already cancelled when stop is called.
// just give the clean up 1 minute to happen and abort otherwise.
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil {
// handle the error by just logging it. There's not much we can do here
// and returning an error to the module system doesn't help much as well...
log.Errorf("netquery: failed to mark connections in history database as ended: %s", err)
}
if err := m.mng.store.Close(); err != nil {
log.Errorf("netquery: failed to close sqlite database: %s", err)
} else {
// Clear deleted connections from database.
if err := VacuumHistory(ctx); err != nil {
log.Errorf("netquery: failed to execute VACUUM in history database: %s", err)
}
}
return nil
}

View File

@@ -0,0 +1,483 @@
package orm
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"reflect"
"strings"
"time"
"zombiezen.com/go/sqlite"
)
// Commonly used error messages when working with orm.
var (
errStructExpected = errors.New("encode: can only encode structs to maps")
errStructPointerExpected = errors.New("decode: result must be pointer to a struct type or map[string]interface{}")
errUnexpectedColumnType = errors.New("decode: unexpected column type")
)
// constants used when transforming data to and from sqlite.
var (
// sqliteTimeFromat defines the string representation that is
// expected by SQLite DATETIME functions.
// Note that SQLite itself does not include support for a DATETIME
// column type. Instead, dates and times are stored either as INTEGER,
// TEXT or REAL.
// This package provides support for time.Time being stored as TEXT (using a
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
// unixepoch and unixnano-epoch where the nano variant is not officially supported by
// SQLITE).
SqliteTimeFormat = "2006-01-02 15:04:05"
)
type (
// Stmt describes the interface that must be implemented in order to
// be decodable to a struct type using DecodeStmt. This interface is implemented
// by *sqlite.Stmt.
Stmt interface {
ColumnCount() int
ColumnName(int) string
ColumnType(int) sqlite.ColumnType
ColumnText(int) string
ColumnBool(int) bool
ColumnFloat(int) float64
ColumnInt(int) int
ColumnReader(int) *bytes.Reader
}
// DecodeFunc is called for each non-basic type during decoding.
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
// DecodeConfig holds decoding functions.
DecodeConfig struct {
DecodeHooks []DecodeFunc
}
)
// DecodeStmt decodes the current result row loaded in Stmt into the struct or map type result.
// Decoding hooks configured in cfg are executed before trying to decode basic types and may
// be specified to provide support for special types.
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result interface{}, cfg DecodeConfig) error {
// make sure we got something to decode into ...
if result == nil {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// fast path for decoding into a map
if mp, ok := result.(*map[string]interface{}); ok {
return decodeIntoMap(ctx, schema, stmt, mp, cfg)
}
// make sure we got a pointer in result
if reflect.TypeOf(result).Kind() != reflect.Ptr {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// make sure it's a poiter to a struct type
t := reflect.ValueOf(result).Elem().Type()
if t.Kind() != reflect.Struct {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// if result is a nil pointer make sure to allocate some space
// for the resulting struct
resultValue := reflect.ValueOf(result)
if resultValue.IsNil() {
resultValue.Set(
reflect.New(t),
)
}
// we need access to the struct directly and not to the
// pointer.
target := reflect.Indirect(resultValue)
// create a lookup map from field name (or sqlite:"" tag)
// to the field name
lm := make(map[string]string)
for i := 0; i < target.NumField(); i++ {
fieldType := t.Field(i)
// skip unexported fields
if !fieldType.IsExported() {
continue
}
lm[sqlColumnName(fieldType)] = fieldType.Name
}
// iterate over all columns and assign them to the correct
// fields
for i := 0; i < stmt.ColumnCount(); i++ {
colName := stmt.ColumnName(i)
fieldName, ok := lm[colName]
if !ok {
// there's no target field for this column
// so we can skip it
continue
}
fieldType, _ := t.FieldByName(fieldName)
value := target.FieldByName(fieldName)
colType := stmt.ColumnType(i)
// if the column is reported as NULL we keep
// the field as it is.
// TODO(ppacher): should we set it to nil here?
if colType == sqlite.TypeNull {
continue
}
// if value is a nil pointer we need to allocate some memory
// first
if getKind(value) == reflect.Ptr && value.IsNil() {
storage := reflect.New(fieldType.Type.Elem())
value.Set(storage)
// make sure value actually points the
// dereferenced target storage
value = storage.Elem()
}
colDef := schema.GetColumnDef(colName)
// execute all decode hooks but make sure we use decodeBasic() as the
// last one.
columnValue, err := runDecodeHooks(
i,
colDef,
stmt,
fieldType,
value,
append(cfg.DecodeHooks, decodeBasic()),
)
if err != nil {
return err
}
// if we don't have a converted value now we try to
// decode basic types
if columnValue == nil {
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
}
// Debugging:
// log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
// convert it to the target type if conversion is possible
newValue := reflect.ValueOf(columnValue)
if newValue.Type().ConvertibleTo(value.Type()) {
newValue = newValue.Convert(value.Type())
}
// assign the new value to the struct field.
value.Set(newValue)
}
return nil
}
// DatetimeDecoder is capable of decoding sqlite INTEGER or TEXT storage classes into
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
// decide between Unix or UnixNano epoch timestamps.
//
// TODO(ppacher): update comment about loc parameter and TEXT storage class parsing.
func DatetimeDecoder(loc *time.Location) DecodeFunc {
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
// if we have the column definition available we
// use the target go type from there.
outType := outval.Type()
if colDef != nil {
outType = colDef.GoType
}
// we only care about "time.Time" here
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
// log.Printf("not decoding %s %v", outType, colDef)
return nil, false, nil
}
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // Only selecting specific types.
case sqlite.TypeInteger:
// stored as unix-epoch, if unixnano is set in the struct field tag
// we parse it with nano-second resolution
// TODO(ppacher): actually split the tag value at "," and search
// the slice for "unixnano"
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), true, nil
}
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), true, nil
case sqlite.TypeText:
// stored ISO8601 but does not have any timezone information
// assigned so we always treat it as loc here.
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
if err != nil {
return nil, false, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
}
return t, true, nil
case sqlite.TypeFloat:
// stored as Julian day numbers
return nil, false, fmt.Errorf("REAL storage type not support for time.Time")
case sqlite.TypeNull:
return nil, true, nil
default:
return nil, false, fmt.Errorf("unsupported storage type for time.Time: %s", stmt.ColumnType(colIdx))
}
}
}
func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
if *mp == nil {
*mp = make(map[string]interface{})
}
for i := 0; i < stmt.ColumnCount(); i++ {
var x interface{}
colDef := schema.GetColumnDef(stmt.ColumnName(i))
outVal := reflect.ValueOf(&x).Elem()
fieldType := reflect.StructField{}
if colDef != nil {
outVal = reflect.New(colDef.GoType).Elem()
fieldType = reflect.StructField{
Type: colDef.GoType,
}
}
val, err := runDecodeHooks(
i,
colDef,
stmt,
fieldType,
outVal,
append(cfg.DecodeHooks, decodeBasic()),
)
if err != nil {
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
}
(*mp)[stmt.ColumnName(i)] = val
}
return nil
}
func decodeBasic() DecodeFunc {
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (result interface{}, handled bool, err error) {
valueKind := getKind(outval)
colType := stmt.ColumnType(colIdx)
colName := stmt.ColumnName(colIdx)
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
// if we have the column definition available we
// use the target go type from there.
if colDef != nil {
valueKind = NormalizeKind(colDef.GoType.Kind())
// if we have a column definition we try to convert the value to
// the actual Go-type that was used in the model.
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
// or that type aliases like (type myInt int) are decoded into myInt instead of int
defer func() {
if handled {
t := reflect.New(colDef.GoType).Elem()
if result == nil || reflect.ValueOf(result).IsZero() {
return
}
if reflect.ValueOf(result).Type().ConvertibleTo(colDef.GoType) {
result = reflect.ValueOf(result).Convert(colDef.GoType).Interface()
}
t.Set(reflect.ValueOf(result))
result = t.Interface()
}
}()
}
// log.Printf("decoding %s into kind %s", colName, valueKind)
if colType == sqlite.TypeNull {
if colDef != nil && colDef.Nullable {
return nil, true, nil
}
if colDef != nil && !colDef.Nullable {
return reflect.New(colDef.GoType).Elem().Interface(), true, nil
}
if outval.Kind() == reflect.Ptr {
return nil, true, nil
}
}
switch valueKind { //nolint:exhaustive
case reflect.String:
if colType != sqlite.TypeText {
return nil, false, errInvalidType
}
return stmt.ColumnText(colIdx), true, nil
case reflect.Bool:
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
// with INTEGER affinity.
if colType != sqlite.TypeInteger {
return nil, false, errInvalidType
}
return stmt.ColumnBool(colIdx), true, nil
case reflect.Float64:
if colType != sqlite.TypeFloat {
return nil, false, errInvalidType
}
return stmt.ColumnFloat(colIdx), true, nil
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
if colType != sqlite.TypeInteger {
return nil, false, errInvalidType
}
return stmt.ColumnInt(colIdx), true, nil
case reflect.Slice:
if outval.Type().Elem().Kind() != reflect.Uint8 {
return nil, false, fmt.Errorf("slices other than []byte for BLOB are not supported")
}
if colType != sqlite.TypeBlob {
return nil, false, errInvalidType
}
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
if err != nil {
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
}
return columnValue, true, nil
case reflect.Interface:
var (
t reflect.Type
x interface{}
)
switch colType {
case sqlite.TypeBlob:
t = reflect.TypeOf([]byte{})
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
if err != nil {
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
}
x = columnValue
case sqlite.TypeFloat:
t = reflect.TypeOf(float64(0))
x = stmt.ColumnFloat(colIdx)
case sqlite.TypeInteger:
t = reflect.TypeOf(int(0))
x = stmt.ColumnInt(colIdx)
case sqlite.TypeText:
t = reflect.TypeOf(string(""))
x = stmt.ColumnText(colIdx)
case sqlite.TypeNull:
t = nil
x = nil
default:
return nil, false, fmt.Errorf("unsupported column type %s", colType)
}
if t == nil {
return nil, true, nil
}
target := reflect.New(t).Elem()
target.Set(reflect.ValueOf(x))
return target.Interface(), true, nil
default:
return nil, false, fmt.Errorf("cannot decode into %s", valueKind)
}
}
}
func sqlColumnName(fieldType reflect.StructField) string {
tagValue, hasTag := fieldType.Tag.Lookup("sqlite")
if !hasTag {
return fieldType.Name
}
parts := strings.Split(tagValue, ",")
if parts[0] != "" {
return parts[0]
}
return fieldType.Name
}
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
for _, fn := range hooks {
res, end, err := fn(colIdx, colDef, stmt, fieldDef, outval)
if err != nil {
return res, err
}
if end {
return res, nil
}
}
return nil, nil
}
// getKind returns the kind of value but normalized Int, Uint and Float variants.
// to their base type.
func getKind(val reflect.Value) reflect.Kind {
kind := val.Kind()
return NormalizeKind(kind)
}
// NormalizeKind returns a normalized kind of the given kind.
func NormalizeKind(kind reflect.Kind) reflect.Kind {
switch {
case kind >= reflect.Int && kind <= reflect.Int64:
return reflect.Int
case kind >= reflect.Uint && kind <= reflect.Uint64:
return reflect.Uint
case kind >= reflect.Float32 && kind <= reflect.Float64:
return reflect.Float64
default:
return kind
}
}
// DefaultDecodeConfig holds the default decoding configuration.
var DefaultDecodeConfig = DecodeConfig{
DecodeHooks: []DecodeFunc{
DatetimeDecoder(time.UTC),
},
}

View File

@@ -0,0 +1,570 @@
package orm
import (
"bytes"
"context"
"encoding/json"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
"zombiezen.com/go/sqlite"
)
type testStmt struct {
columns []string
values []interface{}
types []sqlite.ColumnType
}
func (ts testStmt) ColumnCount() int { return len(ts.columns) }
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } //nolint:forcetypeassert
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } //nolint:forcetypeassert
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } //nolint:forcetypeassert
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } //nolint:forcetypeassert
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } //nolint:forcetypeassert
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
// Compile time check.
var _ Stmt = new(testStmt)
type exampleFieldTypes struct {
S string
I int
F float64
B bool
}
type examplePointerTypes struct {
S *string
I *int
F *float64
B *bool
}
type exampleStructTags struct {
S string `sqlite:"col_string"`
I int `sqlite:"col_int"`
}
type exampleIntConv struct {
I8 int8
I16 int16
I32 int32
I64 int64
I int
}
type exampleBlobTypes struct {
B []byte
}
type exampleJSONRawTypes struct {
B json.RawMessage
}
type exampleTimeTypes struct {
T time.Time
TP *time.Time
}
type exampleInterface struct {
I interface{}
IP *interface{}
}
func (ett *exampleTimeTypes) Equal(other interface{}) bool {
oett, ok := other.(*exampleTimeTypes)
if !ok {
return false
}
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
}
type myInt int
type exampleTimeNano struct {
T time.Time `sqlite:",unixnano"`
}
func (etn *exampleTimeNano) Equal(other interface{}) bool {
oetn, ok := other.(*exampleTimeNano)
if !ok {
return false
}
return etn.T.Equal(oetn.T)
}
func TestDecoder(t *testing.T) { //nolint:maintidx,tparallel
t.Parallel()
ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
cases := []struct {
Desc string
Stmt testStmt
ColumnDef []ColumnDef
Result interface{}
Expected interface{}
}{
{
"Decoding into nil is not allowed",
testStmt{
columns: nil,
values: nil,
types: nil,
},
nil,
nil,
nil,
},
{
"Decoding into basic types",
testStmt{
columns: []string{"S", "I", "F", "B"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeInteger,
sqlite.TypeFloat,
sqlite.TypeInteger,
},
values: []interface{}{
"string value",
1,
1.2,
true,
},
},
nil,
&exampleFieldTypes{},
&exampleFieldTypes{
S: "string value",
I: 1,
F: 1.2,
B: true,
},
},
{
"Decoding into basic types with different order",
testStmt{
columns: []string{"I", "S", "B", "F"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeText,
sqlite.TypeInteger,
sqlite.TypeFloat,
},
values: []interface{}{
1,
"string value",
true,
1.2,
},
},
nil,
&exampleFieldTypes{},
&exampleFieldTypes{
S: "string value",
I: 1,
F: 1.2,
B: true,
},
},
{
"Decoding into basic types with missing values",
testStmt{
columns: []string{"F", "B"},
types: []sqlite.ColumnType{
sqlite.TypeFloat,
sqlite.TypeInteger,
},
values: []interface{}{
1.2,
true,
},
},
nil,
&exampleFieldTypes{},
&exampleFieldTypes{
F: 1.2,
B: true,
},
},
{
"Decoding into pointer types",
testStmt{
columns: []string{"S", "I", "F", "B"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeInteger,
sqlite.TypeFloat,
sqlite.TypeInteger,
},
values: []interface{}{
"string value",
1,
1.2,
true,
},
},
nil,
&examplePointerTypes{},
func() interface{} {
s := "string value"
i := 1
f := 1.2
b := true
return &examplePointerTypes{
S: &s,
I: &i,
F: &f,
B: &b,
}
},
},
{
"Decoding into pointer types with missing values",
testStmt{
columns: []string{"S", "B"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeInteger,
sqlite.TypeFloat,
sqlite.TypeInteger,
},
values: []interface{}{
"string value",
true,
},
},
nil,
&examplePointerTypes{},
func() interface{} {
s := "string value"
b := true
return &examplePointerTypes{
S: &s,
B: &b,
}
},
},
{
"Decoding into fields with struct tags",
testStmt{
columns: []string{"col_string", "col_int"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeInteger,
},
values: []interface{}{
"string value",
1,
},
},
nil,
&exampleStructTags{},
&exampleStructTags{
S: "string value",
I: 1,
},
},
{
"Decoding into correct int type",
testStmt{
columns: []string{"I8", "I16", "I32", "I64", "I"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeInteger,
sqlite.TypeInteger,
sqlite.TypeInteger,
sqlite.TypeInteger,
},
values: []interface{}{
1,
1,
1,
1,
1,
},
},
nil,
&exampleIntConv{},
&exampleIntConv{
1, 1, 1, 1, 1,
},
},
{
"Handling NULL values for basic types",
testStmt{
columns: []string{"S", "I", "F"},
types: []sqlite.ColumnType{
sqlite.TypeNull,
sqlite.TypeNull,
sqlite.TypeFloat,
},
values: []interface{}{
// we use nil here but actually that does not matter
nil,
nil,
1.0,
},
},
nil,
&exampleFieldTypes{},
&exampleFieldTypes{
F: 1.0,
},
},
{
"Handling NULL values for pointer types",
testStmt{
columns: []string{"S", "I", "F"},
types: []sqlite.ColumnType{
sqlite.TypeNull,
sqlite.TypeNull,
sqlite.TypeFloat,
},
values: []interface{}{
// we use nil here but actually that does not matter
nil,
nil,
1.0,
},
},
nil,
&examplePointerTypes{},
func() interface{} {
f := 1.0
return &examplePointerTypes{F: &f}
},
},
{
"Handling blob types",
testStmt{
columns: []string{"B"},
types: []sqlite.ColumnType{
sqlite.TypeBlob,
},
values: []interface{}{
([]byte)("hello world"),
},
},
nil,
&exampleBlobTypes{},
&exampleBlobTypes{
B: ([]byte)("hello world"),
},
},
{
"Handling blob types as json.RawMessage",
testStmt{
columns: []string{"B"},
types: []sqlite.ColumnType{
sqlite.TypeBlob,
},
values: []interface{}{
([]byte)("hello world"),
},
},
nil,
&exampleJSONRawTypes{},
&exampleJSONRawTypes{
B: (json.RawMessage)("hello world"),
},
},
{
"Handling time.Time and pointers to it",
testStmt{
columns: []string{"T", "TP"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeInteger,
},
values: []interface{}{
int(refTime.Unix()),
int(refTime.Unix()),
},
},
nil,
&exampleTimeTypes{},
&exampleTimeTypes{
T: refTime,
TP: &refTime,
},
},
{
"Handling time.Time in nano-second resolution (struct tags)",
testStmt{
columns: []string{"T", "TP"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeInteger,
},
values: []interface{}{
int(refTime.UnixNano()),
int(refTime.UnixNano()),
},
},
nil,
&exampleTimeNano{},
&exampleTimeNano{
T: refTime,
},
},
{
"Decoding into interface",
testStmt{
columns: []string{"I", "IP"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeText,
},
values: []interface{}{
"value1",
"value2",
},
},
nil,
&exampleInterface{},
func() interface{} {
var x interface{} = "value2"
return &exampleInterface{
I: "value1",
IP: &x,
}
},
},
{
"Decoding into map[string]interface{}",
testStmt{
columns: []string{"I", "F", "S", "B"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeFloat,
sqlite.TypeText,
sqlite.TypeBlob,
},
values: []interface{}{
1,
1.1,
"string value",
[]byte("blob value"),
},
},
nil,
new(map[string]interface{}),
&map[string]interface{}{
"I": 1,
"F": 1.1,
"S": "string value",
"B": []byte("blob value"),
},
},
{
"Decoding using type-hints",
testStmt{
columns: []string{"B", "T"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeText,
},
values: []interface{}{
true,
refTime.Format(SqliteTimeFormat),
},
},
[]ColumnDef{
{
Name: "B",
Type: sqlite.TypeInteger,
GoType: reflect.TypeOf(true),
},
{
Name: "T",
Type: sqlite.TypeText,
GoType: reflect.TypeOf(time.Time{}),
IsTime: true,
},
},
new(map[string]interface{}),
&map[string]interface{}{
"B": true,
"T": refTime,
},
},
{
"Decoding into type aliases",
testStmt{
columns: []string{"B"},
types: []sqlite.ColumnType{
sqlite.TypeBlob,
},
values: []interface{}{
[]byte(`{"foo": "bar}`),
},
},
[]ColumnDef{
{
Name: "B",
Type: sqlite.TypeBlob,
GoType: reflect.TypeOf(json.RawMessage(`{"foo": "bar}`)),
},
},
new(map[string]interface{}),
&map[string]interface{}{
"B": json.RawMessage(`{"foo": "bar}`),
},
},
{
"Decoding into type aliases #2",
testStmt{
columns: []string{"I"},
types: []sqlite.ColumnType{sqlite.TypeInteger},
values: []interface{}{
10,
},
},
[]ColumnDef{
{
Name: "I",
Type: sqlite.TypeInteger,
GoType: reflect.TypeOf(myInt(0)),
},
},
new(map[string]interface{}),
&map[string]interface{}{
"I": myInt(10),
},
},
}
for idx := range cases { //nolint:paralleltest
c := cases[idx]
t.Run(c.Desc, func(t *testing.T) {
// log.Println(c.Desc)
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
if fn, ok := c.Expected.(func() interface{}); ok {
c.Expected = fn()
}
if c.Expected == nil {
assert.Error(t, err, c.Desc)
} else {
assert.NoError(t, err, c.Desc)
if equaler, ok := c.Expected.(interface{ Equal(x interface{}) bool }); ok {
assert.True(t, equaler.Equal(c.Result))
} else {
assert.Equal(t, c.Expected, c.Result)
}
}
})
}
}

View File

@@ -0,0 +1,258 @@
package orm
import (
"context"
"fmt"
"reflect"
"time"
"golang.org/x/exp/slices"
"zombiezen.com/go/sqlite"
)
type (
// EncodeFunc is called for each non-basic type during encoding.
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
// EncodeConfig holds encoding functions.
EncodeConfig struct {
EncodeHooks []EncodeFunc
}
)
// ToParamMap returns a map that contains the sqlite compatible value of each struct field of
// r using the sqlite column name as a map key. It either uses the name of the
// exported struct field or the value of the "sqlite" tag.
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) {
// make sure we work on a struct type
val := reflect.Indirect(reflect.ValueOf(r))
if val.Kind() != reflect.Struct {
return nil, fmt.Errorf("%w, got %T", errStructExpected, r)
}
res := make(map[string]interface{}, val.NumField())
for i := 0; i < val.NumField(); i++ {
fieldType := val.Type().Field(i)
field := val.Field(i)
// skip unexported fields
if !fieldType.IsExported() {
continue
}
colDef, err := getColumnDef(fieldType)
if err != nil {
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
}
if slices.Contains(skipFields, colDef.Name) {
continue
}
x, found, err := runEncodeHooks(
colDef,
fieldType.Type,
field,
append(
cfg.EncodeHooks,
encodeBasic(),
),
)
if err != nil {
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
}
if !found {
if reflect.Indirect(field).IsValid() {
x = reflect.Indirect(field).Interface()
}
}
res[keyPrefix+sqlColumnName(fieldType)] = x
}
return res, nil
}
// EncodeValue encodes the given value.
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
fieldValue := reflect.ValueOf(val)
fieldType := reflect.TypeOf(val)
x, found, err := runEncodeHooks(
colDef,
fieldType,
fieldValue,
append(
cfg.EncodeHooks,
encodeBasic(),
),
)
if err != nil {
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
}
if !found {
if reflect.Indirect(fieldValue).IsValid() {
x = reflect.Indirect(fieldValue).Interface()
}
}
return x, nil
}
func encodeBasic() EncodeFunc {
return func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
kind := valType.Kind()
if kind == reflect.Ptr {
valType = valType.Elem()
kind = valType.Kind()
if val.IsNil() {
if !col.Nullable {
// we need to set the zero value here since the column
// is not marked as nullable
return reflect.New(valType).Elem().Interface(), true, nil
}
return nil, true, nil
}
val = val.Elem()
}
switch NormalizeKind(kind) { //nolint:exhaustive
case reflect.String,
reflect.Float64,
reflect.Bool,
reflect.Int,
reflect.Uint:
// sqlite package handles conversion of those types
// already
return val.Interface(), true, nil
case reflect.Slice:
if valType.Elem().Kind() == reflect.Uint8 {
// this is []byte
return val.Interface(), true, nil
}
fallthrough
default:
return nil, false, fmt.Errorf("cannot convert value of kind %s for use in SQLite", kind)
}
}
}
// DatetimeEncoder returns a new datetime encoder for the given time zone.
func DatetimeEncoder(loc *time.Location) EncodeFunc {
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
// if fieldType holds a pointer we need to dereference the value
ft := valType.String()
if valType.Kind() == reflect.Ptr {
ft = valType.Elem().String()
val = reflect.Indirect(val)
}
normalizedKind := NormalizeKind(valType.Kind())
// we only care about "time.Time" here
var t time.Time
switch {
case ft == "time.Time":
// handle the zero time as a NULL.
if !val.IsValid() || val.IsZero() {
return nil, true, nil
}
var ok bool
valInterface := val.Interface()
t, ok = valInterface.(time.Time)
if !ok {
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
}
case valType.Kind() == reflect.String && colDef.IsTime:
var err error
t, err = time.Parse(time.RFC3339, val.String())
if err != nil {
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
}
case (normalizedKind == reflect.Int || normalizedKind == reflect.Uint || normalizedKind == reflect.Float64) && colDef.IsTime:
seconds := int64(0)
switch normalizedKind { //nolint:exhaustive // Previous switch case assures these types.
case reflect.Int:
seconds = val.Int()
case reflect.Uint:
seconds = int64(val.Uint())
case reflect.Float64:
seconds = int64(val.Float())
}
t = time.Unix(seconds, 0)
default:
// we don't care ...
return nil, false, nil
}
switch colDef.Type { //nolint:exhaustive
case sqlite.TypeInteger:
if colDef.UnixNano {
return t.UnixNano(), true, nil
}
return t.Unix(), true, nil
case sqlite.TypeText:
str := t.In(loc).Format(SqliteTimeFormat)
return str, true, nil
}
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDef.Type)
}
}
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
if valType == nil {
if !colDef.Nullable {
switch colDef.Type { //nolint:exhaustive
case sqlite.TypeBlob:
return []byte{}, true, nil
case sqlite.TypeFloat:
return 0.0, true, nil
case sqlite.TypeText:
return "", true, nil
case sqlite.TypeInteger:
return 0, true, nil
default:
return nil, false, fmt.Errorf("unsupported sqlite data type: %s", colDef.Type)
}
}
return nil, true, nil
}
for _, fn := range hooks {
res, end, err := fn(colDef, valType, val)
if err != nil {
return res, false, err
}
if end {
return res, true, nil
}
}
return nil, false, nil
}
// DefaultEncodeConfig holds the default encoding configuration.
var DefaultEncodeConfig = EncodeConfig{
EncodeHooks: []EncodeFunc{
DatetimeEncoder(time.UTC),
},
}

View File

@@ -0,0 +1,260 @@
package orm
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"zombiezen.com/go/sqlite"
)
func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
t.Parallel()
ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
cases := []struct {
Desc string
Input interface{}
Expected map[string]interface{}
}{
{
"Encode basic types",
struct {
I int
F float64
S string
B []byte
}{
I: 1,
F: 1.2,
S: "string",
B: ([]byte)("bytes"),
},
map[string]interface{}{
"I": 1,
"F": 1.2,
"S": "string",
"B": ([]byte)("bytes"),
},
},
{
"Encode using struct tags",
struct {
I int `sqlite:"col_int"`
S string `sqlite:"col_string"`
}{
I: 1,
S: "string value",
},
map[string]interface{}{
"col_int": 1,
"col_string": "string value",
},
},
{
"Ignore Private fields",
struct {
I int
s string
}{
I: 1,
s: "string value",
},
map[string]interface{}{
"I": 1,
},
},
{
"Handle Pointers",
struct {
I *int
S *string
}{
I: new(int),
},
map[string]interface{}{
"I": 0,
"S": nil,
},
},
{
"Handle time.Time types",
struct {
TinInt time.Time `sqlite:",integer,unixnano"`
TinString time.Time `sqlite:",text"`
}{
TinInt: refTime,
TinString: refTime,
},
map[string]interface{}{
"TinInt": refTime.UnixNano(),
"TinString": refTime.Format(SqliteTimeFormat),
},
},
{
"Handle time.Time pointer types",
struct {
TinInt *time.Time `sqlite:",integer,unixnano"`
TinString *time.Time `sqlite:",text"`
Tnil1 *time.Time `sqlite:",text"`
Tnil2 *time.Time `sqlite:",text"`
}{
TinInt: &refTime,
TinString: &refTime,
Tnil1: nil,
Tnil2: (*time.Time)(nil),
},
map[string]interface{}{
"TinInt": refTime.UnixNano(),
"TinString": refTime.Format(SqliteTimeFormat),
"Tnil1": nil,
"Tnil2": nil,
},
},
}
for idx := range cases { //nolint:paralleltest
c := cases[idx]
t.Run(c.Desc, func(t *testing.T) {
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil)
assert.NoError(t, err)
assert.Equal(t, c.Expected, res)
})
}
}
func TestEncodeValue(t *testing.T) { //nolint:tparallel
t.Parallel()
ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
cases := []struct {
Desc string
Column ColumnDef
Input interface{}
Output interface{}
}{
{
"Special value time.Time as text",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
refTime,
refTime.Format(SqliteTimeFormat),
},
{
"Special value time.Time as unix-epoch",
ColumnDef{
IsTime: true,
Type: sqlite.TypeInteger,
},
refTime,
refTime.Unix(),
},
{
"Special value time.Time as unixnano-epoch",
ColumnDef{
IsTime: true,
Type: sqlite.TypeInteger,
UnixNano: true,
},
refTime,
refTime.UnixNano(),
},
{
"Special value zero time",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
time.Time{},
nil,
},
{
"Special value zero time pointer",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
new(time.Time),
nil,
},
{
"Special value *time.Time as text",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
&refTime,
refTime.Format(SqliteTimeFormat),
},
{
"Special value untyped nil",
ColumnDef{
Nullable: true,
IsTime: true,
Type: sqlite.TypeText,
},
nil,
nil,
},
{
"Special value typed nil",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
(*time.Time)(nil),
nil,
},
{
"Time formated as string",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
refTime.In(time.Local).Format(time.RFC3339),
refTime.Format(SqliteTimeFormat),
},
{
"Nullable integer",
ColumnDef{
Type: sqlite.TypeInteger,
Nullable: true,
},
nil,
nil,
},
{
"Not-Null integer",
ColumnDef{
Name: "test",
Type: sqlite.TypeInteger,
},
nil,
0,
},
{
"Not-Null string",
ColumnDef{
Type: sqlite.TypeText,
},
nil,
"",
},
}
for idx := range cases { //nolint:paralleltest
c := cases[idx]
t.Run(c.Desc, func(t *testing.T) {
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
assert.NoError(t, err)
assert.Equal(t, c.Output, res)
})
}
}

View File

@@ -0,0 +1,186 @@
package orm
import (
"context"
"fmt"
"reflect"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
type (
// QueryOption can be specified at RunQuery to alter the behavior
// of the executed query.
QueryOption func(opts *queryOpts)
queryOpts struct {
Transient bool
Args []interface{}
NamedArgs map[string]interface{}
Result interface{}
DecodeConfig DecodeConfig
Schema TableSchema
}
)
// WithTransient marks the query as transient.
//
// Transient queries will not be cached for later
// re-use after they have been prepared.
func WithTransient() QueryOption {
return func(opts *queryOpts) {
opts.Transient = true
}
}
// WithArgs adds a list of arguments for the query. Arguments
// are applied in order.
//
// See SQL Language Expression documentation of SQLite for
// details: https://sqlite.org/lang_expr.html
func WithArgs(args ...interface{}) QueryOption {
return func(opts *queryOpts) {
opts.Args = args
}
}
// WithNamedArgs adds args to the query. The query must used
// named argument placeholders. According to the SQLite spec,
// arguments must either start with ':', '@' or '$'.
//
// See SQL Language Expression documentation of SQLite for
// details: https://sqlite.org/lang_expr.html
func WithNamedArgs(args map[string]interface{}) QueryOption {
return func(opts *queryOpts) {
opts.NamedArgs = args
}
}
// WithSchema returns a query option that adds the given table
// schema to the query.
func WithSchema(tbl TableSchema) QueryOption {
return func(opts *queryOpts) {
opts.Schema = tbl
}
}
// WithResult sets the result receiver. result is expected to
// be a pointer to a slice of struct or map types.
//
// For decoding DecodeStmt is used to decode each
// row into a new slice element. It thus supports special values
// like time.Time. See DecodeStmt() and WithDecodeConfig() for
// more information.
func WithResult(result interface{}) QueryOption {
return func(opts *queryOpts) {
opts.Result = result
}
}
// WithDecodeConfig configures the DecodeConfig to use when
// calling DecodeStmt to decode each row into the result slice.
//
// If not specified, DefaultDecodeConfig will be used.
func WithDecodeConfig(cfg DecodeConfig) QueryOption {
return func(opts *queryOpts) {
opts.DecodeConfig = cfg
}
}
// RunQuery executes the query stored in sql against the databased opened in
// conn. Please refer to the documentation of QueryOption, especially WithResult()
// for more information on how to retrieve the resulting rows.
//
// Example:
//
// var result []struct{
// Count int `sqlite:"rowCount"`
// }
//
// err := RunQuery(ctx, conn, "SELECT COUNT(*) AS rowCount FROM table", WithResult(&result))
// fmt.Println(result[0].Count)
func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...QueryOption) error {
args := queryOpts{
DecodeConfig: DefaultDecodeConfig,
}
for _, fn := range modifiers {
fn(&args)
}
opts := &sqlitex.ExecOptions{
Args: args.Args,
Named: args.NamedArgs,
}
var (
sliceVal reflect.Value
valElemType reflect.Type
)
if args.Result != nil {
target := args.Result
outVal := reflect.ValueOf(target)
if outVal.Kind() != reflect.Ptr {
return fmt.Errorf("target must be a pointer, got %T", target)
}
sliceVal = reflect.Indirect(outVal)
if !sliceVal.IsValid() || sliceVal.IsNil() {
newVal := reflect.Zero(outVal.Type().Elem())
sliceVal.Set(newVal)
}
kind := sliceVal.Kind()
if kind != reflect.Slice {
return fmt.Errorf("target but be pointer to slice, got %T", target)
}
valType := sliceVal.Type()
valElemType = valType.Elem()
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
currentField := reflect.New(valElemType)
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
resultDump := make(map[string]any)
for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ {
name := stmt.ColumnName(colIdx)
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // TODO: handle type BLOB?
case sqlite.TypeText:
resultDump[name] = stmt.ColumnText(colIdx)
case sqlite.TypeFloat:
resultDump[name] = stmt.ColumnFloat(colIdx)
case sqlite.TypeInteger:
resultDump[name] = stmt.ColumnInt(colIdx)
case sqlite.TypeNull:
resultDump[name] = "<null>"
}
}
return fmt.Errorf("%w: %+v", err, resultDump)
}
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))
return nil
}
}
var err error
if args.Transient {
err = sqlitex.ExecuteTransient(conn, sql, opts)
} else {
err = sqlitex.Execute(conn, sql, opts)
}
if err != nil {
return err
}
if args.Result != nil {
reflect.Indirect(reflect.ValueOf(args.Result)).Set(sliceVal)
}
return nil
}

View File

@@ -0,0 +1,288 @@
package orm
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"zombiezen.com/go/sqlite"
"github.com/safing/portbase/log"
)
var errSkipStructField = errors.New("struct field should be skipped")
// Struct Tags.
var (
TagUnixNano = "unixnano"
TagPrimaryKey = "primary"
TagAutoIncrement = "autoincrement"
TagTime = "time"
TagNotNull = "not-null"
TagNullable = "nullable"
TagTypeInt = "integer"
TagTypeText = "text"
TagTypePrefixVarchar = "varchar"
TagTypeBlob = "blob"
TagTypeFloat = "float"
TagTypePrefixDefault = "default="
)
var sqlTypeMap = map[sqlite.ColumnType]string{
sqlite.TypeBlob: "BLOB",
sqlite.TypeFloat: "REAL",
sqlite.TypeInteger: "INTEGER",
sqlite.TypeText: "TEXT",
}
type (
// TableSchema defines a SQL table schema.
TableSchema struct {
Name string
Columns []ColumnDef
}
// ColumnDef defines a SQL column.
ColumnDef struct { //nolint:maligned
Name string
Nullable bool
Type sqlite.ColumnType
GoType reflect.Type
Length int
PrimaryKey bool
AutoIncrement bool
UnixNano bool
IsTime bool
Default any
}
)
// GetColumnDef returns the column definition with the given name.
func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
for _, def := range ts.Columns {
if def.Name == name {
return &def
}
}
return nil
}
// CreateStatement build the CREATE SQL statement for the table.
func (ts TableSchema) CreateStatement(databaseName string, ifNotExists bool) string {
sql := "CREATE TABLE"
if ifNotExists {
sql += " IF NOT EXISTS"
}
name := ts.Name
if databaseName != "" {
name = databaseName + "." + ts.Name
}
sql += " " + name + " ( "
for idx, col := range ts.Columns {
sql += col.AsSQL()
if idx < len(ts.Columns)-1 {
sql += ", "
}
}
sql += " );"
return sql
}
// AsSQL builds the SQL column definition.
func (def ColumnDef) AsSQL() string {
sql := def.Name + " "
if def.Type == sqlite.TypeText && def.Length > 0 {
sql += fmt.Sprintf("VARCHAR(%d)", def.Length)
} else {
sql += sqlTypeMap[def.Type]
}
if def.PrimaryKey {
sql += " PRIMARY KEY"
}
if def.AutoIncrement {
sql += " AUTOINCREMENT"
}
if def.Default != nil {
sql += " DEFAULT "
switch def.Type { //nolint:exhaustive // TODO: handle types BLOB, NULL?
case sqlite.TypeFloat:
sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64) //nolint:forcetypeassert
case sqlite.TypeInteger:
sql += strconv.FormatInt(def.Default.(int64), 10) //nolint:forcetypeassert
case sqlite.TypeText:
sql += fmt.Sprintf("%q", def.Default.(string)) //nolint:forcetypeassert
default:
log.Errorf("unsupported default value: %q %q", def.Type, def.Default)
sql = strings.TrimSuffix(sql, " DEFAULT ")
}
sql += " "
}
if !def.Nullable {
sql += " NOT NULL"
}
return sql
}
// GenerateTableSchema generates a table schema from the given struct.
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
ts := &TableSchema{
Name: name,
}
val := reflect.Indirect(reflect.ValueOf(d))
if val.Kind() != reflect.Struct {
return nil, fmt.Errorf("%w, got %T", errStructExpected, d)
}
for i := 0; i < val.NumField(); i++ {
fieldType := val.Type().Field(i)
if !fieldType.IsExported() {
continue
}
def, err := getColumnDef(fieldType)
if err != nil {
if errors.Is(err, errSkipStructField) {
continue
}
return nil, fmt.Errorf("struct field %s: %w", fieldType.Name, err)
}
ts.Columns = append(ts.Columns, *def)
}
return ts, nil
}
func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
def := &ColumnDef{
Name: fieldType.Name,
Nullable: fieldType.Type.Kind() == reflect.Ptr,
}
ft := fieldType.Type
if fieldType.Type.Kind() == reflect.Ptr {
ft = fieldType.Type.Elem()
}
def.GoType = ft
kind := NormalizeKind(ft.Kind())
switch kind { //nolint:exhaustive
case reflect.Int, reflect.Uint:
def.Type = sqlite.TypeInteger
case reflect.Float64:
def.Type = sqlite.TypeFloat
case reflect.String:
def.Type = sqlite.TypeText
case reflect.Slice:
// only []byte/[]uint8 is supported
if ft.Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("slices of type %s is not supported", ft.Elem())
}
def.Type = sqlite.TypeBlob
}
if err := applyStructFieldTag(fieldType, def); err != nil {
return nil, err
}
return def, nil
}
// applyStructFieldTag parses the sqlite:"" struct field tag and update the column
// definition def accordingly.
func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
parts := strings.Split(fieldType.Tag.Get("sqlite"), ",")
if len(parts) > 0 && parts[0] != "" {
if parts[0] == "-" {
return errSkipStructField
}
def.Name = parts[0]
}
if len(parts) > 1 {
for _, k := range parts[1:] {
switch k {
// column modifiers
case TagPrimaryKey:
def.PrimaryKey = true
case TagAutoIncrement:
def.AutoIncrement = true
case TagNotNull:
def.Nullable = false
case TagNullable:
def.Nullable = true
case TagUnixNano:
def.UnixNano = true
case TagTime:
def.IsTime = true
// basic column types
case TagTypeInt:
def.Type = sqlite.TypeInteger
case TagTypeText:
def.Type = sqlite.TypeText
case TagTypeFloat:
def.Type = sqlite.TypeFloat
case TagTypeBlob:
def.Type = sqlite.TypeBlob
// advanced column types
default:
if strings.HasPrefix(k, TagTypePrefixVarchar) {
lenStr := strings.TrimSuffix(strings.TrimPrefix(k, TagTypePrefixVarchar+"("), ")")
length, err := strconv.ParseInt(lenStr, 10, 0)
if err != nil {
return fmt.Errorf("failed to parse varchar length %q: %w", lenStr, err)
}
def.Type = sqlite.TypeText
def.Length = int(length)
}
if strings.HasPrefix(k, TagTypePrefixDefault) {
defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault)
switch def.Type { //nolint:exhaustive
case sqlite.TypeFloat:
fv, err := strconv.ParseFloat(defaultValue, 64)
if err != nil {
return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err)
}
def.Default = fv
case sqlite.TypeInteger:
fv, err := strconv.ParseInt(defaultValue, 10, 0)
if err != nil {
return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err)
}
def.Default = fv
case sqlite.TypeText:
def.Default = defaultValue
case sqlite.TypeBlob:
return fmt.Errorf("default values for TypeBlob not yet supported")
default:
return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type)
}
}
}
}
}
return nil
}

View File

@@ -0,0 +1,43 @@
package orm
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSchemaBuilder(t *testing.T) {
t.Parallel()
cases := []struct {
Name string
Model interface{}
ExpectedSQL string
}{
{
"Simple",
struct {
ID int `sqlite:"id,primary,autoincrement"`
Text string `sqlite:"text,nullable"`
Int *int `sqlite:",not-null"`
Float interface{} `sqlite:",float,nullable"`
}{},
`CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
},
{
"Varchar",
struct {
S string `sqlite:",varchar(10)"`
}{},
`CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`,
},
}
for idx := range cases {
c := cases[idx]
res, err := GenerateTableSchema(c.Name, c.Model)
assert.NoError(t, err)
assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false))
}
}

675
service/netquery/query.go Normal file
View File

@@ -0,0 +1,675 @@
package netquery
import (
"context"
"encoding/json"
"fmt"
"io"
"reflect"
"sort"
"strings"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/slices"
"zombiezen.com/go/sqlite"
"github.com/safing/portmaster/service/netquery/orm"
)
// DatabaseName is a database name constant.
type DatabaseName string
// Databases.
const (
LiveDatabase = DatabaseName("main")
HistoryDatabase = DatabaseName("history")
)
// Collection of Query and Matcher types.
// NOTE: whenever adding support for new operators make sure
// to update UnmarshalJSON as well.
//
//nolint:golint
type (
Query map[string][]Matcher
MatchType interface {
Operator() string
}
Equal interface{}
Matcher struct {
Equal interface{} `json:"$eq,omitempty"`
NotEqual interface{} `json:"$ne,omitempty"`
In []interface{} `json:"$in,omitempty"`
NotIn []interface{} `json:"$notIn,omitempty"`
Like string `json:"$like,omitempty"`
Greater *float64 `json:"$gt,omitempty"`
GreaterOrEqual *float64 `json:"$ge,omitempty"`
Less *float64 `json:"$lt,omitempty"`
LessOrEqual *float64 `json:"$le,omitempty"`
}
Count struct {
As string `json:"as"`
Field string `json:"field"`
Distinct bool `json:"distinct"`
}
Sum struct {
Condition Query `json:"condition"`
Field string `json:"field"`
As string `json:"as"`
Distinct bool `json:"distinct"`
}
Min struct {
Condition *Query `json:"condition,omitempty"`
Field string `json:"field"`
As string `json:"as"`
Distinct bool `json:"distinct"`
}
FieldSelect struct {
Field string `json:"field"`
As string `json:"as"`
}
Select struct {
Field string `json:"field"`
FieldSelect *FieldSelect `json:"$field"`
Count *Count `json:"$count,omitempty"`
Sum *Sum `json:"$sum,omitempty"`
Min *Min `json:"$min,omitempty"`
Distinct *string `json:"$distinct,omitempty"`
}
Selects []Select
TextSearch struct {
Fields []string `json:"fields"`
Value string `json:"value"`
}
QueryActiveConnectionChartPayload struct {
Query Query `json:"query"`
TextSearch *TextSearch `json:"textSearch"`
}
OrderBy struct {
Field string `json:"field"`
Desc bool `json:"desc"`
}
OrderBys []OrderBy
Pagination struct {
PageSize int `json:"pageSize"`
Page int `json:"page"`
}
)
// UnmarshalJSON unmarshals a Query from json.
func (query *Query) UnmarshalJSON(blob []byte) error {
if *query == nil {
*query = make(Query)
}
var model map[string]json.RawMessage
if err := json.Unmarshal(blob, &model); err != nil {
return err
}
for columnName, rawColumnQuery := range model {
if len(rawColumnQuery) == 0 {
continue
}
switch rawColumnQuery[0] {
case '{':
m, err := parseMatcher(rawColumnQuery)
if err != nil {
return err
}
(*query)[columnName] = []Matcher{*m}
case '[':
var rawMatchers []json.RawMessage
if err := json.Unmarshal(rawColumnQuery, &rawMatchers); err != nil {
return err
}
(*query)[columnName] = make([]Matcher, len(rawMatchers))
for idx, val := range rawMatchers {
// this should not happen
if len(val) == 0 {
continue
}
// if val starts with a { we have a matcher definition
if val[0] == '{' {
m, err := parseMatcher(val)
if err != nil {
return err
}
(*query)[columnName][idx] = *m
continue
} else if val[0] == '[' {
return fmt.Errorf("invalid token [ in query for column %s", columnName)
}
// val is a dedicated JSON primitive and not an object or array
// so we treat that as an EQUAL condition.
var x interface{}
if err := json.Unmarshal(val, &x); err != nil {
return err
}
(*query)[columnName][idx] = Matcher{
Equal: x,
}
}
default:
// value is a JSON primitive and not an object or array
// so we treat that as an EQUAL condition.
var x interface{}
if err := json.Unmarshal(rawColumnQuery, &x); err != nil {
return err
}
(*query)[columnName] = []Matcher{
{Equal: x},
}
}
}
return nil
}
// TODO(ppacher): right now we only support LIMIT and OFFSET for pagination but that
// has an issue that loading the same page twice might yield different results due to
// new records shifting the result slice. To overcome this, return a "PageToken" to the
// user that includes the time the initial query was created so paginated queries can
// ensure new records don't end up in the result set.
func (page *Pagination) toSQLLimitOffsetClause() string {
limit := page.PageSize
// default and cap the limit to at most 100 items
// per page to avoid out-of-memory conditions when loading
// thousands of results at once.
if limit <= 0 || limit > 100 {
limit = 100
}
sql := fmt.Sprintf("LIMIT %d", limit)
if page.Page > 0 {
sql += fmt.Sprintf(" OFFSET %d", page.Page*limit)
}
return sql
}
func parseMatcher(raw json.RawMessage) (*Matcher, error) {
var m Matcher
if err := json.Unmarshal(raw, &m); err != nil {
return nil, err
}
if err := m.Validate(); err != nil {
return nil, fmt.Errorf("invalid query matcher: %w", err)
}
// log.Printf("parsed matcher %s: %+v", string(raw), m)
return &m, nil
}
// Validate validates the matcher.
func (match Matcher) Validate() error {
found := 0
if match.Equal != nil {
found++
}
if match.NotEqual != nil {
found++
}
if match.In != nil {
found++
}
if match.NotIn != nil {
found++
}
if match.Like != "" {
found++
}
if match.Greater != nil {
found++
}
if match.GreaterOrEqual != nil {
found++
}
if match.Less != nil {
found++
}
if match.LessOrEqual != nil {
found++
}
if found == 0 {
return fmt.Errorf("no conditions specified")
}
return nil
}
func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) {
var (
queryParts = make([]string, 0, len(text.Fields))
params = make(map[string]interface{})
)
key := fmt.Sprintf(":t%s", suffix)
params[key] = fmt.Sprintf("%%%s%%", text.Value)
for _, field := range text.Fields {
colDef := schema.GetColumnDef(field)
if colDef == nil {
return "", nil, fmt.Errorf("column %s is not allowed in text-search", colDef.Name)
}
if colDef.Type != sqlite.TypeText {
return "", nil, fmt.Errorf("type of column %s cannot be used in text-search", colDef.Name)
}
queryParts = append(queryParts, fmt.Sprintf("%s LIKE %s", colDef.Name, key))
}
if len(queryParts) == 0 {
return "", nil, nil
}
return "( " + strings.Join(queryParts, " OR ") + " )", params, nil
}
func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
var (
queryParts []string
params = make(map[string]interface{})
errs = new(multierror.Error)
key = fmt.Sprintf("%s%s", colDef.Name, suffix)
)
add := func(operator, suffix string, list bool, values ...interface{}) {
var placeholder []string
for idx, value := range values {
var (
encodedValue any
err error
)
kind := orm.NormalizeKind(reflect.TypeOf(value).Kind())
isNumber := slices.Contains([]reflect.Kind{
reflect.Uint,
reflect.Int,
reflect.Float64,
}, kind)
// if we query a time-field that is queried as a number, don't do any encoding
// here as the orm.DateTimeEncoder would convert the number to a string.
if colDef.IsTime && colDef.Type == sqlite.TypeText && isNumber {
encodedValue = value
} else {
encodedValue, err = orm.EncodeValue(ctx, &colDef, value, encoderConfig)
if err != nil {
errs.Errors = append(errs.Errors,
fmt.Errorf("failed to encode %v for column %s: %w", value, colDef.Name, err),
)
return
}
}
uniqKey := fmt.Sprintf(":%s%s%d", key, suffix, idx)
placeholder = append(placeholder, uniqKey)
params[uniqKey] = encodedValue
}
nameStmt := colDef.Name
if len(values) > 0 {
// NOTE(ppacher): for now we assume that the type of each element of values
// is the same. We also can be sure that there is always at least one value.
//
// TODO(ppacher): if we start supporting values of different types here
// we need to revisit the whole behavior as we might need to do more boolean
// expression nesting to support that.
kind := orm.NormalizeKind(reflect.TypeOf(values[0]).Kind())
isNumber := slices.Contains([]reflect.Kind{
reflect.Uint,
reflect.Int,
reflect.Float64,
}, kind)
// if this is a time column that is stored in "text" format and the provided
// value is a number type, we need to wrap the property in a strftime() method
// call.
if colDef.IsTime && colDef.Type == sqlite.TypeText && isNumber {
nameStmt = fmt.Sprintf("strftime('%%s', %s)+0", nameStmt)
}
}
if len(placeholder) == 1 && !list {
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", nameStmt, operator, placeholder[0]))
} else {
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", nameStmt, operator, strings.Join(placeholder, ", ")))
}
}
if match.Equal != nil {
add("=", "eq", false, match.Equal)
}
if match.NotEqual != nil {
add("!=", "ne", false, match.NotEqual)
}
if match.In != nil {
add("IN", "in", true, match.In...)
}
if match.NotIn != nil {
add("NOT IN", "notin", true, match.NotIn...)
}
if match.Like != "" {
add("LIKE", "like", false, match.Like)
}
if match.Greater != nil {
add(">", "gt", false, *match.Greater)
}
if match.GreaterOrEqual != nil {
add(">=", "ge", false, *match.GreaterOrEqual)
}
if match.Less != nil {
add("<", "lt", false, *match.Less)
}
if match.LessOrEqual != nil {
add("<=", "le", false, *match.LessOrEqual)
}
if len(queryParts) == 0 {
// this is an empty matcher without a single condition.
// we convert that to a no-op TRUE value
return "( 1 = 1 )", nil, errs.ErrorOrNil()
}
if len(queryParts) == 1 {
return queryParts[0], params, errs.ErrorOrNil()
}
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
}
func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
if len(query) == 0 {
return "", nil, nil
}
// create a lookup map to validate column names
lm := make(map[string]orm.ColumnDef, len(m.Columns))
for _, col := range m.Columns {
lm[col.Name] = col
}
paramMap := make(map[string]interface{})
columnStmts := make([]string, 0, len(query))
// get all keys and sort them so we get a stable output
queryKeys := make([]string, 0, len(query))
for column := range query {
queryKeys = append(queryKeys, column)
}
sort.Strings(queryKeys)
// actually create the WHERE clause parts for each
// column in query.
errs := new(multierror.Error)
for _, column := range queryKeys {
values := query[column]
colDef, ok := lm[column]
if !ok {
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
continue
}
queryParts := make([]string, len(values))
for idx, val := range values {
matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig)
if err != nil {
errs.Errors = append(errs.Errors,
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
)
continue
}
// merge parameters up into the superior parameter map
for key, val := range params {
if _, ok := paramMap[key]; ok {
// This is solely a developer mistake when implementing a matcher so no forgiving ...
panic("sqlite parameter collision")
}
paramMap[key] = val
}
queryParts[idx] = matcherQuery
}
columnStmts = append(columnStmts,
fmt.Sprintf("( %s )", strings.Join(queryParts, " OR ")),
)
}
whereClause := strings.Join(columnStmts, " AND ")
return whereClause, paramMap, errs.ErrorOrNil()
}
// UnmarshalJSON unmarshals a Selects from json.
func (sel *Selects) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
// if we are looking at a slice directly decode into
// a []Select
if blob[0] == '[' {
var result []Select
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
(*sel) = result
return nil
}
// if it's an object decode into a single select
if blob[0] == '{' {
var result Select
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*sel = []Select{result}
return nil
}
// otherwise this is just the field name
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
return nil
}
// UnmarshalJSON unmarshals a Select from json.
func (sel *Select) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
// if we have an object at hand decode the select
// directly
if blob[0] == '{' {
var res struct {
Field string `json:"field"`
Count *Count `json:"$count"`
Sum *Sum `json:"$sum"`
Min *Min `json:"$min"`
Distinct *string `json:"$distinct"`
FieldSelect *FieldSelect `json:"$field"`
}
if err := json.Unmarshal(blob, &res); err != nil {
return err
}
sel.Count = res.Count
sel.Field = res.Field
sel.FieldSelect = res.FieldSelect
sel.Distinct = res.Distinct
sel.Sum = res.Sum
sel.Min = res.Min
if sel.Count != nil && sel.Count.As != "" {
if !charOnlyRegexp.MatchString(sel.Count.As) {
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
}
}
if sel.Sum != nil && sel.Sum.As != "" {
if !charOnlyRegexp.MatchString(sel.Sum.As) {
return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+")
}
}
if sel.Min != nil && sel.Min.As != "" {
if !charOnlyRegexp.MatchString(sel.Min.As) {
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
}
}
if sel.FieldSelect != nil && sel.FieldSelect.As != "" {
if !charOnlyRegexp.MatchString(sel.FieldSelect.As) {
return fmt.Errorf("invalid characters in $field.as, value must match [a-zA-Z]+")
}
}
return nil
}
var x string
if err := json.Unmarshal(blob, &x); err != nil {
return err
}
sel.Field = x
return nil
}
// UnmarshalJSON unmarshals a OrderBys from json.
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
if blob[0] == '[' {
var result []OrderBy
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*orderBys = result
return nil
}
if blob[0] == '{' {
var result OrderBy
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*orderBys = []OrderBy{result}
return nil
}
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
*orderBys = []OrderBy{
{
Field: field,
Desc: false,
},
}
return nil
}
// UnmarshalJSON unmarshals a OrderBy from json.
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
if blob[0] == '{' {
var res struct {
Field string `json:"field"`
Desc bool `json:"desc"`
}
if err := json.Unmarshal(blob, &res); err != nil {
return err
}
orderBy.Desc = res.Desc
orderBy.Field = res.Field
return nil
}
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
orderBy.Field = field
orderBy.Desc = false
return nil
}

View File

@@ -0,0 +1,249 @@
package netquery
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"github.com/hashicorp/go-multierror"
servertiming "github.com/mitchellh/go-server-timing"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/netquery/orm"
)
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
type (
// QueryHandler implements http.Handler and allows to perform SQL
// query and aggregate functions on Database.
QueryHandler struct {
IsDevMode func() bool
Database *Database
}
// BatchQueryHandler implements http.Handler and allows to perform SQL
// query and aggregate functions on Database in batches.
BatchQueryHandler struct {
IsDevMode func() bool
Database *Database
}
)
func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
timing := servertiming.FromContext(req.Context())
timingQueryParsed := timing.NewMetric("query_parsed").
WithDesc("Query has been parsed").
Start()
requestPayload, err := parseQueryRequestPayload[QueryRequestPayload](req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
timingQueryParsed.Stop()
timingQueryBuilt := timing.NewMetric("query_built").
WithDesc("The SQL query has been built").
Start()
query, paramMap, err := requestPayload.generateSQL(req.Context(), qh.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
timingQueryBuilt.Stop()
timingQueryExecute := timing.NewMetric("sql_exec").
WithDesc("SQL query execution time").
Start()
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := qh.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
orm.WithSchema(*qh.Database.Schema),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
timingQueryExecute.Stop()
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
// prepare the result body that, in dev mode, contains
// some diagnostics data about the query
var resultBody map[string]interface{}
if qh.IsDevMode() {
resultBody = map[string]interface{}{
"sql_prep_stmt": query,
"sql_params": paramMap,
"query": requestPayload.Query,
"orderBy": requestPayload.OrderBy,
"groupBy": requestPayload.GroupBy,
"selects": requestPayload.Select,
}
} else {
resultBody = make(map[string]interface{})
}
resultBody["results"] = result
// and finally stream the response
if err := enc.Encode(resultBody); err != nil {
// we failed to encode the JSON body to resp so we likely either already sent a
// few bytes or the pipe was already closed. In either case, trying to send the
// error using http.Error() is non-sense. We just log it out here and that's all
// we can do.
log.Errorf("failed to encode JSON response: %s", err)
return
}
}
func (batch *BatchQueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
timing := servertiming.FromContext(req.Context())
timingQueryParsed := timing.NewMetric("query_parsed").
WithDesc("Query has been parsed").
Start()
requestPayload, err := parseQueryRequestPayload[BatchQueryRequestPayload](req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
timingQueryParsed.Stop()
response := make(map[string][]map[string]any, len(*requestPayload))
batches := make([]BatchExecute, 0, len(*requestPayload))
for key, query := range *requestPayload {
timingQueryBuilt := timing.NewMetric("query_built_" + key).
WithDesc("The SQL query has been built").
Start()
sql, paramMap, err := query.generateSQL(req.Context(), batch.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
timingQueryBuilt.Stop()
var result []map[string]any
batches = append(batches, BatchExecute{
ID: key,
SQL: sql,
Params: paramMap,
Result: &result,
})
}
timingQueryExecute := timing.NewMetric("sql_exec").
WithDesc("SQL query execution time").
Start()
status := http.StatusOK
if err := batch.Database.ExecuteBatch(req.Context(), batches); err != nil {
status = http.StatusInternalServerError
var merr *multierror.Error
if errors.As(err, &merr) {
for _, e := range merr.Errors {
resp.Header().Add("X-Query-Error", e.Error())
}
} else {
// Should not happen, ExecuteBatch always returns a multierror.Error
resp.WriteHeader(status)
return
}
}
timingQueryExecute.Stop()
// collect the results
for _, b := range batches {
response[b.ID] = *b.Result
}
// send the HTTP status code
resp.WriteHeader(status)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
// and finally stream the response
if err := enc.Encode(response); err != nil {
// we failed to encode the JSON body to resp so we likely either already sent a
// few bytes or the pipe was already closed. In either case, trying to send the
// error using http.Error() is non-sense. We just log it out here and that's all
// we can do.
log.Errorf("failed to encode JSON response: %s", err)
return
}
}
func parseQueryRequestPayload[T any](req *http.Request) (*T, error) { //nolint:dupl
var (
body io.Reader
requestPayload T
)
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
blob, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
// Compile time check.
var _ http.Handler = new(QueryHandler)

View File

@@ -0,0 +1,320 @@
package netquery
import (
"context"
"fmt"
"strings"
"golang.org/x/exp/slices"
"github.com/safing/portmaster/service/netquery/orm"
)
type (
// QueryRequestPayload describes the payload of a netquery query.
QueryRequestPayload struct {
Select Selects `json:"select"`
Query Query `json:"query"`
OrderBy OrderBys `json:"orderBy"`
GroupBy []string `json:"groupBy"`
TextSearch *TextSearch `json:"textSearch"`
// A list of databases to query. If left empty,
// both, the LiveDatabase and the HistoryDatabase are queried
Databases []DatabaseName `json:"databases"`
Pagination
selectedFields []string
whitelistedFields []string
paramMap map[string]interface{}
}
// BatchQueryRequestPayload describes the payload of a batch netquery
// query. The map key is used in the response to identify the results
// for each query of the batch request.
BatchQueryRequestPayload map[string]QueryRequestPayload
)
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
if err := req.prepareSelectedFields(ctx, schema); err != nil {
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
}
// build the SQL where clause from the payload query
whereClause, paramMap, err := req.Query.toSQLWhereClause(
ctx,
"",
schema,
orm.DefaultEncodeConfig,
)
if err != nil {
return "", nil, fmt.Errorf("generating where clause: %w", err)
}
req.mergeParams(paramMap)
if req.TextSearch != nil {
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
if err != nil {
return "", nil, fmt.Errorf("generating text-search clause: %w", err)
}
if textClause != "" {
if whereClause != "" {
whereClause += " AND "
}
whereClause += textClause
req.mergeParams(textParams)
}
}
groupByClause, err := req.generateGroupByClause(schema)
if err != nil {
return "", nil, fmt.Errorf("generating group-by clause: %w", err)
}
orderByClause, err := req.generateOrderByClause(schema)
if err != nil {
return "", nil, fmt.Errorf("generating order-by clause: %w", err)
}
selectClause := req.generateSelectClause()
if whereClause != "" {
whereClause = "WHERE " + whereClause
}
// if no database is specified we default to LiveDatabase only.
if len(req.Databases) == 0 {
req.Databases = []DatabaseName{LiveDatabase}
}
sources := make([]string, len(req.Databases))
for idx, db := range req.Databases {
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
}
source := strings.Join(sources, " UNION ")
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause()
return strings.TrimSpace(query), req.paramMap, nil
}
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
for idx, s := range req.Select {
var field string
switch {
case s.Count != nil:
field = s.Count.Field
case s.Distinct != nil:
field = *s.Distinct
case s.Sum != nil:
if s.Sum.Field != "" {
field = s.Sum.Field
} else {
field = "*"
}
case s.Min != nil:
if s.Min.Field != "" {
field = s.Min.Field
} else {
field = "*"
}
case s.FieldSelect != nil:
field = s.FieldSelect.Field
default:
field = s.Field
}
colName := "*"
if field != "*" || (s.Count == nil && s.Sum == nil) {
var err error
colName, err = req.validateColumnName(schema, field)
if err != nil {
return err
}
}
switch {
case s.FieldSelect != nil:
as := s.FieldSelect.As
if as == "" {
as = s.FieldSelect.Field
}
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("%s AS %s", s.FieldSelect.Field, as),
)
req.whitelistedFields = append(req.whitelistedFields, as)
case s.Count != nil:
as := s.Count.As
if as == "" {
as = fmt.Sprintf("%s_count", colName)
}
distinct := ""
if s.Count.Distinct {
distinct = "DISTINCT "
}
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
)
req.whitelistedFields = append(req.whitelistedFields, as)
case s.Sum != nil:
if s.Sum.As == "" {
return fmt.Errorf("missing 'as' for $sum")
}
var (
clause string
params map[string]any
)
if s.Sum.Field != "" {
clause = s.Sum.Field
} else {
var err error
clause, params, err = s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $sum: %w", err)
}
}
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
case s.Min != nil:
if s.Min.As == "" {
return fmt.Errorf("missing 'as' for $min")
}
var (
clause string
params map[string]any
)
if s.Min.Field != "" {
clause = field
} else {
var err error
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $min: %w", err)
}
}
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
case s.Distinct != nil:
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
req.whitelistedFields = append(req.whitelistedFields, colName)
default:
req.selectedFields = append(req.selectedFields, colName)
}
}
return nil
}
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
if req.paramMap == nil {
req.paramMap = make(map[string]any)
}
for key, value := range params {
req.paramMap[key] = value
}
}
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
if len(req.GroupBy) == 0 {
return "", nil
}
groupBys := make([]string, len(req.GroupBy))
for idx, name := range req.GroupBy {
colName, err := req.validateColumnName(schema, name)
if err != nil {
return "", err
}
groupBys[idx] = colName
}
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
// if there are no explicitly selected fields we default to the
// group-by columns as that's what's expected most of the time anyway...
if len(req.selectedFields) == 0 {
req.selectedFields = append(req.selectedFields, groupBys...)
}
return groupByClause, nil
}
func (req *QueryRequestPayload) generateSelectClause() string {
selectClause := "*"
if len(req.selectedFields) > 0 {
selectClause = strings.Join(req.selectedFields, ", ")
}
return selectClause
}
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
if len(req.OrderBy) == 0 {
return "", nil
}
orderBys := make([]string, len(req.OrderBy))
for idx, sort := range req.OrderBy {
colName, err := req.validateColumnName(schema, sort.Field)
if err != nil {
return "", err
}
if sort.Desc {
orderBys[idx] = fmt.Sprintf("%s DESC", colName)
} else {
orderBys[idx] = fmt.Sprintf("%s ASC", colName)
}
}
return "ORDER BY " + strings.Join(orderBys, ", "), nil
}
func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) {
colDef := schema.GetColumnDef(field)
if colDef != nil {
return colDef.Name, nil
}
if slices.Contains(req.whitelistedFields, field) {
return field, nil
}
if slices.Contains(req.selectedFields, field) {
return field, nil
}
return "", fmt.Errorf("column name %q not allowed", field)
}

View File

@@ -0,0 +1,250 @@
package netquery
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/safing/portmaster/service/netquery/orm"
)
func TestUnmarshalQuery(t *testing.T) { //nolint:tparallel
t.Parallel()
cases := []struct {
Name string
Input string
Expected Query
Error error
}{
{
"Parse a simple query",
`{ "domain": ["example.com", "example.at"] }`,
Query{
"domain": []Matcher{
{
Equal: "example.com",
},
{
Equal: "example.at",
},
},
},
nil,
},
{
"Parse a more complex query",
`
{
"domain": [
{
"$in": [
"example.at",
"example.com"
]
},
{
"$like": "microsoft.%"
}
],
"path": [
"/bin/ping",
{
"$notin": [
"/sbin/ping",
"/usr/sbin/ping"
]
}
]
}
`,
Query{
"domain": []Matcher{
{
In: []interface{}{
"example.at",
"example.com",
},
},
{
Like: "microsoft.%",
},
},
"path": []Matcher{
{
Equal: "/bin/ping",
},
{
NotIn: []interface{}{
"/sbin/ping",
"/usr/sbin/ping",
},
},
},
},
nil,
},
}
for _, testCase := range cases { //nolint:paralleltest
c := testCase
t.Run(c.Name, func(t *testing.T) {
var q Query
err := json.Unmarshal([]byte(c.Input), &q)
if c.Error != nil {
if assert.Error(t, err) {
assert.Equal(t, c.Error.Error(), err.Error())
}
} else {
assert.NoError(t, err)
assert.Equal(t, c.Expected, q)
}
})
}
}
func TestQueryBuilder(t *testing.T) { //nolint:tparallel
t.Parallel()
now := time.Now()
cases := []struct {
N string
Q Query
R string
P map[string]interface{}
E error
}{
{
"No filter",
nil,
"",
nil,
nil,
},
{
"Simple, one-column filter",
Query{"domain": []Matcher{
{
Equal: "example.com",
},
{
Equal: "example.at",
},
}},
"( domain = :domain0eq0 OR domain = :domain1eq0 )",
map[string]interface{}{
":domain0eq0": "example.com",
":domain1eq0": "example.at",
},
nil,
},
{
"Two column filter",
Query{
"domain": []Matcher{
{
Equal: "example.com",
},
},
"path": []Matcher{
{
Equal: "/bin/curl",
},
{
Equal: "/bin/ping",
},
},
},
"( domain = :domain0eq0 ) AND ( path = :path0eq0 OR path = :path1eq0 )",
map[string]interface{}{
":domain0eq0": "example.com",
":path0eq0": "/bin/curl",
":path1eq0": "/bin/ping",
},
nil,
},
{
"Time based filter",
Query{
"started": []Matcher{
{
Equal: now.Format(time.RFC3339),
},
},
},
"( started = :started0eq0 )",
map[string]interface{}{
":started0eq0": now.In(time.UTC).Format(orm.SqliteTimeFormat),
},
nil,
},
{
"Invalid column access",
Query{
"forbiddenField": []Matcher{{}},
},
"",
nil,
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), //nolint:golint
},
{
"Complex example",
Query{
"domain": []Matcher{
{
In: []interface{}{"example.at", "example.com"},
},
{
Like: "microsoft.%",
},
},
"path": []Matcher{
{
NotIn: []interface{}{
"/bin/ping",
"/sbin/ping",
"/usr/bin/ping",
},
},
},
},
"( domain IN ( :domain0in0, :domain0in1 ) OR domain LIKE :domain1like0 ) AND ( path NOT IN ( :path0notin0, :path0notin1, :path0notin2 ) )",
map[string]interface{}{
":domain0in0": "example.at",
":domain0in1": "example.com",
":domain1like0": "microsoft.%",
":path0notin0": "/bin/ping",
":path0notin1": "/sbin/ping",
":path0notin2": "/usr/bin/ping",
},
nil,
},
}
tbl, err := orm.GenerateTableSchema("connections", Conn{})
require.NoError(t, err)
for idx, testCase := range cases { //nolint:paralleltest
cID := idx
c := testCase
t.Run(c.N, func(t *testing.T) {
str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
if c.E != nil {
if assert.Error(t, err) {
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", cID)
}
} else {
assert.NoError(t, err, "test case %d", cID)
assert.Equal(t, c.P, params, "test case %d", cID)
assert.Equal(t, c.R, str, "test case %d", cID)
}
})
}
}

View File

@@ -0,0 +1,77 @@
package netquery
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portbase/runtime"
"github.com/safing/portmaster/service/netquery/orm"
)
// RuntimeQueryRunner provides a simple interface for the runtime database
// that allows direct SQL queries to be performed against db.
// Each resulting row of that query are marshaled as map[string]interface{}
// and returned as a single record to the caller.
//
// Using portbase/database#Query is not possible because portbase/database will
// complain about the SQL query being invalid. To work around that issue,
// RuntimeQueryRunner uses a 'GET key' request where the SQL query is embedded into
// the record key.
type RuntimeQueryRunner struct {
db *Database
reg *runtime.Registry
keyPrefix string
}
// NewRuntimeQueryRunner returns a new runtime SQL query runner that parses
// and serves SQL queries form GET <prefix>/<plain sql query> requests.
func NewRuntimeQueryRunner(db *Database, prefix string, reg *runtime.Registry) (*RuntimeQueryRunner, error) {
runner := &RuntimeQueryRunner{
db: db,
reg: reg,
keyPrefix: prefix,
}
if _, err := reg.Register(prefix, runtime.SimpleValueGetterFunc(runner.get)); err != nil {
return nil, fmt.Errorf("failed to register runtime value provider: %w", err)
}
return runner, nil
}
func (runner *RuntimeQueryRunner) get(keyOrPrefix string) ([]record.Record, error) {
query := strings.TrimPrefix(
keyOrPrefix,
runner.keyPrefix,
)
log.Infof("netquery: executing custom SQL query: %q", query)
var result []map[string]interface{}
if err := runner.db.Execute(context.Background(), query, orm.WithResult(&result)); err != nil {
return nil, fmt.Errorf("failed to perform query %q: %w", query, err)
}
// we need to wrap the result slice into a map as portbase/database attempts
// to inject a _meta field.
blob, err := json.Marshal(map[string]interface{}{
"result": result,
})
if err != nil {
return nil, fmt.Errorf("failed to marshal result: %w", err)
}
// construct a new record wrapper that uses the already prepared JSON blob.
key := fmt.Sprintf("%s:%s", runner.reg.DatabaseName(), keyOrPrefix)
wrapper, err := record.NewWrapper(key, new(record.Meta), dsd.JSON, blob)
if err != nil {
return nil, fmt.Errorf("failed to create record wrapper: %w", err)
}
return []record.Record{wrapper}, nil
}