wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
141
service/netquery/active_chart_handler.go
Normal file
141
service/netquery/active_chart_handler.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
)
|
||||
|
||||
// ActiveChartHandler handles requests for connection charts.
|
||||
type ActiveChartHandler struct {
|
||||
Database *Database
|
||||
}
|
||||
|
||||
func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
|
||||
requestPayload, err := ch.parseRequest(req)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// actually execute the query against the database and collect the result
|
||||
var result []map[string]interface{}
|
||||
if err := ch.Database.Execute(
|
||||
req.Context(),
|
||||
query,
|
||||
orm.WithNamedArgs(paramMap),
|
||||
orm.WithResult(&result),
|
||||
orm.WithSchema(*ch.Database.Schema),
|
||||
); err != nil {
|
||||
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// send the HTTP status code
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
|
||||
// prepare the result encoder.
|
||||
enc := json.NewEncoder(resp)
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
|
||||
"results": result,
|
||||
"query": query,
|
||||
"params": paramMap,
|
||||
})
|
||||
}
|
||||
|
||||
func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
|
||||
var body io.Reader
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodPost, http.MethodPut:
|
||||
body = req.Body
|
||||
case http.MethodGet:
|
||||
body = strings.NewReader(req.URL.Query().Get("q"))
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid HTTP method")
|
||||
}
|
||||
|
||||
var requestPayload QueryActiveConnectionChartPayload
|
||||
blob, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body" + err.Error())
|
||||
}
|
||||
|
||||
body = bytes.NewReader(blob)
|
||||
|
||||
dec := json.NewDecoder(body)
|
||||
dec.DisallowUnknownFields()
|
||||
|
||||
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
return &requestPayload, nil
|
||||
}
|
||||
|
||||
func (req *QueryActiveConnectionChartPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||
template := `
|
||||
WITH RECURSIVE epoch(x) AS (
|
||||
SELECT strftime('%%s')-600
|
||||
UNION ALL
|
||||
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
|
||||
)
|
||||
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked
|
||||
FROM epoch
|
||||
JOIN connections
|
||||
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
|
||||
%s
|
||||
GROUP BY round(timestamp/10, 0)*10;`
|
||||
|
||||
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
params = make(map[string]interface{})
|
||||
}
|
||||
|
||||
if req.TextSearch != nil {
|
||||
textSearch, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if textSearch != "" {
|
||||
if clause != "" {
|
||||
clause += " AND "
|
||||
}
|
||||
clause += textSearch
|
||||
|
||||
for key, val := range textParams {
|
||||
params[key] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if clause == "" {
|
||||
return fmt.Sprintf(template, ""), map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf(template, "WHERE ( "+clause+")"), params, nil
|
||||
}
|
||||
148
service/netquery/bandwidth_chart_handler.go
Normal file
148
service/netquery/bandwidth_chart_handler.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
)
|
||||
|
||||
// BandwidthChartHandler handles requests for connection charts.
|
||||
type BandwidthChartHandler struct {
|
||||
Database *Database
|
||||
}
|
||||
|
||||
// BandwidthChartRequest holds a request for a bandwidth chart.
|
||||
type BandwidthChartRequest struct {
|
||||
Interval int `json:"interval"`
|
||||
Query Query `json:"query"`
|
||||
GroupBy []string `json:"groupBy"`
|
||||
}
|
||||
|
||||
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
|
||||
requestPayload, err := ch.parseRequest(req)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// actually execute the query against the database and collect the result
|
||||
var result []map[string]interface{}
|
||||
if err := ch.Database.Execute(
|
||||
req.Context(),
|
||||
query,
|
||||
orm.WithNamedArgs(paramMap),
|
||||
orm.WithResult(&result),
|
||||
orm.WithSchema(*ch.Database.Schema),
|
||||
); err != nil {
|
||||
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// send the HTTP status code
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
|
||||
// prepare the result encoder.
|
||||
enc := json.NewEncoder(resp)
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
|
||||
"results": result,
|
||||
"query": query,
|
||||
"params": paramMap,
|
||||
})
|
||||
}
|
||||
|
||||
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
|
||||
var body io.Reader
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodPost, http.MethodPut:
|
||||
body = req.Body
|
||||
case http.MethodGet:
|
||||
body = strings.NewReader(req.URL.Query().Get("q"))
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid HTTP method")
|
||||
}
|
||||
|
||||
var requestPayload BandwidthChartRequest
|
||||
blob, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body" + err.Error())
|
||||
}
|
||||
|
||||
body = bytes.NewReader(blob)
|
||||
|
||||
dec := json.NewDecoder(body)
|
||||
dec.DisallowUnknownFields()
|
||||
|
||||
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
return &requestPayload, nil
|
||||
}
|
||||
|
||||
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||
if req.Interval == 0 {
|
||||
req.Interval = 10
|
||||
}
|
||||
|
||||
interval := fmt.Sprintf("round(time/%d, 0)*%d", req.Interval, req.Interval)
|
||||
|
||||
// make sure there are only allowed fields specified in the request group-by
|
||||
for _, gb := range req.GroupBy {
|
||||
def := schema.GetColumnDef(gb)
|
||||
if def == nil {
|
||||
return "", nil, fmt.Errorf("unsupported groupBy key: %q", gb)
|
||||
}
|
||||
}
|
||||
|
||||
selects := append([]string{
|
||||
interval + " as timestamp",
|
||||
"SUM(incoming) as incoming",
|
||||
"SUM(outgoing) as outgoing",
|
||||
}, req.GroupBy...)
|
||||
|
||||
groupBy := append([]string{interval}, req.GroupBy...)
|
||||
|
||||
whereClause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if whereClause != "" {
|
||||
whereClause = "WHERE " + whereClause
|
||||
}
|
||||
|
||||
template := fmt.Sprintf(
|
||||
`SELECT %s
|
||||
FROM main.bandwidth AS bw
|
||||
JOIN main.connections AS conns
|
||||
ON bw.conn_id = conns.id
|
||||
%s
|
||||
GROUP BY %s
|
||||
ORDER BY time ASC`,
|
||||
strings.Join(selects, ", "),
|
||||
whereClause,
|
||||
strings.Join(groupBy, ", "),
|
||||
)
|
||||
|
||||
return template, params, nil
|
||||
}
|
||||
681
service/netquery/database.go
Normal file
681
service/netquery/database.go
Normal file
@@ -0,0 +1,681 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/jackc/puddle/v2"
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
)
|
||||
|
||||
// InMemory is the "file path" to open a new in-memory database.
|
||||
const InMemory = "file:inmem.db?mode=memory"
|
||||
|
||||
// Available connection types as their string representation.
|
||||
const (
|
||||
ConnTypeDNS = "dns"
|
||||
ConnTypeIP = "ip"
|
||||
)
|
||||
|
||||
// ConnectionTypeToString is a lookup map to get the string representation
|
||||
// of a network.ConnectionType as used by this package.
|
||||
var ConnectionTypeToString = map[network.ConnectionType]string{
|
||||
network.DNSRequest: ConnTypeDNS,
|
||||
network.IPConnection: ConnTypeIP,
|
||||
}
|
||||
|
||||
type (
|
||||
// Database represents a SQLite3 backed connection database.
|
||||
// It's use is tailored for persistence and querying of network.Connection.
|
||||
// Access to the underlying SQLite database is synchronized.
|
||||
//
|
||||
Database struct {
|
||||
Schema *orm.TableSchema
|
||||
|
||||
readConnPool *puddle.Pool[*sqlite.Conn]
|
||||
historyPath string
|
||||
|
||||
l sync.Mutex
|
||||
writeConn *sqlite.Conn
|
||||
}
|
||||
|
||||
// BatchExecute executes multiple queries in one transaction.
|
||||
BatchExecute struct {
|
||||
ID string
|
||||
SQL string
|
||||
Params map[string]any
|
||||
Result *[]map[string]any
|
||||
}
|
||||
|
||||
// Conn is a network connection that is stored in a SQLite database and accepted
|
||||
// by the *Database type of this package. This also defines, using the ./orm package,
|
||||
// the table schema and the model that is exposed via the runtime database as well as
|
||||
// the query API.
|
||||
//
|
||||
// Use ConvertConnection from this package to convert a network.Connection to this
|
||||
// representation.
|
||||
Conn struct { //nolint:maligned
|
||||
// ID is a device-unique identifier for the connection. It is built
|
||||
// from network.Connection by hashing the connection ID and the start
|
||||
// time. We cannot just use the network.Connection.ID because it is only unique
|
||||
// as long as the connection is still active and might be, although unlikely,
|
||||
// reused afterwards.
|
||||
ID string `sqlite:"id,primary"`
|
||||
ProfileID string `sqlite:"profile"`
|
||||
Path string `sqlite:"path"`
|
||||
Type string `sqlite:"type,varchar(8)"`
|
||||
External bool `sqlite:"external"`
|
||||
IPVersion packet.IPVersion `sqlite:"ip_version"`
|
||||
IPProtocol packet.IPProtocol `sqlite:"ip_protocol"`
|
||||
LocalIP string `sqlite:"local_ip"`
|
||||
LocalPort uint16 `sqlite:"local_port"`
|
||||
RemoteIP string `sqlite:"remote_ip"`
|
||||
RemotePort uint16 `sqlite:"remote_port"`
|
||||
Domain string `sqlite:"domain"`
|
||||
Country string `sqlite:"country,varchar(2)"`
|
||||
ASN uint `sqlite:"asn"`
|
||||
ASOwner string `sqlite:"as_owner"`
|
||||
Latitude float64 `sqlite:"latitude"`
|
||||
Longitude float64 `sqlite:"longitude"`
|
||||
Scope netutils.IPScope `sqlite:"scope"`
|
||||
WorstVerdict network.Verdict `sqlite:"worst_verdict"`
|
||||
ActiveVerdict network.Verdict `sqlite:"verdict"`
|
||||
FirewallVerdict network.Verdict `sqlite:"firewall_verdict"`
|
||||
Started time.Time `sqlite:"started,text,time"`
|
||||
Ended *time.Time `sqlite:"ended,text,time"`
|
||||
Tunneled bool `sqlite:"tunneled"`
|
||||
Encrypted bool `sqlite:"encrypted"`
|
||||
Internal bool `sqlite:"internal"`
|
||||
Direction string `sqlite:"direction"`
|
||||
ExtraData json.RawMessage `sqlite:"extra_data"`
|
||||
Allowed *bool `sqlite:"allowed"`
|
||||
ProfileRevision int `sqlite:"profile_revision"`
|
||||
ExitNode *string `sqlite:"exit_node"`
|
||||
BytesReceived uint64 `sqlite:"bytes_received,default=0"`
|
||||
BytesSent uint64 `sqlite:"bytes_sent,default=0"`
|
||||
|
||||
// TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
|
||||
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
|
||||
|
||||
// TODO(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here
|
||||
ProfileName string `sqlite:"profile_name"`
|
||||
}
|
||||
)
|
||||
|
||||
// New opens a new in-memory database named path and attaches a persistent history database.
|
||||
//
|
||||
// The returned Database used connection pooling for read-only connections
|
||||
// (see Execute). To perform database writes use either Save() or ExecuteWrite().
|
||||
// Note that write connections are serialized by the Database object before being
|
||||
// handed over to SQLite.
|
||||
func New(dbPath string) (*Database, error) {
|
||||
historyParentDir := dataroot.Root().ChildDir("databases", 0o700)
|
||||
if err := historyParentDir.Ensure(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure database directory exists: %w", err)
|
||||
}
|
||||
|
||||
// Get file location of history database.
|
||||
historyFile := filepath.Join(historyParentDir.Path, "history.db")
|
||||
// Convert to SQLite URI path.
|
||||
historyURI := "file:///" + strings.TrimPrefix(filepath.ToSlash(historyFile), "/")
|
||||
|
||||
constructor := func(ctx context.Context) (*sqlite.Conn, error) {
|
||||
c, err := sqlite.OpenConn(
|
||||
dbPath,
|
||||
sqlite.OpenReadOnly,
|
||||
sqlite.OpenSharedCache,
|
||||
sqlite.OpenURI,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", dbPath, err)
|
||||
}
|
||||
|
||||
if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE '"+historyURI+"?mode=ro' AS history", nil); err != nil {
|
||||
return nil, fmt.Errorf("failed to attach history database: %w", err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
destructor := func(resource *sqlite.Conn) {
|
||||
if err := resource.Close(); err != nil {
|
||||
log.Errorf("failed to close pooled SQlite database connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
pool, err := puddle.NewPool(&puddle.Config[*sqlite.Conn]{
|
||||
Constructor: constructor,
|
||||
Destructor: destructor,
|
||||
MaxSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schema, err := orm.GenerateTableSchema("connections", Conn{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
writeConn, err := sqlite.OpenConn(
|
||||
dbPath,
|
||||
sqlite.OpenCreate,
|
||||
sqlite.OpenReadWrite,
|
||||
sqlite.OpenWAL,
|
||||
sqlite.OpenSharedCache,
|
||||
sqlite.OpenURI,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open sqlite at %s: %w", dbPath, err)
|
||||
}
|
||||
|
||||
return &Database{
|
||||
readConnPool: pool,
|
||||
Schema: schema,
|
||||
writeConn: writeConn,
|
||||
historyPath: historyURI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewInMemory is like New but creates a new in-memory database and
|
||||
// automatically applies the connection table schema.
|
||||
func NewInMemory() (*Database, error) {
|
||||
db, err := New(InMemory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// this should actually never happen because an in-memory database
|
||||
// always starts empty...
|
||||
if err := db.ApplyMigrations(); err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare database: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Close closes the database, including pools and connections.
|
||||
func (db *Database) Close() error {
|
||||
db.readConnPool.Close()
|
||||
|
||||
if err := db.writeConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VacuumHistory rewrites the history database in order to purge deleted records.
|
||||
func VacuumHistory(ctx context.Context) (err error) {
|
||||
historyParentDir := dataroot.Root().ChildDir("databases", 0o700)
|
||||
if err := historyParentDir.Ensure(); err != nil {
|
||||
return fmt.Errorf("failed to ensure database directory exists: %w", err)
|
||||
}
|
||||
|
||||
// Get file location of history database.
|
||||
historyFile := filepath.Join(historyParentDir.Path, "history.db")
|
||||
// Convert to SQLite URI path.
|
||||
historyURI := "file:///" + strings.TrimPrefix(filepath.ToSlash(historyFile), "/")
|
||||
|
||||
writeConn, err := sqlite.OpenConn(
|
||||
historyURI,
|
||||
sqlite.OpenCreate,
|
||||
sqlite.OpenReadWrite,
|
||||
sqlite.OpenWAL,
|
||||
sqlite.OpenSharedCache,
|
||||
sqlite.OpenURI,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := writeConn.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
return orm.RunQuery(ctx, writeConn, "VACUUM")
|
||||
}
|
||||
|
||||
// ApplyMigrations applies any table and data migrations that are needed
|
||||
// to bring db up-to-date with the built-in schema.
|
||||
// TODO(ppacher): right now this only applies the current schema and ignores
|
||||
// any data-migrations. Once the history module is implemented this should
|
||||
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration.
|
||||
func (db *Database) ApplyMigrations() error {
|
||||
db.l.Lock()
|
||||
defer db.l.Unlock()
|
||||
|
||||
if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE '"+db.historyPath+"?mode=rwc' AS 'history';", nil); err != nil {
|
||||
return fmt.Errorf("failed to attach history database: %w", err)
|
||||
}
|
||||
|
||||
dbNames := []string{"main", "history"}
|
||||
for _, dbName := range dbNames {
|
||||
// get the create-table SQL statement from the inferred schema
|
||||
sql := db.Schema.CreateStatement(dbName, true)
|
||||
log.Debugf("creating table schema for database %q", dbName)
|
||||
|
||||
// execute the SQL
|
||||
if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil {
|
||||
return fmt.Errorf("failed to create schema on database %q: %w", dbName, err)
|
||||
}
|
||||
|
||||
// create a few indexes
|
||||
indexes := []string{
|
||||
`CREATE INDEX IF NOT EXISTS %sprofile_id_index ON %s (profile)`,
|
||||
`CREATE INDEX IF NOT EXISTS %sstarted_time_index ON %s (strftime('%%s', started)+0)`,
|
||||
`CREATE INDEX IF NOT EXISTS %sstarted_ended_time_index ON %s (strftime('%%s', started)+0, strftime('%%s', ended)+0) WHERE ended IS NOT NULL`,
|
||||
}
|
||||
for _, idx := range indexes {
|
||||
name := ""
|
||||
if dbName != "" {
|
||||
name = dbName + "."
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf(idx, name, db.Schema.Name)
|
||||
|
||||
if err := sqlitex.ExecuteTransient(db.writeConn, stmt, nil); err != nil {
|
||||
return fmt.Errorf("failed to create index on database %q: %q: %w", dbName, idx, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth (
|
||||
conn_id TEXT NOT NULL,
|
||||
time INTEGER NOT NULL,
|
||||
incoming INTEGER NOT NULL,
|
||||
outgoing INTEGER NOT NULL,
|
||||
CONSTRAINT fk_conn_id
|
||||
FOREIGN KEY(conn_id) REFERENCES connections(id)
|
||||
ON DELETE CASCADE
|
||||
)`
|
||||
if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil {
|
||||
return fmt.Errorf("failed to create main.bandwidth database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) withConn(ctx context.Context, fn func(conn *sqlite.Conn) error) error {
|
||||
res, err := db.readConnPool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Release()
|
||||
|
||||
return fn(res.Value())
|
||||
}
|
||||
|
||||
// ExecuteWrite executes a custom SQL query using a writable connection against the SQLite
|
||||
// database used by db.
|
||||
// It uses orm.RunQuery() under the hood so please refer to the orm package for
|
||||
// more information about available options.
|
||||
func (db *Database) ExecuteWrite(ctx context.Context, sql string, args ...orm.QueryOption) error {
|
||||
db.l.Lock()
|
||||
defer db.l.Unlock()
|
||||
|
||||
return orm.RunQuery(ctx, db.writeConn, sql, args...)
|
||||
}
|
||||
|
||||
// Execute executes a custom SQL query using a read-only connection against the SQLite
|
||||
// database used by db.
|
||||
// It uses orm.RunQuery() under the hood so please refer to the orm package for
|
||||
// more information about available options.
|
||||
func (db *Database) Execute(ctx context.Context, sql string, args ...orm.QueryOption) error {
|
||||
return db.withConn(ctx, func(conn *sqlite.Conn) error {
|
||||
return orm.RunQuery(ctx, conn, sql, args...)
|
||||
})
|
||||
}
|
||||
|
||||
// ExecuteBatch executes multiple custom SQL query using a read-only connection against the SQLite
|
||||
// database used by db.
|
||||
func (db *Database) ExecuteBatch(ctx context.Context, batches []BatchExecute) error {
|
||||
return db.withConn(ctx, func(conn *sqlite.Conn) error {
|
||||
merr := new(multierror.Error)
|
||||
|
||||
for _, batch := range batches {
|
||||
if err := orm.RunQuery(ctx, conn, batch.SQL, orm.WithNamedArgs(batch.Params), orm.WithResult(batch.Result)); err != nil {
|
||||
merr.Errors = append(merr.Errors, fmt.Errorf("%s: %w", batch.ID, err))
|
||||
}
|
||||
}
|
||||
|
||||
return merr.ErrorOrNil()
|
||||
})
|
||||
}
|
||||
|
||||
// CountRows returns the number of rows stored in the database.
|
||||
func (db *Database) CountRows(ctx context.Context) (int, error) {
|
||||
var result []struct {
|
||||
Count int `sqlite:"count"`
|
||||
}
|
||||
|
||||
if err := db.Execute(ctx, "SELECT COUNT(*) AS count FROM (SELECT * FROM main.connections UNION SELECT * from history.connections)", orm.WithResult(&result)); err != nil {
|
||||
return 0, fmt.Errorf("failed to perform query: %w", err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
return 0, fmt.Errorf("unexpected number of rows returned, expected 1 got %d", len(result))
|
||||
}
|
||||
|
||||
return result[0].Count, nil
|
||||
}
|
||||
|
||||
// Cleanup removes all connections that have ended before threshold from the live database.
|
||||
//
|
||||
// NOTE(ppacher): there is no easy way to get the number of removed
|
||||
// rows other than counting them in a first step. Though, that's
|
||||
// probably not worth the cylces...
|
||||
func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) {
|
||||
where := `WHERE ended IS NOT NULL
|
||||
AND datetime(ended) < datetime(:threshold)`
|
||||
sql := "DELETE FROM main.connections " + where + ";"
|
||||
|
||||
args := orm.WithNamedArgs(map[string]interface{}{
|
||||
":threshold": threshold.UTC().Format(orm.SqliteTimeFormat),
|
||||
})
|
||||
|
||||
var result []struct {
|
||||
Count int `sqlite:"count"`
|
||||
}
|
||||
if err := db.Execute(
|
||||
ctx,
|
||||
"SELECT COUNT(*) AS count FROM connections "+where,
|
||||
args,
|
||||
orm.WithTransient(),
|
||||
orm.WithResult(&result),
|
||||
); err != nil {
|
||||
return 0, fmt.Errorf("failed to perform query: %w", err)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
return 0, fmt.Errorf("unexpected number of rows, expected 1 got %d", len(result))
|
||||
}
|
||||
|
||||
err := db.ExecuteWrite(ctx, sql, args)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result[0].Count, nil
|
||||
}
|
||||
|
||||
// RemoveAllHistoryData removes all connections from the history database.
|
||||
func (db *Database) RemoveAllHistoryData(ctx context.Context) error {
|
||||
query := fmt.Sprintf("DELETE FROM %s.connections", HistoryDatabase)
|
||||
return db.ExecuteWrite(ctx, query)
|
||||
}
|
||||
|
||||
// RemoveHistoryForProfile removes all connections from the history database
|
||||
// for a given profile ID (source/id).
|
||||
func (db *Database) RemoveHistoryForProfile(ctx context.Context, profileID string) error {
|
||||
query := fmt.Sprintf("DELETE FROM %s.connections WHERE profile = :profile", HistoryDatabase)
|
||||
return db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
|
||||
":profile": profileID,
|
||||
}))
|
||||
}
|
||||
|
||||
// MigrateProfileID migrates the given profile IDs in the history database.
|
||||
// This needs to be done when profiles are deleted and replaced by a different profile.
|
||||
func (db *Database) MigrateProfileID(ctx context.Context, from string, to string) error {
|
||||
return db.ExecuteWrite(ctx, "UPDATE history.connections SET profile = :to WHERE profile = :from", orm.WithNamedArgs(map[string]any{
|
||||
":from": from,
|
||||
":to": to,
|
||||
}))
|
||||
}
|
||||
|
||||
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database
|
||||
// as JSON to w.
|
||||
// Any error aborts dumping rows and is returned.
|
||||
func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:unused
|
||||
var conns []Conn
|
||||
err := db.withConn(ctx, func(conn *sqlite.Conn) error {
|
||||
return sqlitex.ExecuteTransient(conn, "SELECT * FROM connections", &sqlitex.ExecOptions{
|
||||
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||
var c Conn
|
||||
if err := orm.DecodeStmt(ctx, db.Schema, stmt, &c, orm.DefaultDecodeConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conns = append(conns, c)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
return enc.Encode(conns)
|
||||
}
|
||||
|
||||
// CleanupHistory deletes history data outside of the (per-app) retention time frame.
|
||||
func (db *Database) CleanupHistory(ctx context.Context) error {
|
||||
// Setup tracer for the clean up process.
|
||||
ctx, tracer := log.AddTracer(ctx)
|
||||
defer tracer.Submit()
|
||||
|
||||
// Get list of profiles in history.
|
||||
query := "SELECT DISTINCT profile FROM history.connections"
|
||||
var result []struct {
|
||||
Profile string `sqlite:"profile"`
|
||||
}
|
||||
if err := db.Execute(ctx, query, orm.WithResult(&result)); err != nil {
|
||||
return fmt.Errorf("failed to get a list of profiles from the history database: %w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
// Get global retention days - do not delete in case of error.
|
||||
globalRetentionDays = config.GetAsInt(profile.CfgOptionKeepHistoryKey, 0)()
|
||||
|
||||
profileName string
|
||||
retentionDays int64
|
||||
|
||||
profileCnt int
|
||||
merr = new(multierror.Error)
|
||||
)
|
||||
for _, row := range result {
|
||||
// Get profile and retention days.
|
||||
id := strings.TrimPrefix(row.Profile, string(profile.SourceLocal)+"/")
|
||||
p, err := profile.GetLocalProfile(id, nil, nil)
|
||||
if err == nil {
|
||||
profileName = p.String()
|
||||
retentionDays = p.LayeredProfile().KeepHistory()
|
||||
} else {
|
||||
// Getting profile failed, fallback to global setting.
|
||||
tracer.Errorf("history: failed to load profile for id %s: %s", id, err)
|
||||
profileName = row.Profile
|
||||
retentionDays = globalRetentionDays
|
||||
}
|
||||
|
||||
// Skip deleting if history should be kept forever.
|
||||
if retentionDays == 0 {
|
||||
tracer.Tracef("history: retention is disabled for %s, skipping", profileName)
|
||||
continue
|
||||
}
|
||||
// Count profiles where connections were deleted.
|
||||
profileCnt++
|
||||
|
||||
// TODO: count cleared connections
|
||||
threshold := time.Now().Add(-1 * time.Duration(retentionDays) * time.Hour * 24)
|
||||
if err := db.ExecuteWrite(ctx,
|
||||
"DELETE FROM history.connections WHERE profile = :profile AND active = FALSE AND datetime(started) < datetime(:threshold)",
|
||||
orm.WithNamedArgs(map[string]any{
|
||||
":profile": row.Profile,
|
||||
":threshold": threshold.Format(orm.SqliteTimeFormat),
|
||||
}),
|
||||
); err != nil {
|
||||
tracer.Warningf("history: failed to delete connections of %s: %s", profileName, err)
|
||||
merr.Errors = append(merr.Errors, fmt.Errorf("profile %s: %w", row.Profile, err))
|
||||
} else {
|
||||
tracer.Debugf(
|
||||
"history: deleted connections older than %d days (before %s) of %s",
|
||||
retentionDays,
|
||||
threshold.Format(time.RFC822),
|
||||
profileName,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Log summary.
|
||||
tracer.Infof("history: deleted connections outside of retention from %d profiles", profileCnt)
|
||||
|
||||
return merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// MarkAllHistoryConnectionsEnded marks all connections in the history database as ended.
|
||||
func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
|
||||
query := fmt.Sprintf("UPDATE %s.connections SET active = FALSE, ended = :ended WHERE active = TRUE", HistoryDatabase)
|
||||
|
||||
if err := db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
|
||||
":ended": time.Now().Format(orm.SqliteTimeFormat),
|
||||
})); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
|
||||
// the bandwidth data to the history database.
|
||||
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
|
||||
params := map[string]any{
|
||||
":id": makeNqIDFromParts(processKey, connID),
|
||||
}
|
||||
|
||||
parts := []string{}
|
||||
parts = append(parts, "bytes_received = (bytes_received + :bytes_received)")
|
||||
params[":bytes_received"] = bytesReceived
|
||||
parts = append(parts, "bytes_sent = (bytes_sent + :bytes_sent)")
|
||||
params[":bytes_sent"] = bytesSent
|
||||
|
||||
updateSet := strings.Join(parts, ", ")
|
||||
|
||||
updateStmts := []string{
|
||||
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, LiveDatabase, updateSet),
|
||||
}
|
||||
|
||||
if enableHistory {
|
||||
updateStmts = append(updateStmts,
|
||||
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, HistoryDatabase, updateSet),
|
||||
)
|
||||
}
|
||||
|
||||
merr := new(multierror.Error)
|
||||
for _, stmt := range updateStmts {
|
||||
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
|
||||
merr.Errors = append(merr.Errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
// also add the date to the in-memory bandwidth database
|
||||
params[":time"] = time.Now().Unix()
|
||||
stmt := "INSERT INTO main.bandwidth (conn_id, time, incoming, outgoing) VALUES(:id, :time, :bytes_received, :bytes_sent)"
|
||||
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
|
||||
merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err))
|
||||
}
|
||||
|
||||
return merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// Save inserts the connection conn into the SQLite database. If conn
|
||||
// already exists the table row is updated instead.
|
||||
//
|
||||
// Save uses the database write connection instead of relying on the
|
||||
// connection pool.
|
||||
func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error {
|
||||
// convert the connection to a param map where each key is already translated
|
||||
// to the sql column name. We also skip bytes_received and bytes_sent since those
|
||||
// will be updated independently from the connection object.
|
||||
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig, []string{
|
||||
"bytes_received",
|
||||
"bytes_sent",
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode connection for SQL: %w", err)
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(connMap))
|
||||
placeholders := make([]string, 0, len(connMap))
|
||||
values := make(map[string]interface{}, len(connMap))
|
||||
updateSets := make([]string, 0, len(connMap))
|
||||
|
||||
// sort keys so we get a stable SQLite query that can be better cached.
|
||||
keys := make([]string, 0, len(connMap))
|
||||
for key := range connMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
value := connMap[key]
|
||||
|
||||
columns = append(columns, key)
|
||||
placeholders = append(placeholders, ":"+key)
|
||||
values[":"+key] = value
|
||||
updateSets = append(updateSets, fmt.Sprintf("%s = :%s", key, key))
|
||||
}
|
||||
|
||||
db.l.Lock()
|
||||
defer db.l.Unlock()
|
||||
|
||||
// TODO(ppacher): make sure this one can be cached to speed up inserting
|
||||
// and save some CPU cycles for the user
|
||||
dbNames := []DatabaseName{LiveDatabase}
|
||||
|
||||
// TODO: Should we only add ended connection to the history database to save
|
||||
// a couple INSERTs per connection?
|
||||
// This means we need to write the current live DB to the history DB on
|
||||
// shutdown in order to be able to pick the back up after a restart.
|
||||
|
||||
// Save to history DB if enabled.
|
||||
if enableHistory {
|
||||
dbNames = append(dbNames, HistoryDatabase)
|
||||
}
|
||||
|
||||
for _, dbName := range dbNames {
|
||||
sql := fmt.Sprintf(
|
||||
`INSERT INTO %s.connections (%s)
|
||||
VALUES(%s)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
%s
|
||||
`,
|
||||
dbName,
|
||||
strings.Join(columns, ", "),
|
||||
strings.Join(placeholders, ", "),
|
||||
strings.Join(updateSets, ", "),
|
||||
)
|
||||
|
||||
if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{
|
||||
Named: values,
|
||||
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||
log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount())
|
||||
return nil
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
277
service/netquery/manager.go
Normal file
277
service/netquery/manager.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/runtime"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
)
|
||||
|
||||
type (
|
||||
// ConnectionStore describes the interface that is used by Manager
|
||||
// to save new or updated connection objects.
|
||||
// It is implemented by the *Database type of this package.
|
||||
ConnectionStore interface {
|
||||
// Save is called to perists the new or updated connection. If required,
|
||||
// It's up to the implementation to figure out if the operation is an
|
||||
// insert or an update.
|
||||
// The ID of Conn is unique and can be trusted to never collide with other
|
||||
// connections of the save device.
|
||||
Save(context.Context, Conn, bool) error
|
||||
|
||||
// MarkAllHistoryConnectionsEnded marks all active connections in the history
|
||||
// database as ended NOW.
|
||||
MarkAllHistoryConnectionsEnded(context.Context) error
|
||||
|
||||
// RemoveAllHistoryData removes all connections from the history database.
|
||||
RemoveAllHistoryData(context.Context) error
|
||||
|
||||
// RemoveHistoryForProfile removes all connections from the history database.
|
||||
// for a given profile ID (source/id)
|
||||
RemoveHistoryForProfile(context.Context, string) error
|
||||
|
||||
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
|
||||
// the bandwidth data to the history database.
|
||||
UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
|
||||
|
||||
// CleanupHistory deletes data outside of the retention time frame from the history database.
|
||||
CleanupHistory(ctx context.Context) error
|
||||
|
||||
// Close closes the connection store. It must not be used afterwards.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Manager handles new and updated network.Connections feeds and persists them
|
||||
// at a connection store.
|
||||
// Manager also registers itself as a runtime database and pushes updates to
|
||||
// connections using the local format.
|
||||
// Users should use this update feed rather than the deprecated "network:" database.
|
||||
Manager struct {
|
||||
store ConnectionStore
|
||||
push runtime.PushFunc
|
||||
runtimeReg *runtime.Registry
|
||||
pushPrefix string
|
||||
}
|
||||
)
|
||||
|
||||
// NewManager returns a new connection manager that persists all newly created or
|
||||
// updated connections at store.
|
||||
func NewManager(store ConnectionStore, pushPrefix string, reg *runtime.Registry) (*Manager, error) {
|
||||
mng := &Manager{
|
||||
store: store,
|
||||
runtimeReg: reg,
|
||||
pushPrefix: pushPrefix,
|
||||
}
|
||||
|
||||
push, err := reg.Register(pushPrefix, runtime.SimpleValueGetterFunc(mng.runtimeGet))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mng.push = push
|
||||
|
||||
return mng, nil
|
||||
}
|
||||
|
||||
func (mng *Manager) runtimeGet(keyOrPrefix string) ([]record.Record, error) {
|
||||
// TODO(ppacher):
|
||||
// we don't yet support querying using the runtime database here ...
|
||||
// consider exposing connection from the database at least by ID.
|
||||
//
|
||||
// NOTE(ppacher):
|
||||
// for debugging purposes use RuntimeQueryRunner to execute plain
|
||||
// SQL queries against the database using portbase/database/runtime.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// HandleFeed starts reading new and updated connections from feed and persists them
|
||||
// in the configured ConnectionStore. HandleFeed blocks until either ctx is cancelled
|
||||
// or feed is closed.
|
||||
// Any errors encountered when processing new or updated connections are logged but
|
||||
// otherwise ignored.
|
||||
// HandleFeed handles and persists updates one after each other! Depending on the system
|
||||
// load the user might want to use a buffered channel for feed.
|
||||
func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connection) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case conn, ok := <-feed:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
|
||||
if !conn.DataIsComplete() {
|
||||
return
|
||||
}
|
||||
|
||||
model, err := convertConnection(conn)
|
||||
if err != nil {
|
||||
log.Errorf("netquery: failed to convert connection %s to sqlite model: %s", conn.ID, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// DEBUG:
|
||||
// log.Tracef("netquery: updating connection %s", conn.ID)
|
||||
|
||||
// Save to netquery database.
|
||||
// Do not include internal connections in history.
|
||||
if err := mng.store.Save(ctx, *model, conn.HistoryEnabled); err != nil {
|
||||
log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// we clone the record metadata from the connection
|
||||
// into the new model so the portbase/database layer
|
||||
// can handle NEW/UPDATE correctly.
|
||||
cloned := conn.Meta().Duplicate()
|
||||
|
||||
// push an update for the connection
|
||||
if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil {
|
||||
log.Errorf("netquery: failed to push update for conn %s via database system: %s", conn.ID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mng *Manager) pushConnUpdate(_ context.Context, meta record.Meta, conn Conn) error {
|
||||
blob, err := json.Marshal(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal connection: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s:%s%s", mng.runtimeReg.DatabaseName(), mng.pushPrefix, conn.ID)
|
||||
wrapper, err := record.NewWrapper(
|
||||
key,
|
||||
&meta,
|
||||
dsd.JSON,
|
||||
blob,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create record wrapper: %w", err)
|
||||
}
|
||||
|
||||
mng.push(wrapper)
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertConnection converts conn to the local representation used
|
||||
// to persist the information in SQLite.
|
||||
// The caller must hold the lock to the given network.Connection.
|
||||
func convertConnection(conn *network.Connection) (*Conn, error) {
|
||||
direction := "outbound"
|
||||
if conn.Inbound {
|
||||
direction = "inbound"
|
||||
}
|
||||
|
||||
c := Conn{
|
||||
ID: makeNqIDFromConn(conn),
|
||||
External: conn.External,
|
||||
IPVersion: conn.IPVersion,
|
||||
IPProtocol: conn.IPProtocol,
|
||||
LocalIP: conn.LocalIP.String(),
|
||||
LocalPort: conn.LocalPort,
|
||||
ActiveVerdict: conn.Verdict,
|
||||
Started: time.Unix(conn.Started, 0),
|
||||
Tunneled: conn.Tunneled,
|
||||
Encrypted: conn.Encrypted,
|
||||
Internal: conn.Internal,
|
||||
Direction: direction,
|
||||
Type: ConnectionTypeToString[conn.Type],
|
||||
ProfileID: conn.ProcessContext.Source + "/" + conn.ProcessContext.Profile,
|
||||
Path: conn.ProcessContext.BinaryPath,
|
||||
ProfileRevision: int(conn.ProfileRevisionCounter),
|
||||
ProfileName: conn.ProcessContext.ProfileName,
|
||||
}
|
||||
|
||||
switch conn.Type {
|
||||
case network.DNSRequest:
|
||||
c.Type = "dns"
|
||||
case network.IPConnection:
|
||||
c.Type = "ip"
|
||||
case network.Undefined:
|
||||
c.Type = ""
|
||||
}
|
||||
|
||||
c.Allowed = &conn.ConnectionEstablished
|
||||
|
||||
if conn.Ended > 0 {
|
||||
ended := time.Unix(conn.Ended, 0)
|
||||
c.Ended = &ended
|
||||
c.Active = false
|
||||
} else {
|
||||
c.Active = true
|
||||
}
|
||||
|
||||
extraData := map[string]interface{}{
|
||||
"pid": conn.ProcessContext.PID,
|
||||
"processCreatedAt": conn.ProcessContext.CreatedAt,
|
||||
}
|
||||
|
||||
if conn.TunnelContext != nil {
|
||||
extraData["tunnel"] = conn.TunnelContext
|
||||
exitNode := conn.TunnelContext.GetExitNodeID()
|
||||
c.ExitNode = &exitNode
|
||||
}
|
||||
|
||||
if conn.DNSContext != nil {
|
||||
extraData["dns"] = conn.DNSContext
|
||||
}
|
||||
|
||||
// TODO(ppacher): enable when TLS inspection is merged
|
||||
// if conn.TLSContext != nil {
|
||||
// extraData["tls"] = conn.TLSContext
|
||||
// }
|
||||
|
||||
if conn.Entity != nil {
|
||||
extraData["cname"] = conn.Entity.CNAME
|
||||
extraData["blockedByLists"] = conn.Entity.BlockedByLists
|
||||
extraData["blockedEntities"] = conn.Entity.BlockedEntities
|
||||
extraData["reason"] = conn.Reason
|
||||
|
||||
c.RemoteIP = conn.Entity.IP.String()
|
||||
c.RemotePort = conn.Entity.Port
|
||||
c.Domain = conn.Entity.Domain
|
||||
c.Country = conn.Entity.Country
|
||||
c.ASN = conn.Entity.ASN
|
||||
c.ASOwner = conn.Entity.ASOrg
|
||||
c.Scope = conn.Entity.IPScope
|
||||
if conn.Entity.Coordinates != nil {
|
||||
c.Latitude = conn.Entity.Coordinates.Latitude
|
||||
c.Longitude = conn.Entity.Coordinates.Longitude
|
||||
}
|
||||
}
|
||||
|
||||
// pre-compute the JSON blob for the extra data column
|
||||
// and assign it.
|
||||
extraDataBlob, err := json.Marshal(extraData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal extra data: %w", err)
|
||||
}
|
||||
c.ExtraData = extraDataBlob
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// makeNqIDFromConn creates a netquery connection ID from the network connection.
|
||||
func makeNqIDFromConn(conn *network.Connection) string {
|
||||
return makeNqIDFromParts(conn.Process().GetKey(), conn.ID)
|
||||
}
|
||||
|
||||
// makeNqIDFromParts creates a netquery connection ID from the given network
|
||||
// connection ID and the process key.
|
||||
func makeNqIDFromParts(processKey string, netConnID string) string {
|
||||
return processKey + "-" + netConnID
|
||||
}
|
||||
315
service/netquery/module_api.go
Normal file
315
service/netquery/module_api.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
servertiming "github.com/mitchellh/go-server-timing"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
"github.com/safing/portbase/runtime"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
)
|
||||
|
||||
// DefaultModule is the default netquery module.
|
||||
var DefaultModule *module
|
||||
|
||||
type module struct {
|
||||
*modules.Module
|
||||
|
||||
Store *Database
|
||||
|
||||
db *database.Interface
|
||||
mng *Manager
|
||||
feed chan *network.Connection
|
||||
}
|
||||
|
||||
func init() {
|
||||
DefaultModule = new(module)
|
||||
|
||||
DefaultModule.Module = modules.Register(
|
||||
"netquery",
|
||||
DefaultModule.prepare,
|
||||
DefaultModule.start,
|
||||
DefaultModule.stop,
|
||||
"api",
|
||||
"network",
|
||||
"database",
|
||||
)
|
||||
|
||||
subsystems.Register(
|
||||
"history",
|
||||
"Network History",
|
||||
"Keep Network History Data",
|
||||
DefaultModule.Module,
|
||||
"config:history/",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (m *module) prepare() error {
|
||||
var err error
|
||||
|
||||
m.db = database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
// TODO: Open database in start() phase.
|
||||
m.Store, err = NewInMemory()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create in-memory database: %w", err)
|
||||
}
|
||||
|
||||
m.mng, err = NewManager(m.Store, "netquery/data/", runtime.DefaultRegistry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create manager: %w", err)
|
||||
}
|
||||
|
||||
m.feed = make(chan *network.Connection, 1000)
|
||||
|
||||
queryHander := &QueryHandler{
|
||||
Database: m.Store,
|
||||
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
|
||||
}
|
||||
|
||||
batchHander := &BatchQueryHandler{
|
||||
Database: m.Store,
|
||||
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
|
||||
}
|
||||
|
||||
chartHandler := &ActiveChartHandler{
|
||||
Database: m.Store,
|
||||
}
|
||||
|
||||
bwChartHandler := &BandwidthChartHandler{
|
||||
Database: m.Store,
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Name: "Query Connections",
|
||||
Description: "Query the in-memory sqlite connection database.",
|
||||
Path: "netquery/query",
|
||||
MimeType: "application/json",
|
||||
Read: api.PermitUser, // Needs read+write as the query is sent using POST data.
|
||||
Write: api.PermitUser, // Needs read+write as the query is sent using POST data.
|
||||
BelongsTo: m.Module,
|
||||
HandlerFunc: servertiming.Middleware(queryHander, nil).ServeHTTP,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Name: "Batch Query Connections",
|
||||
Description: "Batch query the in-memory sqlite connection database.",
|
||||
Path: "netquery/query/batch",
|
||||
MimeType: "application/json",
|
||||
Read: api.PermitUser, // Needs read+write as the query is sent using POST data.
|
||||
Write: api.PermitUser, // Needs read+write as the query is sent using POST data.
|
||||
BelongsTo: m.Module,
|
||||
HandlerFunc: servertiming.Middleware(batchHander, nil).ServeHTTP,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Name: "Active Connections Chart",
|
||||
Description: "Query the in-memory sqlite connection database and return a chart of active connections.",
|
||||
Path: "netquery/charts/connection-active",
|
||||
MimeType: "application/json",
|
||||
Write: api.PermitUser,
|
||||
BelongsTo: m.Module,
|
||||
HandlerFunc: servertiming.Middleware(chartHandler, nil).ServeHTTP,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
// TODO: Use query parameters instead.
|
||||
Path: "netquery/charts/bandwidth",
|
||||
MimeType: "application/json",
|
||||
Write: api.PermitUser,
|
||||
BelongsTo: m.Module,
|
||||
HandlerFunc: bwChartHandler.ServeHTTP,
|
||||
Name: "Bandwidth Chart",
|
||||
Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Name: "Remove connections from profile history",
|
||||
Description: "Remove all connections from the history database for one or more profiles",
|
||||
Path: "netquery/history/clear",
|
||||
MimeType: "application/json",
|
||||
Write: api.PermitUser,
|
||||
BelongsTo: m.Module,
|
||||
ActionFunc: func(ar *api.Request) (msg string, err error) {
|
||||
var body struct {
|
||||
ProfileIDs []string `json:"profileIDs"`
|
||||
}
|
||||
if err := json.Unmarshal(ar.InputData, &body); err != nil {
|
||||
return "", fmt.Errorf("failed to decode parameters in body: %w", err)
|
||||
}
|
||||
|
||||
if len(body.ProfileIDs) == 0 {
|
||||
if err := m.mng.store.RemoveAllHistoryData(ar.Context()); err != nil {
|
||||
return "", fmt.Errorf("failed to remove all history: %w", err)
|
||||
}
|
||||
} else {
|
||||
merr := new(multierror.Error)
|
||||
for _, profileID := range body.ProfileIDs {
|
||||
if err := m.mng.store.RemoveHistoryForProfile(ar.Context(), profileID); err != nil {
|
||||
merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err))
|
||||
} else {
|
||||
log.Infof("netquery: successfully cleared history for %s", profileID)
|
||||
}
|
||||
}
|
||||
|
||||
if err := merr.ErrorOrNil(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return "Successfully cleared history.", nil
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Name: "Apply connection history retention threshold",
|
||||
Path: "netquery/history/cleanup",
|
||||
Write: api.PermitUser,
|
||||
BelongsTo: m.Module,
|
||||
ActionFunc: func(ar *api.Request) (msg string, err error) {
|
||||
if err := m.Store.CleanupHistory(ar.Context()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "Deleted outdated connections.", nil
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *module) start() error {
|
||||
m.StartServiceWorker("netquery connection feed listener", 0, func(ctx context.Context) error {
|
||||
sub, err := m.db.Subscribe(query.New("network:"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to network tree: %w", err)
|
||||
}
|
||||
defer close(m.feed)
|
||||
defer func() {
|
||||
_ = sub.Cancel()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case rec, ok := <-sub.Feed:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, ok := rec.(*network.Connection)
|
||||
if !ok {
|
||||
// This is fine as we also receive process updates on
|
||||
// this channel.
|
||||
continue
|
||||
}
|
||||
|
||||
m.feed <- conn
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
m.StartServiceWorker("netquery connection feed handler", 0, func(ctx context.Context) error {
|
||||
m.mng.HandleFeed(ctx, m.feed)
|
||||
return nil
|
||||
})
|
||||
|
||||
m.StartServiceWorker("netquery live db cleaner", 0, func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(10 * time.Second):
|
||||
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
|
||||
count, err := m.Store.Cleanup(ctx, threshold)
|
||||
if err != nil {
|
||||
log.Errorf("netquery: failed to removed old connections from live db: %s", err)
|
||||
} else {
|
||||
log.Tracef("netquery: successfully removed %d old connections from live db that ended before %s", count, threshold)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
m.NewTask("network history cleaner", func(ctx context.Context, _ *modules.Task) error {
|
||||
return m.Store.CleanupHistory(ctx)
|
||||
}).Repeat(time.Hour).Schedule(time.Now().Add(10 * time.Minute))
|
||||
|
||||
// For debugging, provide a simple direct SQL query interface using
|
||||
// the runtime database.
|
||||
// Only expose in development mode.
|
||||
if config.GetAsBool(config.CfgDevModeKey, false)() {
|
||||
_, err := NewRuntimeQueryRunner(m.Store, "netquery/query/", runtime.DefaultRegistry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate profile IDs in history database when profiles are migrated/merged.
|
||||
if err := m.RegisterEventHook(
|
||||
"profiles",
|
||||
"profile migrated",
|
||||
"migrate profile IDs in history database",
|
||||
func(ctx context.Context, data interface{}) error {
|
||||
if profileIDs, ok := data.([]string); ok && len(profileIDs) == 2 {
|
||||
return m.Store.MigrateProfileID(ctx, profileIDs[0], profileIDs[1])
|
||||
}
|
||||
return nil
|
||||
},
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *module) stop() error {
|
||||
// we don't use m.Module.Ctx here because it is already cancelled when stop is called.
|
||||
// just give the clean up 1 minute to happen and abort otherwise.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil {
|
||||
// handle the error by just logging it. There's not much we can do here
|
||||
// and returning an error to the module system doesn't help much as well...
|
||||
log.Errorf("netquery: failed to mark connections in history database as ended: %s", err)
|
||||
}
|
||||
|
||||
if err := m.mng.store.Close(); err != nil {
|
||||
log.Errorf("netquery: failed to close sqlite database: %s", err)
|
||||
} else {
|
||||
// Clear deleted connections from database.
|
||||
if err := VacuumHistory(ctx); err != nil {
|
||||
log.Errorf("netquery: failed to execute VACUUM in history database: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
483
service/netquery/orm/decoder.go
Normal file
483
service/netquery/orm/decoder.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// Commonly used error messages when working with orm.
|
||||
var (
|
||||
errStructExpected = errors.New("encode: can only encode structs to maps")
|
||||
errStructPointerExpected = errors.New("decode: result must be pointer to a struct type or map[string]interface{}")
|
||||
errUnexpectedColumnType = errors.New("decode: unexpected column type")
|
||||
)
|
||||
|
||||
// constants used when transforming data to and from sqlite.
|
||||
var (
|
||||
// sqliteTimeFromat defines the string representation that is
|
||||
// expected by SQLite DATETIME functions.
|
||||
// Note that SQLite itself does not include support for a DATETIME
|
||||
// column type. Instead, dates and times are stored either as INTEGER,
|
||||
// TEXT or REAL.
|
||||
// This package provides support for time.Time being stored as TEXT (using a
|
||||
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
|
||||
// unixepoch and unixnano-epoch where the nano variant is not officially supported by
|
||||
// SQLITE).
|
||||
SqliteTimeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
type (
|
||||
|
||||
// Stmt describes the interface that must be implemented in order to
|
||||
// be decodable to a struct type using DecodeStmt. This interface is implemented
|
||||
// by *sqlite.Stmt.
|
||||
Stmt interface {
|
||||
ColumnCount() int
|
||||
ColumnName(int) string
|
||||
ColumnType(int) sqlite.ColumnType
|
||||
ColumnText(int) string
|
||||
ColumnBool(int) bool
|
||||
ColumnFloat(int) float64
|
||||
ColumnInt(int) int
|
||||
ColumnReader(int) *bytes.Reader
|
||||
}
|
||||
|
||||
// DecodeFunc is called for each non-basic type during decoding.
|
||||
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
|
||||
|
||||
// DecodeConfig holds decoding functions.
|
||||
DecodeConfig struct {
|
||||
DecodeHooks []DecodeFunc
|
||||
}
|
||||
)
|
||||
|
||||
// DecodeStmt decodes the current result row loaded in Stmt into the struct or map type result.
|
||||
// Decoding hooks configured in cfg are executed before trying to decode basic types and may
|
||||
// be specified to provide support for special types.
|
||||
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
|
||||
func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result interface{}, cfg DecodeConfig) error {
|
||||
// make sure we got something to decode into ...
|
||||
if result == nil {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// fast path for decoding into a map
|
||||
if mp, ok := result.(*map[string]interface{}); ok {
|
||||
return decodeIntoMap(ctx, schema, stmt, mp, cfg)
|
||||
}
|
||||
|
||||
// make sure we got a pointer in result
|
||||
if reflect.TypeOf(result).Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// make sure it's a poiter to a struct type
|
||||
t := reflect.ValueOf(result).Elem().Type()
|
||||
if t.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// if result is a nil pointer make sure to allocate some space
|
||||
// for the resulting struct
|
||||
resultValue := reflect.ValueOf(result)
|
||||
if resultValue.IsNil() {
|
||||
resultValue.Set(
|
||||
reflect.New(t),
|
||||
)
|
||||
}
|
||||
|
||||
// we need access to the struct directly and not to the
|
||||
// pointer.
|
||||
target := reflect.Indirect(resultValue)
|
||||
|
||||
// create a lookup map from field name (or sqlite:"" tag)
|
||||
// to the field name
|
||||
lm := make(map[string]string)
|
||||
for i := 0; i < target.NumField(); i++ {
|
||||
fieldType := t.Field(i)
|
||||
|
||||
// skip unexported fields
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
lm[sqlColumnName(fieldType)] = fieldType.Name
|
||||
}
|
||||
|
||||
// iterate over all columns and assign them to the correct
|
||||
// fields
|
||||
for i := 0; i < stmt.ColumnCount(); i++ {
|
||||
colName := stmt.ColumnName(i)
|
||||
fieldName, ok := lm[colName]
|
||||
if !ok {
|
||||
// there's no target field for this column
|
||||
// so we can skip it
|
||||
continue
|
||||
}
|
||||
fieldType, _ := t.FieldByName(fieldName)
|
||||
|
||||
value := target.FieldByName(fieldName)
|
||||
|
||||
colType := stmt.ColumnType(i)
|
||||
|
||||
// if the column is reported as NULL we keep
|
||||
// the field as it is.
|
||||
// TODO(ppacher): should we set it to nil here?
|
||||
if colType == sqlite.TypeNull {
|
||||
continue
|
||||
}
|
||||
|
||||
// if value is a nil pointer we need to allocate some memory
|
||||
// first
|
||||
if getKind(value) == reflect.Ptr && value.IsNil() {
|
||||
storage := reflect.New(fieldType.Type.Elem())
|
||||
|
||||
value.Set(storage)
|
||||
|
||||
// make sure value actually points the
|
||||
// dereferenced target storage
|
||||
value = storage.Elem()
|
||||
}
|
||||
|
||||
colDef := schema.GetColumnDef(colName)
|
||||
|
||||
// execute all decode hooks but make sure we use decodeBasic() as the
|
||||
// last one.
|
||||
columnValue, err := runDecodeHooks(
|
||||
i,
|
||||
colDef,
|
||||
stmt,
|
||||
fieldType,
|
||||
value,
|
||||
append(cfg.DecodeHooks, decodeBasic()),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we don't have a converted value now we try to
|
||||
// decode basic types
|
||||
if columnValue == nil {
|
||||
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
|
||||
}
|
||||
|
||||
// Debugging:
|
||||
// log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
|
||||
|
||||
// convert it to the target type if conversion is possible
|
||||
newValue := reflect.ValueOf(columnValue)
|
||||
if newValue.Type().ConvertibleTo(value.Type()) {
|
||||
newValue = newValue.Convert(value.Type())
|
||||
}
|
||||
|
||||
// assign the new value to the struct field.
|
||||
value.Set(newValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DatetimeDecoder is capable of decoding sqlite INTEGER or TEXT storage classes into
|
||||
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
|
||||
// decide between Unix or UnixNano epoch timestamps.
|
||||
//
|
||||
// TODO(ppacher): update comment about loc parameter and TEXT storage class parsing.
|
||||
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
|
||||
// if we have the column definition available we
|
||||
// use the target go type from there.
|
||||
outType := outval.Type()
|
||||
|
||||
if colDef != nil {
|
||||
outType = colDef.GoType
|
||||
}
|
||||
|
||||
// we only care about "time.Time" here
|
||||
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
|
||||
// log.Printf("not decoding %s %v", outType, colDef)
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // Only selecting specific types.
|
||||
case sqlite.TypeInteger:
|
||||
// stored as unix-epoch, if unixnano is set in the struct field tag
|
||||
// we parse it with nano-second resolution
|
||||
// TODO(ppacher): actually split the tag value at "," and search
|
||||
// the slice for "unixnano"
|
||||
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
|
||||
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), true, nil
|
||||
}
|
||||
|
||||
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), true, nil
|
||||
|
||||
case sqlite.TypeText:
|
||||
// stored ISO8601 but does not have any timezone information
|
||||
// assigned so we always treat it as loc here.
|
||||
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
||||
}
|
||||
|
||||
return t, true, nil
|
||||
|
||||
case sqlite.TypeFloat:
|
||||
// stored as Julian day numbers
|
||||
return nil, false, fmt.Errorf("REAL storage type not support for time.Time")
|
||||
|
||||
case sqlite.TypeNull:
|
||||
return nil, true, nil
|
||||
|
||||
default:
|
||||
return nil, false, fmt.Errorf("unsupported storage type for time.Time: %s", stmt.ColumnType(colIdx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
|
||||
if *mp == nil {
|
||||
*mp = make(map[string]interface{})
|
||||
}
|
||||
|
||||
for i := 0; i < stmt.ColumnCount(); i++ {
|
||||
var x interface{}
|
||||
|
||||
colDef := schema.GetColumnDef(stmt.ColumnName(i))
|
||||
|
||||
outVal := reflect.ValueOf(&x).Elem()
|
||||
fieldType := reflect.StructField{}
|
||||
if colDef != nil {
|
||||
outVal = reflect.New(colDef.GoType).Elem()
|
||||
fieldType = reflect.StructField{
|
||||
Type: colDef.GoType,
|
||||
}
|
||||
}
|
||||
|
||||
val, err := runDecodeHooks(
|
||||
i,
|
||||
colDef,
|
||||
stmt,
|
||||
fieldType,
|
||||
outVal,
|
||||
append(cfg.DecodeHooks, decodeBasic()),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
|
||||
}
|
||||
|
||||
(*mp)[stmt.ColumnName(i)] = val
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeBasic() DecodeFunc {
|
||||
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (result interface{}, handled bool, err error) {
|
||||
valueKind := getKind(outval)
|
||||
colType := stmt.ColumnType(colIdx)
|
||||
colName := stmt.ColumnName(colIdx)
|
||||
|
||||
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
|
||||
|
||||
// if we have the column definition available we
|
||||
// use the target go type from there.
|
||||
if colDef != nil {
|
||||
valueKind = NormalizeKind(colDef.GoType.Kind())
|
||||
|
||||
// if we have a column definition we try to convert the value to
|
||||
// the actual Go-type that was used in the model.
|
||||
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
|
||||
// or that type aliases like (type myInt int) are decoded into myInt instead of int
|
||||
defer func() {
|
||||
if handled {
|
||||
t := reflect.New(colDef.GoType).Elem()
|
||||
|
||||
if result == nil || reflect.ValueOf(result).IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.ValueOf(result).Type().ConvertibleTo(colDef.GoType) {
|
||||
result = reflect.ValueOf(result).Convert(colDef.GoType).Interface()
|
||||
}
|
||||
t.Set(reflect.ValueOf(result))
|
||||
|
||||
result = t.Interface()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// log.Printf("decoding %s into kind %s", colName, valueKind)
|
||||
|
||||
if colType == sqlite.TypeNull {
|
||||
if colDef != nil && colDef.Nullable {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
if colDef != nil && !colDef.Nullable {
|
||||
return reflect.New(colDef.GoType).Elem().Interface(), true, nil
|
||||
}
|
||||
|
||||
if outval.Kind() == reflect.Ptr {
|
||||
return nil, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
switch valueKind { //nolint:exhaustive
|
||||
case reflect.String:
|
||||
if colType != sqlite.TypeText {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnText(colIdx), true, nil
|
||||
|
||||
case reflect.Bool:
|
||||
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
|
||||
// with INTEGER affinity.
|
||||
if colType != sqlite.TypeInteger {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnBool(colIdx), true, nil
|
||||
|
||||
case reflect.Float64:
|
||||
if colType != sqlite.TypeFloat {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnFloat(colIdx), true, nil
|
||||
|
||||
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
|
||||
if colType != sqlite.TypeInteger {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
|
||||
return stmt.ColumnInt(colIdx), true, nil
|
||||
|
||||
case reflect.Slice:
|
||||
if outval.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, false, fmt.Errorf("slices other than []byte for BLOB are not supported")
|
||||
}
|
||||
|
||||
if colType != sqlite.TypeBlob {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
|
||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
}
|
||||
|
||||
return columnValue, true, nil
|
||||
|
||||
case reflect.Interface:
|
||||
var (
|
||||
t reflect.Type
|
||||
x interface{}
|
||||
)
|
||||
switch colType {
|
||||
case sqlite.TypeBlob:
|
||||
t = reflect.TypeOf([]byte{})
|
||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
}
|
||||
x = columnValue
|
||||
|
||||
case sqlite.TypeFloat:
|
||||
t = reflect.TypeOf(float64(0))
|
||||
x = stmt.ColumnFloat(colIdx)
|
||||
|
||||
case sqlite.TypeInteger:
|
||||
t = reflect.TypeOf(int(0))
|
||||
x = stmt.ColumnInt(colIdx)
|
||||
|
||||
case sqlite.TypeText:
|
||||
t = reflect.TypeOf(string(""))
|
||||
x = stmt.ColumnText(colIdx)
|
||||
|
||||
case sqlite.TypeNull:
|
||||
t = nil
|
||||
x = nil
|
||||
|
||||
default:
|
||||
return nil, false, fmt.Errorf("unsupported column type %s", colType)
|
||||
}
|
||||
|
||||
if t == nil {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
target := reflect.New(t).Elem()
|
||||
target.Set(reflect.ValueOf(x))
|
||||
|
||||
return target.Interface(), true, nil
|
||||
|
||||
default:
|
||||
return nil, false, fmt.Errorf("cannot decode into %s", valueKind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sqlColumnName(fieldType reflect.StructField) string {
|
||||
tagValue, hasTag := fieldType.Tag.Lookup("sqlite")
|
||||
if !hasTag {
|
||||
return fieldType.Name
|
||||
}
|
||||
|
||||
parts := strings.Split(tagValue, ",")
|
||||
if parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return fieldType.Name
|
||||
}
|
||||
|
||||
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
|
||||
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
|
||||
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
|
||||
func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
|
||||
for _, fn := range hooks {
|
||||
res, end, err := fn(colIdx, colDef, stmt, fieldDef, outval)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if end {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getKind returns the kind of value but normalized Int, Uint and Float variants.
|
||||
// to their base type.
|
||||
func getKind(val reflect.Value) reflect.Kind {
|
||||
kind := val.Kind()
|
||||
return NormalizeKind(kind)
|
||||
}
|
||||
|
||||
// NormalizeKind returns a normalized kind of the given kind.
|
||||
func NormalizeKind(kind reflect.Kind) reflect.Kind {
|
||||
switch {
|
||||
case kind >= reflect.Int && kind <= reflect.Int64:
|
||||
return reflect.Int
|
||||
case kind >= reflect.Uint && kind <= reflect.Uint64:
|
||||
return reflect.Uint
|
||||
case kind >= reflect.Float32 && kind <= reflect.Float64:
|
||||
return reflect.Float64
|
||||
default:
|
||||
return kind
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultDecodeConfig holds the default decoding configuration.
|
||||
var DefaultDecodeConfig = DecodeConfig{
|
||||
DecodeHooks: []DecodeFunc{
|
||||
DatetimeDecoder(time.UTC),
|
||||
},
|
||||
}
|
||||
570
service/netquery/orm/decoder_test.go
Normal file
570
service/netquery/orm/decoder_test.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
type testStmt struct {
|
||||
columns []string
|
||||
values []interface{}
|
||||
types []sqlite.ColumnType
|
||||
}
|
||||
|
||||
func (ts testStmt) ColumnCount() int { return len(ts.columns) }
|
||||
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
|
||||
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
|
||||
|
||||
// Compile time check.
|
||||
var _ Stmt = new(testStmt)
|
||||
|
||||
type exampleFieldTypes struct {
|
||||
S string
|
||||
I int
|
||||
F float64
|
||||
B bool
|
||||
}
|
||||
|
||||
type examplePointerTypes struct {
|
||||
S *string
|
||||
I *int
|
||||
F *float64
|
||||
B *bool
|
||||
}
|
||||
|
||||
type exampleStructTags struct {
|
||||
S string `sqlite:"col_string"`
|
||||
I int `sqlite:"col_int"`
|
||||
}
|
||||
|
||||
type exampleIntConv struct {
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
I int
|
||||
}
|
||||
|
||||
type exampleBlobTypes struct {
|
||||
B []byte
|
||||
}
|
||||
|
||||
type exampleJSONRawTypes struct {
|
||||
B json.RawMessage
|
||||
}
|
||||
|
||||
type exampleTimeTypes struct {
|
||||
T time.Time
|
||||
TP *time.Time
|
||||
}
|
||||
|
||||
type exampleInterface struct {
|
||||
I interface{}
|
||||
IP *interface{}
|
||||
}
|
||||
|
||||
func (ett *exampleTimeTypes) Equal(other interface{}) bool {
|
||||
oett, ok := other.(*exampleTimeTypes)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
|
||||
}
|
||||
|
||||
type myInt int
|
||||
|
||||
type exampleTimeNano struct {
|
||||
T time.Time `sqlite:",unixnano"`
|
||||
}
|
||||
|
||||
func (etn *exampleTimeNano) Equal(other interface{}) bool {
|
||||
oetn, ok := other.(*exampleTimeNano)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return etn.T.Equal(oetn.T)
|
||||
}
|
||||
|
||||
func TestDecoder(t *testing.T) { //nolint:maintidx,tparallel
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Stmt testStmt
|
||||
ColumnDef []ColumnDef
|
||||
Result interface{}
|
||||
Expected interface{}
|
||||
}{
|
||||
{
|
||||
"Decoding into nil is not allowed",
|
||||
testStmt{
|
||||
columns: nil,
|
||||
values: nil,
|
||||
types: nil,
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Decoding into basic types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into basic types with different order",
|
||||
testStmt{
|
||||
columns: []string{"I", "S", "B", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
"string value",
|
||||
true,
|
||||
1.2,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into basic types with missing values",
|
||||
testStmt{
|
||||
columns: []string{"F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into pointer types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
s := "string value"
|
||||
i := 1
|
||||
f := 1.2
|
||||
b := true
|
||||
|
||||
return &examplePointerTypes{
|
||||
S: &s,
|
||||
I: &i,
|
||||
F: &f,
|
||||
B: &b,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into pointer types with missing values",
|
||||
testStmt{
|
||||
columns: []string{"S", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
s := "string value"
|
||||
b := true
|
||||
|
||||
return &examplePointerTypes{
|
||||
S: &s,
|
||||
B: &b,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into fields with struct tags",
|
||||
testStmt{
|
||||
columns: []string{"col_string", "col_int"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleStructTags{},
|
||||
&exampleStructTags{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into correct int type",
|
||||
testStmt{
|
||||
columns: []string{"I8", "I16", "I32", "I64", "I"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleIntConv{},
|
||||
&exampleIntConv{
|
||||
1, 1, 1, 1, 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling NULL values for basic types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
// we use nil here but actually that does not matter
|
||||
nil,
|
||||
nil,
|
||||
1.0,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
F: 1.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling NULL values for pointer types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
// we use nil here but actually that does not matter
|
||||
nil,
|
||||
nil,
|
||||
1.0,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
f := 1.0
|
||||
|
||||
return &examplePointerTypes{F: &f}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling blob types",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleBlobTypes{},
|
||||
&exampleBlobTypes{
|
||||
B: ([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling blob types as json.RawMessage",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleJSONRawTypes{},
|
||||
&exampleJSONRawTypes{
|
||||
B: (json.RawMessage)("hello world"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling time.Time and pointers to it",
|
||||
testStmt{
|
||||
columns: []string{"T", "TP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
int(refTime.Unix()),
|
||||
int(refTime.Unix()),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleTimeTypes{},
|
||||
&exampleTimeTypes{
|
||||
T: refTime,
|
||||
TP: &refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling time.Time in nano-second resolution (struct tags)",
|
||||
testStmt{
|
||||
columns: []string{"T", "TP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
int(refTime.UnixNano()),
|
||||
int(refTime.UnixNano()),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleTimeNano{},
|
||||
&exampleTimeNano{
|
||||
T: refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into interface",
|
||||
testStmt{
|
||||
columns: []string{"I", "IP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeText,
|
||||
},
|
||||
values: []interface{}{
|
||||
"value1",
|
||||
"value2",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleInterface{},
|
||||
func() interface{} {
|
||||
var x interface{} = "value2"
|
||||
|
||||
return &exampleInterface{
|
||||
I: "value1",
|
||||
IP: &x,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into map[string]interface{}",
|
||||
testStmt{
|
||||
columns: []string{"I", "F", "S", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
1.1,
|
||||
"string value",
|
||||
[]byte("blob value"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"I": 1,
|
||||
"F": 1.1,
|
||||
"S": "string value",
|
||||
"B": []byte("blob value"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding using type-hints",
|
||||
testStmt{
|
||||
columns: []string{"B", "T"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeText,
|
||||
},
|
||||
values: []interface{}{
|
||||
true,
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "B",
|
||||
Type: sqlite.TypeInteger,
|
||||
GoType: reflect.TypeOf(true),
|
||||
},
|
||||
{
|
||||
Name: "T",
|
||||
Type: sqlite.TypeText,
|
||||
GoType: reflect.TypeOf(time.Time{}),
|
||||
IsTime: true,
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"B": true,
|
||||
"T": refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into type aliases",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
[]byte(`{"foo": "bar}`),
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "B",
|
||||
Type: sqlite.TypeBlob,
|
||||
GoType: reflect.TypeOf(json.RawMessage(`{"foo": "bar}`)),
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"B": json.RawMessage(`{"foo": "bar}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into type aliases #2",
|
||||
testStmt{
|
||||
columns: []string{"I"},
|
||||
types: []sqlite.ColumnType{sqlite.TypeInteger},
|
||||
values: []interface{}{
|
||||
10,
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "I",
|
||||
Type: sqlite.TypeInteger,
|
||||
GoType: reflect.TypeOf(myInt(0)),
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"I": myInt(10),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
// log.Println(c.Desc)
|
||||
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
|
||||
if fn, ok := c.Expected.(func() interface{}); ok {
|
||||
c.Expected = fn()
|
||||
}
|
||||
|
||||
if c.Expected == nil {
|
||||
assert.Error(t, err, c.Desc)
|
||||
} else {
|
||||
assert.NoError(t, err, c.Desc)
|
||||
|
||||
if equaler, ok := c.Expected.(interface{ Equal(x interface{}) bool }); ok {
|
||||
assert.True(t, equaler.Equal(c.Result))
|
||||
} else {
|
||||
assert.Equal(t, c.Expected, c.Result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
258
service/netquery/orm/encoder.go
Normal file
258
service/netquery/orm/encoder.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
type (
|
||||
// EncodeFunc is called for each non-basic type during encoding.
|
||||
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
|
||||
|
||||
// EncodeConfig holds encoding functions.
|
||||
EncodeConfig struct {
|
||||
EncodeHooks []EncodeFunc
|
||||
}
|
||||
)
|
||||
|
||||
// ToParamMap returns a map that contains the sqlite compatible value of each struct field of
|
||||
// r using the sqlite column name as a map key. It either uses the name of the
|
||||
// exported struct field or the value of the "sqlite" tag.
|
||||
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) {
|
||||
// make sure we work on a struct type
|
||||
val := reflect.Indirect(reflect.ValueOf(r))
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("%w, got %T", errStructExpected, r)
|
||||
}
|
||||
|
||||
res := make(map[string]interface{}, val.NumField())
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
fieldType := val.Type().Field(i)
|
||||
field := val.Field(i)
|
||||
|
||||
// skip unexported fields
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
colDef, err := getColumnDef(fieldType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
if slices.Contains(skipFields, colDef.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
x, found, err := runEncodeHooks(
|
||||
colDef,
|
||||
fieldType.Type,
|
||||
field,
|
||||
append(
|
||||
cfg.EncodeHooks,
|
||||
encodeBasic(),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
if reflect.Indirect(field).IsValid() {
|
||||
x = reflect.Indirect(field).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
res[keyPrefix+sqlColumnName(fieldType)] = x
|
||||
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// EncodeValue encodes the given value.
|
||||
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
|
||||
fieldValue := reflect.ValueOf(val)
|
||||
fieldType := reflect.TypeOf(val)
|
||||
|
||||
x, found, err := runEncodeHooks(
|
||||
colDef,
|
||||
fieldType,
|
||||
fieldValue,
|
||||
append(
|
||||
cfg.EncodeHooks,
|
||||
encodeBasic(),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
if reflect.Indirect(fieldValue).IsValid() {
|
||||
x = reflect.Indirect(fieldValue).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return x, nil
|
||||
}
|
||||
|
||||
func encodeBasic() EncodeFunc {
|
||||
return func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
||||
kind := valType.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
valType = valType.Elem()
|
||||
kind = valType.Kind()
|
||||
|
||||
if val.IsNil() {
|
||||
if !col.Nullable {
|
||||
// we need to set the zero value here since the column
|
||||
// is not marked as nullable
|
||||
return reflect.New(valType).Elem().Interface(), true, nil
|
||||
}
|
||||
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
switch NormalizeKind(kind) { //nolint:exhaustive
|
||||
case reflect.String,
|
||||
reflect.Float64,
|
||||
reflect.Bool,
|
||||
reflect.Int,
|
||||
reflect.Uint:
|
||||
// sqlite package handles conversion of those types
|
||||
// already
|
||||
return val.Interface(), true, nil
|
||||
|
||||
case reflect.Slice:
|
||||
if valType.Elem().Kind() == reflect.Uint8 {
|
||||
// this is []byte
|
||||
return val.Interface(), true, nil
|
||||
}
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
return nil, false, fmt.Errorf("cannot convert value of kind %s for use in SQLite", kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DatetimeEncoder returns a new datetime encoder for the given time zone.
|
||||
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
||||
// if fieldType holds a pointer we need to dereference the value
|
||||
ft := valType.String()
|
||||
if valType.Kind() == reflect.Ptr {
|
||||
ft = valType.Elem().String()
|
||||
val = reflect.Indirect(val)
|
||||
}
|
||||
|
||||
normalizedKind := NormalizeKind(valType.Kind())
|
||||
|
||||
// we only care about "time.Time" here
|
||||
var t time.Time
|
||||
switch {
|
||||
case ft == "time.Time":
|
||||
// handle the zero time as a NULL.
|
||||
if !val.IsValid() || val.IsZero() {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var ok bool
|
||||
valInterface := val.Interface()
|
||||
t, ok = valInterface.(time.Time)
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
|
||||
}
|
||||
|
||||
case valType.Kind() == reflect.String && colDef.IsTime:
|
||||
var err error
|
||||
t, err = time.Parse(time.RFC3339, val.String())
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
|
||||
}
|
||||
|
||||
case (normalizedKind == reflect.Int || normalizedKind == reflect.Uint || normalizedKind == reflect.Float64) && colDef.IsTime:
|
||||
seconds := int64(0)
|
||||
switch normalizedKind { //nolint:exhaustive // Previous switch case assures these types.
|
||||
case reflect.Int:
|
||||
seconds = val.Int()
|
||||
case reflect.Uint:
|
||||
seconds = int64(val.Uint())
|
||||
case reflect.Float64:
|
||||
seconds = int64(val.Float())
|
||||
}
|
||||
|
||||
t = time.Unix(seconds, 0)
|
||||
|
||||
default:
|
||||
// we don't care ...
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch colDef.Type { //nolint:exhaustive
|
||||
case sqlite.TypeInteger:
|
||||
if colDef.UnixNano {
|
||||
return t.UnixNano(), true, nil
|
||||
}
|
||||
return t.Unix(), true, nil
|
||||
|
||||
case sqlite.TypeText:
|
||||
str := t.In(loc).Format(SqliteTimeFormat)
|
||||
|
||||
return str, true, nil
|
||||
}
|
||||
|
||||
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDef.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
||||
if valType == nil {
|
||||
if !colDef.Nullable {
|
||||
switch colDef.Type { //nolint:exhaustive
|
||||
case sqlite.TypeBlob:
|
||||
return []byte{}, true, nil
|
||||
case sqlite.TypeFloat:
|
||||
return 0.0, true, nil
|
||||
case sqlite.TypeText:
|
||||
return "", true, nil
|
||||
case sqlite.TypeInteger:
|
||||
return 0, true, nil
|
||||
default:
|
||||
return nil, false, fmt.Errorf("unsupported sqlite data type: %s", colDef.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
for _, fn := range hooks {
|
||||
res, end, err := fn(colDef, valType, val)
|
||||
if err != nil {
|
||||
return res, false, err
|
||||
}
|
||||
|
||||
if end {
|
||||
return res, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// DefaultEncodeConfig holds the default encoding configuration.
|
||||
var DefaultEncodeConfig = EncodeConfig{
|
||||
EncodeHooks: []EncodeFunc{
|
||||
DatetimeEncoder(time.UTC),
|
||||
},
|
||||
}
|
||||
260
service/netquery/orm/encoder_test.go
Normal file
260
service/netquery/orm/encoder_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Input interface{}
|
||||
Expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
"Encode basic types",
|
||||
struct {
|
||||
I int
|
||||
F float64
|
||||
S string
|
||||
B []byte
|
||||
}{
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
S: "string",
|
||||
B: ([]byte)("bytes"),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 1,
|
||||
"F": 1.2,
|
||||
"S": "string",
|
||||
"B": ([]byte)("bytes"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Encode using struct tags",
|
||||
struct {
|
||||
I int `sqlite:"col_int"`
|
||||
S string `sqlite:"col_string"`
|
||||
}{
|
||||
I: 1,
|
||||
S: "string value",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"col_int": 1,
|
||||
"col_string": "string value",
|
||||
},
|
||||
},
|
||||
{
|
||||
"Ignore Private fields",
|
||||
struct {
|
||||
I int
|
||||
s string
|
||||
}{
|
||||
I: 1,
|
||||
s: "string value",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle Pointers",
|
||||
struct {
|
||||
I *int
|
||||
S *string
|
||||
}{
|
||||
I: new(int),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 0,
|
||||
"S": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle time.Time types",
|
||||
struct {
|
||||
TinInt time.Time `sqlite:",integer,unixnano"`
|
||||
TinString time.Time `sqlite:",text"`
|
||||
}{
|
||||
TinInt: refTime,
|
||||
TinString: refTime,
|
||||
},
|
||||
map[string]interface{}{
|
||||
"TinInt": refTime.UnixNano(),
|
||||
"TinString": refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle time.Time pointer types",
|
||||
struct {
|
||||
TinInt *time.Time `sqlite:",integer,unixnano"`
|
||||
TinString *time.Time `sqlite:",text"`
|
||||
Tnil1 *time.Time `sqlite:",text"`
|
||||
Tnil2 *time.Time `sqlite:",text"`
|
||||
}{
|
||||
TinInt: &refTime,
|
||||
TinString: &refTime,
|
||||
Tnil1: nil,
|
||||
Tnil2: (*time.Time)(nil),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"TinInt": refTime.UnixNano(),
|
||||
"TinString": refTime.Format(SqliteTimeFormat),
|
||||
"Tnil1": nil,
|
||||
"Tnil2": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Expected, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeValue(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Column ColumnDef
|
||||
Input interface{}
|
||||
Output interface{}
|
||||
}{
|
||||
{
|
||||
"Special value time.Time as text",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
refTime,
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
{
|
||||
"Special value time.Time as unix-epoch",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeInteger,
|
||||
},
|
||||
refTime,
|
||||
refTime.Unix(),
|
||||
},
|
||||
{
|
||||
"Special value time.Time as unixnano-epoch",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeInteger,
|
||||
UnixNano: true,
|
||||
},
|
||||
refTime,
|
||||
refTime.UnixNano(),
|
||||
},
|
||||
{
|
||||
"Special value zero time",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
time.Time{},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Special value zero time pointer",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
new(time.Time),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Special value *time.Time as text",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
&refTime,
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
{
|
||||
"Special value untyped nil",
|
||||
ColumnDef{
|
||||
Nullable: true,
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Special value typed nil",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
(*time.Time)(nil),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Time formated as string",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
refTime.In(time.Local).Format(time.RFC3339),
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
{
|
||||
"Nullable integer",
|
||||
ColumnDef{
|
||||
Type: sqlite.TypeInteger,
|
||||
Nullable: true,
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Not-Null integer",
|
||||
ColumnDef{
|
||||
Name: "test",
|
||||
Type: sqlite.TypeInteger,
|
||||
},
|
||||
nil,
|
||||
0,
|
||||
},
|
||||
{
|
||||
"Not-Null string",
|
||||
ColumnDef{
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Output, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
186
service/netquery/orm/query_runner.go
Normal file
186
service/netquery/orm/query_runner.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
type (
|
||||
// QueryOption can be specified at RunQuery to alter the behavior
|
||||
// of the executed query.
|
||||
QueryOption func(opts *queryOpts)
|
||||
|
||||
queryOpts struct {
|
||||
Transient bool
|
||||
Args []interface{}
|
||||
NamedArgs map[string]interface{}
|
||||
Result interface{}
|
||||
DecodeConfig DecodeConfig
|
||||
Schema TableSchema
|
||||
}
|
||||
)
|
||||
|
||||
// WithTransient marks the query as transient.
|
||||
//
|
||||
// Transient queries will not be cached for later
|
||||
// re-use after they have been prepared.
|
||||
func WithTransient() QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Transient = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithArgs adds a list of arguments for the query. Arguments
|
||||
// are applied in order.
|
||||
//
|
||||
// See SQL Language Expression documentation of SQLite for
|
||||
// details: https://sqlite.org/lang_expr.html
|
||||
func WithArgs(args ...interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Args = args
|
||||
}
|
||||
}
|
||||
|
||||
// WithNamedArgs adds args to the query. The query must used
|
||||
// named argument placeholders. According to the SQLite spec,
|
||||
// arguments must either start with ':', '@' or '$'.
|
||||
//
|
||||
// See SQL Language Expression documentation of SQLite for
|
||||
// details: https://sqlite.org/lang_expr.html
|
||||
func WithNamedArgs(args map[string]interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.NamedArgs = args
|
||||
}
|
||||
}
|
||||
|
||||
// WithSchema returns a query option that adds the given table
|
||||
// schema to the query.
|
||||
func WithSchema(tbl TableSchema) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Schema = tbl
|
||||
}
|
||||
}
|
||||
|
||||
// WithResult sets the result receiver. result is expected to
|
||||
// be a pointer to a slice of struct or map types.
|
||||
//
|
||||
// For decoding DecodeStmt is used to decode each
|
||||
// row into a new slice element. It thus supports special values
|
||||
// like time.Time. See DecodeStmt() and WithDecodeConfig() for
|
||||
// more information.
|
||||
func WithResult(result interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Result = result
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecodeConfig configures the DecodeConfig to use when
|
||||
// calling DecodeStmt to decode each row into the result slice.
|
||||
//
|
||||
// If not specified, DefaultDecodeConfig will be used.
|
||||
func WithDecodeConfig(cfg DecodeConfig) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.DecodeConfig = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// RunQuery executes the query stored in sql against the databased opened in
|
||||
// conn. Please refer to the documentation of QueryOption, especially WithResult()
|
||||
// for more information on how to retrieve the resulting rows.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var result []struct{
|
||||
// Count int `sqlite:"rowCount"`
|
||||
// }
|
||||
//
|
||||
// err := RunQuery(ctx, conn, "SELECT COUNT(*) AS rowCount FROM table", WithResult(&result))
|
||||
// fmt.Println(result[0].Count)
|
||||
func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...QueryOption) error {
|
||||
args := queryOpts{
|
||||
DecodeConfig: DefaultDecodeConfig,
|
||||
}
|
||||
|
||||
for _, fn := range modifiers {
|
||||
fn(&args)
|
||||
}
|
||||
|
||||
opts := &sqlitex.ExecOptions{
|
||||
Args: args.Args,
|
||||
Named: args.NamedArgs,
|
||||
}
|
||||
|
||||
var (
|
||||
sliceVal reflect.Value
|
||||
valElemType reflect.Type
|
||||
)
|
||||
|
||||
if args.Result != nil {
|
||||
target := args.Result
|
||||
outVal := reflect.ValueOf(target)
|
||||
if outVal.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("target must be a pointer, got %T", target)
|
||||
}
|
||||
|
||||
sliceVal = reflect.Indirect(outVal)
|
||||
if !sliceVal.IsValid() || sliceVal.IsNil() {
|
||||
newVal := reflect.Zero(outVal.Type().Elem())
|
||||
sliceVal.Set(newVal)
|
||||
}
|
||||
|
||||
kind := sliceVal.Kind()
|
||||
if kind != reflect.Slice {
|
||||
return fmt.Errorf("target but be pointer to slice, got %T", target)
|
||||
}
|
||||
valType := sliceVal.Type()
|
||||
valElemType = valType.Elem()
|
||||
|
||||
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
currentField := reflect.New(valElemType)
|
||||
|
||||
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||
resultDump := make(map[string]any)
|
||||
|
||||
for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ {
|
||||
name := stmt.ColumnName(colIdx)
|
||||
|
||||
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // TODO: handle type BLOB?
|
||||
case sqlite.TypeText:
|
||||
resultDump[name] = stmt.ColumnText(colIdx)
|
||||
case sqlite.TypeFloat:
|
||||
resultDump[name] = stmt.ColumnFloat(colIdx)
|
||||
case sqlite.TypeInteger:
|
||||
resultDump[name] = stmt.ColumnInt(colIdx)
|
||||
case sqlite.TypeNull:
|
||||
resultDump[name] = "<null>"
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%w: %+v", err, resultDump)
|
||||
}
|
||||
|
||||
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if args.Transient {
|
||||
err = sqlitex.ExecuteTransient(conn, sql, opts)
|
||||
} else {
|
||||
err = sqlitex.Execute(conn, sql, opts)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if args.Result != nil {
|
||||
reflect.Indirect(reflect.ValueOf(args.Result)).Set(sliceVal)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
288
service/netquery/orm/schema_builder.go
Normal file
288
service/netquery/orm/schema_builder.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
var errSkipStructField = errors.New("struct field should be skipped")
|
||||
|
||||
// Struct Tags.
|
||||
var (
|
||||
TagUnixNano = "unixnano"
|
||||
TagPrimaryKey = "primary"
|
||||
TagAutoIncrement = "autoincrement"
|
||||
TagTime = "time"
|
||||
TagNotNull = "not-null"
|
||||
TagNullable = "nullable"
|
||||
TagTypeInt = "integer"
|
||||
TagTypeText = "text"
|
||||
TagTypePrefixVarchar = "varchar"
|
||||
TagTypeBlob = "blob"
|
||||
TagTypeFloat = "float"
|
||||
TagTypePrefixDefault = "default="
|
||||
)
|
||||
|
||||
var sqlTypeMap = map[sqlite.ColumnType]string{
|
||||
sqlite.TypeBlob: "BLOB",
|
||||
sqlite.TypeFloat: "REAL",
|
||||
sqlite.TypeInteger: "INTEGER",
|
||||
sqlite.TypeText: "TEXT",
|
||||
}
|
||||
|
||||
type (
|
||||
// TableSchema defines a SQL table schema.
|
||||
TableSchema struct {
|
||||
Name string
|
||||
Columns []ColumnDef
|
||||
}
|
||||
|
||||
// ColumnDef defines a SQL column.
|
||||
ColumnDef struct { //nolint:maligned
|
||||
Name string
|
||||
Nullable bool
|
||||
Type sqlite.ColumnType
|
||||
GoType reflect.Type
|
||||
Length int
|
||||
PrimaryKey bool
|
||||
AutoIncrement bool
|
||||
UnixNano bool
|
||||
IsTime bool
|
||||
Default any
|
||||
}
|
||||
)
|
||||
|
||||
// GetColumnDef returns the column definition with the given name.
|
||||
func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
|
||||
for _, def := range ts.Columns {
|
||||
if def.Name == name {
|
||||
return &def
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateStatement build the CREATE SQL statement for the table.
|
||||
func (ts TableSchema) CreateStatement(databaseName string, ifNotExists bool) string {
|
||||
sql := "CREATE TABLE"
|
||||
if ifNotExists {
|
||||
sql += " IF NOT EXISTS"
|
||||
}
|
||||
name := ts.Name
|
||||
if databaseName != "" {
|
||||
name = databaseName + "." + ts.Name
|
||||
}
|
||||
|
||||
sql += " " + name + " ( "
|
||||
|
||||
for idx, col := range ts.Columns {
|
||||
sql += col.AsSQL()
|
||||
if idx < len(ts.Columns)-1 {
|
||||
sql += ", "
|
||||
}
|
||||
}
|
||||
|
||||
sql += " );"
|
||||
return sql
|
||||
}
|
||||
|
||||
// AsSQL builds the SQL column definition.
|
||||
func (def ColumnDef) AsSQL() string {
|
||||
sql := def.Name + " "
|
||||
|
||||
if def.Type == sqlite.TypeText && def.Length > 0 {
|
||||
sql += fmt.Sprintf("VARCHAR(%d)", def.Length)
|
||||
} else {
|
||||
sql += sqlTypeMap[def.Type]
|
||||
}
|
||||
|
||||
if def.PrimaryKey {
|
||||
sql += " PRIMARY KEY"
|
||||
}
|
||||
if def.AutoIncrement {
|
||||
sql += " AUTOINCREMENT"
|
||||
}
|
||||
if def.Default != nil {
|
||||
sql += " DEFAULT "
|
||||
switch def.Type { //nolint:exhaustive // TODO: handle types BLOB, NULL?
|
||||
case sqlite.TypeFloat:
|
||||
sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64) //nolint:forcetypeassert
|
||||
case sqlite.TypeInteger:
|
||||
sql += strconv.FormatInt(def.Default.(int64), 10) //nolint:forcetypeassert
|
||||
case sqlite.TypeText:
|
||||
sql += fmt.Sprintf("%q", def.Default.(string)) //nolint:forcetypeassert
|
||||
default:
|
||||
log.Errorf("unsupported default value: %q %q", def.Type, def.Default)
|
||||
sql = strings.TrimSuffix(sql, " DEFAULT ")
|
||||
}
|
||||
sql += " "
|
||||
}
|
||||
if !def.Nullable {
|
||||
sql += " NOT NULL"
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
// GenerateTableSchema generates a table schema from the given struct.
|
||||
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
|
||||
ts := &TableSchema{
|
||||
Name: name,
|
||||
}
|
||||
|
||||
val := reflect.Indirect(reflect.ValueOf(d))
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("%w, got %T", errStructExpected, d)
|
||||
}
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
fieldType := val.Type().Field(i)
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
def, err := getColumnDef(fieldType)
|
||||
if err != nil {
|
||||
if errors.Is(err, errSkipStructField) {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("struct field %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
ts.Columns = append(ts.Columns, *def)
|
||||
}
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
|
||||
def := &ColumnDef{
|
||||
Name: fieldType.Name,
|
||||
Nullable: fieldType.Type.Kind() == reflect.Ptr,
|
||||
}
|
||||
|
||||
ft := fieldType.Type
|
||||
|
||||
if fieldType.Type.Kind() == reflect.Ptr {
|
||||
ft = fieldType.Type.Elem()
|
||||
}
|
||||
|
||||
def.GoType = ft
|
||||
kind := NormalizeKind(ft.Kind())
|
||||
|
||||
switch kind { //nolint:exhaustive
|
||||
case reflect.Int, reflect.Uint:
|
||||
def.Type = sqlite.TypeInteger
|
||||
|
||||
case reflect.Float64:
|
||||
def.Type = sqlite.TypeFloat
|
||||
|
||||
case reflect.String:
|
||||
def.Type = sqlite.TypeText
|
||||
|
||||
case reflect.Slice:
|
||||
// only []byte/[]uint8 is supported
|
||||
if ft.Elem().Kind() != reflect.Uint8 {
|
||||
return nil, fmt.Errorf("slices of type %s is not supported", ft.Elem())
|
||||
}
|
||||
|
||||
def.Type = sqlite.TypeBlob
|
||||
}
|
||||
|
||||
if err := applyStructFieldTag(fieldType, def); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return def, nil
|
||||
}
|
||||
|
||||
// applyStructFieldTag parses the sqlite:"" struct field tag and update the column
|
||||
// definition def accordingly.
|
||||
func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
|
||||
parts := strings.Split(fieldType.Tag.Get("sqlite"), ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
if parts[0] == "-" {
|
||||
return errSkipStructField
|
||||
}
|
||||
|
||||
def.Name = parts[0]
|
||||
}
|
||||
|
||||
if len(parts) > 1 {
|
||||
for _, k := range parts[1:] {
|
||||
switch k {
|
||||
// column modifiers
|
||||
case TagPrimaryKey:
|
||||
def.PrimaryKey = true
|
||||
case TagAutoIncrement:
|
||||
def.AutoIncrement = true
|
||||
case TagNotNull:
|
||||
def.Nullable = false
|
||||
case TagNullable:
|
||||
def.Nullable = true
|
||||
case TagUnixNano:
|
||||
def.UnixNano = true
|
||||
case TagTime:
|
||||
def.IsTime = true
|
||||
|
||||
// basic column types
|
||||
case TagTypeInt:
|
||||
def.Type = sqlite.TypeInteger
|
||||
case TagTypeText:
|
||||
def.Type = sqlite.TypeText
|
||||
case TagTypeFloat:
|
||||
def.Type = sqlite.TypeFloat
|
||||
case TagTypeBlob:
|
||||
def.Type = sqlite.TypeBlob
|
||||
|
||||
// advanced column types
|
||||
default:
|
||||
if strings.HasPrefix(k, TagTypePrefixVarchar) {
|
||||
lenStr := strings.TrimSuffix(strings.TrimPrefix(k, TagTypePrefixVarchar+"("), ")")
|
||||
length, err := strconv.ParseInt(lenStr, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse varchar length %q: %w", lenStr, err)
|
||||
}
|
||||
|
||||
def.Type = sqlite.TypeText
|
||||
def.Length = int(length)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(k, TagTypePrefixDefault) {
|
||||
defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault)
|
||||
switch def.Type { //nolint:exhaustive
|
||||
case sqlite.TypeFloat:
|
||||
fv, err := strconv.ParseFloat(defaultValue, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err)
|
||||
}
|
||||
def.Default = fv
|
||||
case sqlite.TypeInteger:
|
||||
fv, err := strconv.ParseInt(defaultValue, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err)
|
||||
}
|
||||
def.Default = fv
|
||||
case sqlite.TypeText:
|
||||
def.Default = defaultValue
|
||||
case sqlite.TypeBlob:
|
||||
return fmt.Errorf("default values for TypeBlob not yet supported")
|
||||
default:
|
||||
return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
43
service/netquery/orm/schema_builder_test.go
Normal file
43
service/netquery/orm/schema_builder_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSchemaBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
Model interface{}
|
||||
ExpectedSQL string
|
||||
}{
|
||||
{
|
||||
"Simple",
|
||||
struct {
|
||||
ID int `sqlite:"id,primary,autoincrement"`
|
||||
Text string `sqlite:"text,nullable"`
|
||||
Int *int `sqlite:",not-null"`
|
||||
Float interface{} `sqlite:",float,nullable"`
|
||||
}{},
|
||||
`CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
|
||||
},
|
||||
{
|
||||
"Varchar",
|
||||
struct {
|
||||
S string `sqlite:",varchar(10)"`
|
||||
}{},
|
||||
`CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`,
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
c := cases[idx]
|
||||
|
||||
res, err := GenerateTableSchema(c.Name, c.Model)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false))
|
||||
}
|
||||
}
|
||||
675
service/netquery/query.go
Normal file
675
service/netquery/query.go
Normal file
@@ -0,0 +1,675 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/slices"
|
||||
"zombiezen.com/go/sqlite"
|
||||
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
)
|
||||
|
||||
// DatabaseName is a database name constant.
|
||||
type DatabaseName string
|
||||
|
||||
// Databases.
|
||||
const (
|
||||
LiveDatabase = DatabaseName("main")
|
||||
HistoryDatabase = DatabaseName("history")
|
||||
)
|
||||
|
||||
// Collection of Query and Matcher types.
|
||||
// NOTE: whenever adding support for new operators make sure
|
||||
// to update UnmarshalJSON as well.
|
||||
//
|
||||
//nolint:golint
|
||||
type (
|
||||
Query map[string][]Matcher
|
||||
|
||||
MatchType interface {
|
||||
Operator() string
|
||||
}
|
||||
|
||||
Equal interface{}
|
||||
|
||||
Matcher struct {
|
||||
Equal interface{} `json:"$eq,omitempty"`
|
||||
NotEqual interface{} `json:"$ne,omitempty"`
|
||||
In []interface{} `json:"$in,omitempty"`
|
||||
NotIn []interface{} `json:"$notIn,omitempty"`
|
||||
Like string `json:"$like,omitempty"`
|
||||
Greater *float64 `json:"$gt,omitempty"`
|
||||
GreaterOrEqual *float64 `json:"$ge,omitempty"`
|
||||
Less *float64 `json:"$lt,omitempty"`
|
||||
LessOrEqual *float64 `json:"$le,omitempty"`
|
||||
}
|
||||
|
||||
Count struct {
|
||||
As string `json:"as"`
|
||||
Field string `json:"field"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
Sum struct {
|
||||
Condition Query `json:"condition"`
|
||||
Field string `json:"field"`
|
||||
As string `json:"as"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
Min struct {
|
||||
Condition *Query `json:"condition,omitempty"`
|
||||
Field string `json:"field"`
|
||||
As string `json:"as"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
FieldSelect struct {
|
||||
Field string `json:"field"`
|
||||
As string `json:"as"`
|
||||
}
|
||||
|
||||
Select struct {
|
||||
Field string `json:"field"`
|
||||
FieldSelect *FieldSelect `json:"$field"`
|
||||
Count *Count `json:"$count,omitempty"`
|
||||
Sum *Sum `json:"$sum,omitempty"`
|
||||
Min *Min `json:"$min,omitempty"`
|
||||
Distinct *string `json:"$distinct,omitempty"`
|
||||
}
|
||||
|
||||
Selects []Select
|
||||
|
||||
TextSearch struct {
|
||||
Fields []string `json:"fields"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
QueryActiveConnectionChartPayload struct {
|
||||
Query Query `json:"query"`
|
||||
TextSearch *TextSearch `json:"textSearch"`
|
||||
}
|
||||
|
||||
OrderBy struct {
|
||||
Field string `json:"field"`
|
||||
Desc bool `json:"desc"`
|
||||
}
|
||||
|
||||
OrderBys []OrderBy
|
||||
|
||||
Pagination struct {
|
||||
PageSize int `json:"pageSize"`
|
||||
Page int `json:"page"`
|
||||
}
|
||||
)
|
||||
|
||||
// UnmarshalJSON unmarshals a Query from json.
|
||||
func (query *Query) UnmarshalJSON(blob []byte) error {
|
||||
if *query == nil {
|
||||
*query = make(Query)
|
||||
}
|
||||
|
||||
var model map[string]json.RawMessage
|
||||
|
||||
if err := json.Unmarshal(blob, &model); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for columnName, rawColumnQuery := range model {
|
||||
if len(rawColumnQuery) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch rawColumnQuery[0] {
|
||||
case '{':
|
||||
m, err := parseMatcher(rawColumnQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*query)[columnName] = []Matcher{*m}
|
||||
|
||||
case '[':
|
||||
var rawMatchers []json.RawMessage
|
||||
if err := json.Unmarshal(rawColumnQuery, &rawMatchers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*query)[columnName] = make([]Matcher, len(rawMatchers))
|
||||
for idx, val := range rawMatchers {
|
||||
// this should not happen
|
||||
if len(val) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// if val starts with a { we have a matcher definition
|
||||
if val[0] == '{' {
|
||||
m, err := parseMatcher(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*query)[columnName][idx] = *m
|
||||
|
||||
continue
|
||||
} else if val[0] == '[' {
|
||||
return fmt.Errorf("invalid token [ in query for column %s", columnName)
|
||||
}
|
||||
|
||||
// val is a dedicated JSON primitive and not an object or array
|
||||
// so we treat that as an EQUAL condition.
|
||||
var x interface{}
|
||||
if err := json.Unmarshal(val, &x); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*query)[columnName][idx] = Matcher{
|
||||
Equal: x,
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// value is a JSON primitive and not an object or array
|
||||
// so we treat that as an EQUAL condition.
|
||||
var x interface{}
|
||||
if err := json.Unmarshal(rawColumnQuery, &x); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*query)[columnName] = []Matcher{
|
||||
{Equal: x},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(ppacher): right now we only support LIMIT and OFFSET for pagination but that
|
||||
// has an issue that loading the same page twice might yield different results due to
|
||||
// new records shifting the result slice. To overcome this, return a "PageToken" to the
|
||||
// user that includes the time the initial query was created so paginated queries can
|
||||
// ensure new records don't end up in the result set.
|
||||
func (page *Pagination) toSQLLimitOffsetClause() string {
|
||||
limit := page.PageSize
|
||||
|
||||
// default and cap the limit to at most 100 items
|
||||
// per page to avoid out-of-memory conditions when loading
|
||||
// thousands of results at once.
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf("LIMIT %d", limit)
|
||||
|
||||
if page.Page > 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", page.Page*limit)
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func parseMatcher(raw json.RawMessage) (*Matcher, error) {
|
||||
var m Matcher
|
||||
if err := json.Unmarshal(raw, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid query matcher: %w", err)
|
||||
}
|
||||
|
||||
// log.Printf("parsed matcher %s: %+v", string(raw), m)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// Validate validates the matcher.
|
||||
func (match Matcher) Validate() error {
|
||||
found := 0
|
||||
|
||||
if match.Equal != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.NotEqual != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.In != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.NotIn != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.Like != "" {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.Greater != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.GreaterOrEqual != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.Less != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.LessOrEqual != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if found == 0 {
|
||||
return fmt.Errorf("no conditions specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
var (
|
||||
queryParts = make([]string, 0, len(text.Fields))
|
||||
params = make(map[string]interface{})
|
||||
)
|
||||
|
||||
key := fmt.Sprintf(":t%s", suffix)
|
||||
params[key] = fmt.Sprintf("%%%s%%", text.Value)
|
||||
|
||||
for _, field := range text.Fields {
|
||||
colDef := schema.GetColumnDef(field)
|
||||
if colDef == nil {
|
||||
return "", nil, fmt.Errorf("column %s is not allowed in text-search", colDef.Name)
|
||||
}
|
||||
if colDef.Type != sqlite.TypeText {
|
||||
return "", nil, fmt.Errorf("type of column %s cannot be used in text-search", colDef.Name)
|
||||
}
|
||||
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s LIKE %s", colDef.Name, key))
|
||||
}
|
||||
|
||||
if len(queryParts) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
return "( " + strings.Join(queryParts, " OR ") + " )", params, nil
|
||||
}
|
||||
|
||||
func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
var (
|
||||
queryParts []string
|
||||
params = make(map[string]interface{})
|
||||
errs = new(multierror.Error)
|
||||
key = fmt.Sprintf("%s%s", colDef.Name, suffix)
|
||||
)
|
||||
|
||||
add := func(operator, suffix string, list bool, values ...interface{}) {
|
||||
var placeholder []string
|
||||
|
||||
for idx, value := range values {
|
||||
var (
|
||||
encodedValue any
|
||||
err error
|
||||
)
|
||||
|
||||
kind := orm.NormalizeKind(reflect.TypeOf(value).Kind())
|
||||
isNumber := slices.Contains([]reflect.Kind{
|
||||
reflect.Uint,
|
||||
reflect.Int,
|
||||
reflect.Float64,
|
||||
}, kind)
|
||||
|
||||
// if we query a time-field that is queried as a number, don't do any encoding
|
||||
// here as the orm.DateTimeEncoder would convert the number to a string.
|
||||
if colDef.IsTime && colDef.Type == sqlite.TypeText && isNumber {
|
||||
encodedValue = value
|
||||
} else {
|
||||
encodedValue, err = orm.EncodeValue(ctx, &colDef, value, encoderConfig)
|
||||
if err != nil {
|
||||
errs.Errors = append(errs.Errors,
|
||||
fmt.Errorf("failed to encode %v for column %s: %w", value, colDef.Name, err),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
uniqKey := fmt.Sprintf(":%s%s%d", key, suffix, idx)
|
||||
placeholder = append(placeholder, uniqKey)
|
||||
params[uniqKey] = encodedValue
|
||||
}
|
||||
|
||||
nameStmt := colDef.Name
|
||||
|
||||
if len(values) > 0 {
|
||||
// NOTE(ppacher): for now we assume that the type of each element of values
|
||||
// is the same. We also can be sure that there is always at least one value.
|
||||
//
|
||||
// TODO(ppacher): if we start supporting values of different types here
|
||||
// we need to revisit the whole behavior as we might need to do more boolean
|
||||
// expression nesting to support that.
|
||||
kind := orm.NormalizeKind(reflect.TypeOf(values[0]).Kind())
|
||||
isNumber := slices.Contains([]reflect.Kind{
|
||||
reflect.Uint,
|
||||
reflect.Int,
|
||||
reflect.Float64,
|
||||
}, kind)
|
||||
|
||||
// if this is a time column that is stored in "text" format and the provided
|
||||
// value is a number type, we need to wrap the property in a strftime() method
|
||||
// call.
|
||||
if colDef.IsTime && colDef.Type == sqlite.TypeText && isNumber {
|
||||
nameStmt = fmt.Sprintf("strftime('%%s', %s)+0", nameStmt)
|
||||
}
|
||||
}
|
||||
|
||||
if len(placeholder) == 1 && !list {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", nameStmt, operator, placeholder[0]))
|
||||
} else {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", nameStmt, operator, strings.Join(placeholder, ", ")))
|
||||
}
|
||||
}
|
||||
|
||||
if match.Equal != nil {
|
||||
add("=", "eq", false, match.Equal)
|
||||
}
|
||||
|
||||
if match.NotEqual != nil {
|
||||
add("!=", "ne", false, match.NotEqual)
|
||||
}
|
||||
|
||||
if match.In != nil {
|
||||
add("IN", "in", true, match.In...)
|
||||
}
|
||||
|
||||
if match.NotIn != nil {
|
||||
add("NOT IN", "notin", true, match.NotIn...)
|
||||
}
|
||||
|
||||
if match.Like != "" {
|
||||
add("LIKE", "like", false, match.Like)
|
||||
}
|
||||
|
||||
if match.Greater != nil {
|
||||
add(">", "gt", false, *match.Greater)
|
||||
}
|
||||
|
||||
if match.GreaterOrEqual != nil {
|
||||
add(">=", "ge", false, *match.GreaterOrEqual)
|
||||
}
|
||||
|
||||
if match.Less != nil {
|
||||
add("<", "lt", false, *match.Less)
|
||||
}
|
||||
|
||||
if match.LessOrEqual != nil {
|
||||
add("<=", "le", false, *match.LessOrEqual)
|
||||
}
|
||||
|
||||
if len(queryParts) == 0 {
|
||||
// this is an empty matcher without a single condition.
|
||||
// we convert that to a no-op TRUE value
|
||||
return "( 1 = 1 )", nil, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
if len(queryParts) == 1 {
|
||||
return queryParts[0], params, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
if len(query) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// create a lookup map to validate column names
|
||||
lm := make(map[string]orm.ColumnDef, len(m.Columns))
|
||||
for _, col := range m.Columns {
|
||||
lm[col.Name] = col
|
||||
}
|
||||
|
||||
paramMap := make(map[string]interface{})
|
||||
columnStmts := make([]string, 0, len(query))
|
||||
|
||||
// get all keys and sort them so we get a stable output
|
||||
queryKeys := make([]string, 0, len(query))
|
||||
for column := range query {
|
||||
queryKeys = append(queryKeys, column)
|
||||
}
|
||||
sort.Strings(queryKeys)
|
||||
|
||||
// actually create the WHERE clause parts for each
|
||||
// column in query.
|
||||
errs := new(multierror.Error)
|
||||
for _, column := range queryKeys {
|
||||
values := query[column]
|
||||
colDef, ok := lm[column]
|
||||
if !ok {
|
||||
|
||||
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
queryParts := make([]string, len(values))
|
||||
for idx, val := range values {
|
||||
matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig)
|
||||
if err != nil {
|
||||
errs.Errors = append(errs.Errors,
|
||||
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// merge parameters up into the superior parameter map
|
||||
for key, val := range params {
|
||||
if _, ok := paramMap[key]; ok {
|
||||
// This is solely a developer mistake when implementing a matcher so no forgiving ...
|
||||
panic("sqlite parameter collision")
|
||||
}
|
||||
|
||||
paramMap[key] = val
|
||||
}
|
||||
|
||||
queryParts[idx] = matcherQuery
|
||||
}
|
||||
|
||||
columnStmts = append(columnStmts,
|
||||
fmt.Sprintf("( %s )", strings.Join(queryParts, " OR ")),
|
||||
)
|
||||
}
|
||||
|
||||
whereClause := strings.Join(columnStmts, " AND ")
|
||||
|
||||
return whereClause, paramMap, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a Selects from json.
|
||||
func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// if we are looking at a slice directly decode into
|
||||
// a []Select
|
||||
if blob[0] == '[' {
|
||||
var result []Select
|
||||
if err := json.Unmarshal(blob, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*sel) = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// if it's an object decode into a single select
|
||||
if blob[0] == '{' {
|
||||
var result Select
|
||||
if err := json.Unmarshal(blob, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*sel = []Select{result}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// otherwise this is just the field name
|
||||
var field string
|
||||
if err := json.Unmarshal(blob, &field); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a Select from json.
|
||||
func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// if we have an object at hand decode the select
|
||||
// directly
|
||||
if blob[0] == '{' {
|
||||
var res struct {
|
||||
Field string `json:"field"`
|
||||
Count *Count `json:"$count"`
|
||||
Sum *Sum `json:"$sum"`
|
||||
Min *Min `json:"$min"`
|
||||
Distinct *string `json:"$distinct"`
|
||||
FieldSelect *FieldSelect `json:"$field"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(blob, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sel.Count = res.Count
|
||||
sel.Field = res.Field
|
||||
sel.FieldSelect = res.FieldSelect
|
||||
sel.Distinct = res.Distinct
|
||||
sel.Sum = res.Sum
|
||||
sel.Min = res.Min
|
||||
|
||||
if sel.Count != nil && sel.Count.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Count.As) {
|
||||
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
if sel.Sum != nil && sel.Sum.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Sum.As) {
|
||||
return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
if sel.Min != nil && sel.Min.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Min.As) {
|
||||
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
if sel.FieldSelect != nil && sel.FieldSelect.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.FieldSelect.As) {
|
||||
return fmt.Errorf("invalid characters in $field.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var x string
|
||||
if err := json.Unmarshal(blob, &x); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sel.Field = x
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a OrderBys from json.
|
||||
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if blob[0] == '[' {
|
||||
var result []OrderBy
|
||||
if err := json.Unmarshal(blob, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*orderBys = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if blob[0] == '{' {
|
||||
var result OrderBy
|
||||
if err := json.Unmarshal(blob, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*orderBys = []OrderBy{result}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var field string
|
||||
if err := json.Unmarshal(blob, &field); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*orderBys = []OrderBy{
|
||||
{
|
||||
Field: field,
|
||||
Desc: false,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a OrderBy from json.
|
||||
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if blob[0] == '{' {
|
||||
var res struct {
|
||||
Field string `json:"field"`
|
||||
Desc bool `json:"desc"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(blob, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
orderBy.Desc = res.Desc
|
||||
orderBy.Field = res.Field
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var field string
|
||||
if err := json.Unmarshal(blob, &field); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
orderBy.Field = field
|
||||
orderBy.Desc = false
|
||||
|
||||
return nil
|
||||
}
|
||||
249
service/netquery/query_handler.go
Normal file
249
service/netquery/query_handler.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
servertiming "github.com/mitchellh/go-server-timing"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
)
|
||||
|
||||
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
|
||||
|
||||
type (
|
||||
|
||||
// QueryHandler implements http.Handler and allows to perform SQL
|
||||
// query and aggregate functions on Database.
|
||||
QueryHandler struct {
|
||||
IsDevMode func() bool
|
||||
Database *Database
|
||||
}
|
||||
|
||||
// BatchQueryHandler implements http.Handler and allows to perform SQL
|
||||
// query and aggregate functions on Database in batches.
|
||||
BatchQueryHandler struct {
|
||||
IsDevMode func() bool
|
||||
Database *Database
|
||||
}
|
||||
)
|
||||
|
||||
func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||
timing := servertiming.FromContext(req.Context())
|
||||
|
||||
timingQueryParsed := timing.NewMetric("query_parsed").
|
||||
WithDesc("Query has been parsed").
|
||||
Start()
|
||||
|
||||
requestPayload, err := parseQueryRequestPayload[QueryRequestPayload](req)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
timingQueryParsed.Stop()
|
||||
|
||||
timingQueryBuilt := timing.NewMetric("query_built").
|
||||
WithDesc("The SQL query has been built").
|
||||
Start()
|
||||
|
||||
query, paramMap, err := requestPayload.generateSQL(req.Context(), qh.Database.Schema)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
timingQueryBuilt.Stop()
|
||||
|
||||
timingQueryExecute := timing.NewMetric("sql_exec").
|
||||
WithDesc("SQL query execution time").
|
||||
Start()
|
||||
|
||||
// actually execute the query against the database and collect the result
|
||||
var result []map[string]interface{}
|
||||
if err := qh.Database.Execute(
|
||||
req.Context(),
|
||||
query,
|
||||
orm.WithNamedArgs(paramMap),
|
||||
orm.WithResult(&result),
|
||||
orm.WithSchema(*qh.Database.Schema),
|
||||
); err != nil {
|
||||
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
timingQueryExecute.Stop()
|
||||
|
||||
// send the HTTP status code
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
|
||||
// prepare the result encoder.
|
||||
enc := json.NewEncoder(resp)
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
// prepare the result body that, in dev mode, contains
|
||||
// some diagnostics data about the query
|
||||
var resultBody map[string]interface{}
|
||||
if qh.IsDevMode() {
|
||||
resultBody = map[string]interface{}{
|
||||
"sql_prep_stmt": query,
|
||||
"sql_params": paramMap,
|
||||
"query": requestPayload.Query,
|
||||
"orderBy": requestPayload.OrderBy,
|
||||
"groupBy": requestPayload.GroupBy,
|
||||
"selects": requestPayload.Select,
|
||||
}
|
||||
} else {
|
||||
resultBody = make(map[string]interface{})
|
||||
}
|
||||
resultBody["results"] = result
|
||||
|
||||
// and finally stream the response
|
||||
if err := enc.Encode(resultBody); err != nil {
|
||||
// we failed to encode the JSON body to resp so we likely either already sent a
|
||||
// few bytes or the pipe was already closed. In either case, trying to send the
|
||||
// error using http.Error() is non-sense. We just log it out here and that's all
|
||||
// we can do.
|
||||
log.Errorf("failed to encode JSON response: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (batch *BatchQueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||
timing := servertiming.FromContext(req.Context())
|
||||
|
||||
timingQueryParsed := timing.NewMetric("query_parsed").
|
||||
WithDesc("Query has been parsed").
|
||||
Start()
|
||||
|
||||
requestPayload, err := parseQueryRequestPayload[BatchQueryRequestPayload](req)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
timingQueryParsed.Stop()
|
||||
|
||||
response := make(map[string][]map[string]any, len(*requestPayload))
|
||||
|
||||
batches := make([]BatchExecute, 0, len(*requestPayload))
|
||||
|
||||
for key, query := range *requestPayload {
|
||||
|
||||
timingQueryBuilt := timing.NewMetric("query_built_" + key).
|
||||
WithDesc("The SQL query has been built").
|
||||
Start()
|
||||
|
||||
sql, paramMap, err := query.generateSQL(req.Context(), batch.Database.Schema)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
timingQueryBuilt.Stop()
|
||||
|
||||
var result []map[string]any
|
||||
batches = append(batches, BatchExecute{
|
||||
ID: key,
|
||||
SQL: sql,
|
||||
Params: paramMap,
|
||||
Result: &result,
|
||||
})
|
||||
}
|
||||
|
||||
timingQueryExecute := timing.NewMetric("sql_exec").
|
||||
WithDesc("SQL query execution time").
|
||||
Start()
|
||||
|
||||
status := http.StatusOK
|
||||
if err := batch.Database.ExecuteBatch(req.Context(), batches); err != nil {
|
||||
status = http.StatusInternalServerError
|
||||
|
||||
var merr *multierror.Error
|
||||
if errors.As(err, &merr) {
|
||||
for _, e := range merr.Errors {
|
||||
resp.Header().Add("X-Query-Error", e.Error())
|
||||
}
|
||||
} else {
|
||||
// Should not happen, ExecuteBatch always returns a multierror.Error
|
||||
resp.WriteHeader(status)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
timingQueryExecute.Stop()
|
||||
|
||||
// collect the results
|
||||
for _, b := range batches {
|
||||
response[b.ID] = *b.Result
|
||||
}
|
||||
|
||||
// send the HTTP status code
|
||||
resp.WriteHeader(status)
|
||||
|
||||
// prepare the result encoder.
|
||||
enc := json.NewEncoder(resp)
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
// and finally stream the response
|
||||
if err := enc.Encode(response); err != nil {
|
||||
// we failed to encode the JSON body to resp so we likely either already sent a
|
||||
// few bytes or the pipe was already closed. In either case, trying to send the
|
||||
// error using http.Error() is non-sense. We just log it out here and that's all
|
||||
// we can do.
|
||||
log.Errorf("failed to encode JSON response: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func parseQueryRequestPayload[T any](req *http.Request) (*T, error) { //nolint:dupl
|
||||
var (
|
||||
body io.Reader
|
||||
requestPayload T
|
||||
)
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodPost, http.MethodPut:
|
||||
body = req.Body
|
||||
case http.MethodGet:
|
||||
body = strings.NewReader(req.URL.Query().Get("q"))
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid HTTP method")
|
||||
}
|
||||
|
||||
blob, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body" + err.Error())
|
||||
}
|
||||
|
||||
body = bytes.NewReader(blob)
|
||||
|
||||
dec := json.NewDecoder(body)
|
||||
dec.DisallowUnknownFields()
|
||||
|
||||
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
return &requestPayload, nil
|
||||
}
|
||||
|
||||
// Compile time check.
|
||||
var _ http.Handler = new(QueryHandler)
|
||||
320
service/netquery/query_request.go
Normal file
320
service/netquery/query_request.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
)
|
||||
|
||||
type (
|
||||
// QueryRequestPayload describes the payload of a netquery query.
|
||||
QueryRequestPayload struct {
|
||||
Select Selects `json:"select"`
|
||||
Query Query `json:"query"`
|
||||
OrderBy OrderBys `json:"orderBy"`
|
||||
GroupBy []string `json:"groupBy"`
|
||||
TextSearch *TextSearch `json:"textSearch"`
|
||||
// A list of databases to query. If left empty,
|
||||
// both, the LiveDatabase and the HistoryDatabase are queried
|
||||
Databases []DatabaseName `json:"databases"`
|
||||
|
||||
Pagination
|
||||
|
||||
selectedFields []string
|
||||
whitelistedFields []string
|
||||
paramMap map[string]interface{}
|
||||
}
|
||||
|
||||
// BatchQueryRequestPayload describes the payload of a batch netquery
|
||||
// query. The map key is used in the response to identify the results
|
||||
// for each query of the batch request.
|
||||
BatchQueryRequestPayload map[string]QueryRequestPayload
|
||||
)
|
||||
|
||||
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||
if err := req.prepareSelectedFields(ctx, schema); err != nil {
|
||||
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
|
||||
}
|
||||
|
||||
// build the SQL where clause from the payload query
|
||||
whereClause, paramMap, err := req.Query.toSQLWhereClause(
|
||||
ctx,
|
||||
"",
|
||||
schema,
|
||||
orm.DefaultEncodeConfig,
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generating where clause: %w", err)
|
||||
}
|
||||
|
||||
req.mergeParams(paramMap)
|
||||
|
||||
if req.TextSearch != nil {
|
||||
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generating text-search clause: %w", err)
|
||||
}
|
||||
|
||||
if textClause != "" {
|
||||
if whereClause != "" {
|
||||
whereClause += " AND "
|
||||
}
|
||||
|
||||
whereClause += textClause
|
||||
|
||||
req.mergeParams(textParams)
|
||||
}
|
||||
}
|
||||
|
||||
groupByClause, err := req.generateGroupByClause(schema)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generating group-by clause: %w", err)
|
||||
}
|
||||
|
||||
orderByClause, err := req.generateOrderByClause(schema)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generating order-by clause: %w", err)
|
||||
}
|
||||
|
||||
selectClause := req.generateSelectClause()
|
||||
|
||||
if whereClause != "" {
|
||||
whereClause = "WHERE " + whereClause
|
||||
}
|
||||
|
||||
// if no database is specified we default to LiveDatabase only.
|
||||
if len(req.Databases) == 0 {
|
||||
req.Databases = []DatabaseName{LiveDatabase}
|
||||
}
|
||||
|
||||
sources := make([]string, len(req.Databases))
|
||||
for idx, db := range req.Databases {
|
||||
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
|
||||
}
|
||||
|
||||
source := strings.Join(sources, " UNION ")
|
||||
|
||||
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
|
||||
|
||||
query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause()
|
||||
|
||||
return strings.TrimSpace(query), req.paramMap, nil
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
|
||||
for idx, s := range req.Select {
|
||||
var field string
|
||||
|
||||
switch {
|
||||
case s.Count != nil:
|
||||
field = s.Count.Field
|
||||
case s.Distinct != nil:
|
||||
field = *s.Distinct
|
||||
case s.Sum != nil:
|
||||
if s.Sum.Field != "" {
|
||||
field = s.Sum.Field
|
||||
} else {
|
||||
field = "*"
|
||||
}
|
||||
case s.Min != nil:
|
||||
if s.Min.Field != "" {
|
||||
field = s.Min.Field
|
||||
} else {
|
||||
field = "*"
|
||||
}
|
||||
case s.FieldSelect != nil:
|
||||
field = s.FieldSelect.Field
|
||||
default:
|
||||
field = s.Field
|
||||
}
|
||||
|
||||
colName := "*"
|
||||
if field != "*" || (s.Count == nil && s.Sum == nil) {
|
||||
var err error
|
||||
|
||||
colName, err = req.validateColumnName(schema, field)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case s.FieldSelect != nil:
|
||||
as := s.FieldSelect.As
|
||||
if as == "" {
|
||||
as = s.FieldSelect.Field
|
||||
}
|
||||
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("%s AS %s", s.FieldSelect.Field, as),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, as)
|
||||
|
||||
case s.Count != nil:
|
||||
as := s.Count.As
|
||||
if as == "" {
|
||||
as = fmt.Sprintf("%s_count", colName)
|
||||
}
|
||||
distinct := ""
|
||||
if s.Count.Distinct {
|
||||
distinct = "DISTINCT "
|
||||
}
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, as)
|
||||
|
||||
case s.Sum != nil:
|
||||
if s.Sum.As == "" {
|
||||
return fmt.Errorf("missing 'as' for $sum")
|
||||
}
|
||||
|
||||
var (
|
||||
clause string
|
||||
params map[string]any
|
||||
)
|
||||
|
||||
if s.Sum.Field != "" {
|
||||
clause = s.Sum.Field
|
||||
} else {
|
||||
var err error
|
||||
clause, params, err = s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in $sum: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req.mergeParams(params)
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
|
||||
|
||||
case s.Min != nil:
|
||||
if s.Min.As == "" {
|
||||
return fmt.Errorf("missing 'as' for $min")
|
||||
}
|
||||
|
||||
var (
|
||||
clause string
|
||||
params map[string]any
|
||||
)
|
||||
|
||||
if s.Min.Field != "" {
|
||||
clause = field
|
||||
} else {
|
||||
var err error
|
||||
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in $min: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req.mergeParams(params)
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
|
||||
|
||||
case s.Distinct != nil:
|
||||
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
|
||||
req.whitelistedFields = append(req.whitelistedFields, colName)
|
||||
|
||||
default:
|
||||
req.selectedFields = append(req.selectedFields, colName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
|
||||
if req.paramMap == nil {
|
||||
req.paramMap = make(map[string]any)
|
||||
}
|
||||
|
||||
for key, value := range params {
|
||||
req.paramMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
|
||||
if len(req.GroupBy) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
groupBys := make([]string, len(req.GroupBy))
|
||||
for idx, name := range req.GroupBy {
|
||||
colName, err := req.validateColumnName(schema, name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
groupBys[idx] = colName
|
||||
}
|
||||
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
|
||||
|
||||
// if there are no explicitly selected fields we default to the
|
||||
// group-by columns as that's what's expected most of the time anyway...
|
||||
if len(req.selectedFields) == 0 {
|
||||
req.selectedFields = append(req.selectedFields, groupBys...)
|
||||
}
|
||||
|
||||
return groupByClause, nil
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) generateSelectClause() string {
|
||||
selectClause := "*"
|
||||
if len(req.selectedFields) > 0 {
|
||||
selectClause = strings.Join(req.selectedFields, ", ")
|
||||
}
|
||||
|
||||
return selectClause
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
|
||||
if len(req.OrderBy) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
orderBys := make([]string, len(req.OrderBy))
|
||||
for idx, sort := range req.OrderBy {
|
||||
colName, err := req.validateColumnName(schema, sort.Field)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if sort.Desc {
|
||||
orderBys[idx] = fmt.Sprintf("%s DESC", colName)
|
||||
} else {
|
||||
orderBys[idx] = fmt.Sprintf("%s ASC", colName)
|
||||
}
|
||||
}
|
||||
|
||||
return "ORDER BY " + strings.Join(orderBys, ", "), nil
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) {
|
||||
colDef := schema.GetColumnDef(field)
|
||||
if colDef != nil {
|
||||
return colDef.Name, nil
|
||||
}
|
||||
|
||||
if slices.Contains(req.whitelistedFields, field) {
|
||||
return field, nil
|
||||
}
|
||||
|
||||
if slices.Contains(req.selectedFields, field) {
|
||||
return field, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("column name %q not allowed", field)
|
||||
}
|
||||
250
service/netquery/query_test.go
Normal file
250
service/netquery/query_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
)
|
||||
|
||||
func TestUnmarshalQuery(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
Input string
|
||||
Expected Query
|
||||
Error error
|
||||
}{
|
||||
{
|
||||
"Parse a simple query",
|
||||
`{ "domain": ["example.com", "example.at"] }`,
|
||||
Query{
|
||||
"domain": []Matcher{
|
||||
{
|
||||
Equal: "example.com",
|
||||
},
|
||||
{
|
||||
Equal: "example.at",
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Parse a more complex query",
|
||||
`
|
||||
{
|
||||
"domain": [
|
||||
{
|
||||
"$in": [
|
||||
"example.at",
|
||||
"example.com"
|
||||
]
|
||||
},
|
||||
{
|
||||
"$like": "microsoft.%"
|
||||
}
|
||||
],
|
||||
"path": [
|
||||
"/bin/ping",
|
||||
{
|
||||
"$notin": [
|
||||
"/sbin/ping",
|
||||
"/usr/sbin/ping"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
`,
|
||||
Query{
|
||||
"domain": []Matcher{
|
||||
{
|
||||
In: []interface{}{
|
||||
"example.at",
|
||||
"example.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
Like: "microsoft.%",
|
||||
},
|
||||
},
|
||||
"path": []Matcher{
|
||||
{
|
||||
Equal: "/bin/ping",
|
||||
},
|
||||
{
|
||||
NotIn: []interface{}{
|
||||
"/sbin/ping",
|
||||
"/usr/sbin/ping",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range cases { //nolint:paralleltest
|
||||
c := testCase
|
||||
t.Run(c.Name, func(t *testing.T) {
|
||||
var q Query
|
||||
err := json.Unmarshal([]byte(c.Input), &q)
|
||||
|
||||
if c.Error != nil {
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, c.Error.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Expected, q)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryBuilder(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
cases := []struct {
|
||||
N string
|
||||
Q Query
|
||||
R string
|
||||
P map[string]interface{}
|
||||
E error
|
||||
}{
|
||||
{
|
||||
"No filter",
|
||||
nil,
|
||||
"",
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Simple, one-column filter",
|
||||
Query{"domain": []Matcher{
|
||||
{
|
||||
Equal: "example.com",
|
||||
},
|
||||
{
|
||||
Equal: "example.at",
|
||||
},
|
||||
}},
|
||||
"( domain = :domain0eq0 OR domain = :domain1eq0 )",
|
||||
map[string]interface{}{
|
||||
":domain0eq0": "example.com",
|
||||
":domain1eq0": "example.at",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Two column filter",
|
||||
Query{
|
||||
"domain": []Matcher{
|
||||
{
|
||||
Equal: "example.com",
|
||||
},
|
||||
},
|
||||
"path": []Matcher{
|
||||
{
|
||||
Equal: "/bin/curl",
|
||||
},
|
||||
{
|
||||
Equal: "/bin/ping",
|
||||
},
|
||||
},
|
||||
},
|
||||
"( domain = :domain0eq0 ) AND ( path = :path0eq0 OR path = :path1eq0 )",
|
||||
map[string]interface{}{
|
||||
":domain0eq0": "example.com",
|
||||
":path0eq0": "/bin/curl",
|
||||
":path1eq0": "/bin/ping",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Time based filter",
|
||||
Query{
|
||||
"started": []Matcher{
|
||||
{
|
||||
Equal: now.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
"( started = :started0eq0 )",
|
||||
map[string]interface{}{
|
||||
":started0eq0": now.In(time.UTC).Format(orm.SqliteTimeFormat),
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Invalid column access",
|
||||
Query{
|
||||
"forbiddenField": []Matcher{{}},
|
||||
},
|
||||
"",
|
||||
nil,
|
||||
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), //nolint:golint
|
||||
},
|
||||
{
|
||||
"Complex example",
|
||||
Query{
|
||||
"domain": []Matcher{
|
||||
{
|
||||
In: []interface{}{"example.at", "example.com"},
|
||||
},
|
||||
{
|
||||
Like: "microsoft.%",
|
||||
},
|
||||
},
|
||||
"path": []Matcher{
|
||||
{
|
||||
NotIn: []interface{}{
|
||||
"/bin/ping",
|
||||
"/sbin/ping",
|
||||
"/usr/bin/ping",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"( domain IN ( :domain0in0, :domain0in1 ) OR domain LIKE :domain1like0 ) AND ( path NOT IN ( :path0notin0, :path0notin1, :path0notin2 ) )",
|
||||
map[string]interface{}{
|
||||
":domain0in0": "example.at",
|
||||
":domain0in1": "example.com",
|
||||
":domain1like0": "microsoft.%",
|
||||
":path0notin0": "/bin/ping",
|
||||
":path0notin1": "/sbin/ping",
|
||||
":path0notin2": "/usr/bin/ping",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
tbl, err := orm.GenerateTableSchema("connections", Conn{})
|
||||
require.NoError(t, err)
|
||||
|
||||
for idx, testCase := range cases { //nolint:paralleltest
|
||||
cID := idx
|
||||
c := testCase
|
||||
t.Run(c.N, func(t *testing.T) {
|
||||
str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
|
||||
|
||||
if c.E != nil {
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", cID)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err, "test case %d", cID)
|
||||
assert.Equal(t, c.P, params, "test case %d", cID)
|
||||
assert.Equal(t, c.R, str, "test case %d", cID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
77
service/netquery/runtime_query_runner.go
Normal file
77
service/netquery/runtime_query_runner.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/runtime"
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
)
|
||||
|
||||
// RuntimeQueryRunner provides a simple interface for the runtime database
|
||||
// that allows direct SQL queries to be performed against db.
|
||||
// Each resulting row of that query are marshaled as map[string]interface{}
|
||||
// and returned as a single record to the caller.
|
||||
//
|
||||
// Using portbase/database#Query is not possible because portbase/database will
|
||||
// complain about the SQL query being invalid. To work around that issue,
|
||||
// RuntimeQueryRunner uses a 'GET key' request where the SQL query is embedded into
|
||||
// the record key.
|
||||
type RuntimeQueryRunner struct {
|
||||
db *Database
|
||||
reg *runtime.Registry
|
||||
keyPrefix string
|
||||
}
|
||||
|
||||
// NewRuntimeQueryRunner returns a new runtime SQL query runner that parses
|
||||
// and serves SQL queries form GET <prefix>/<plain sql query> requests.
|
||||
func NewRuntimeQueryRunner(db *Database, prefix string, reg *runtime.Registry) (*RuntimeQueryRunner, error) {
|
||||
runner := &RuntimeQueryRunner{
|
||||
db: db,
|
||||
reg: reg,
|
||||
keyPrefix: prefix,
|
||||
}
|
||||
|
||||
if _, err := reg.Register(prefix, runtime.SimpleValueGetterFunc(runner.get)); err != nil {
|
||||
return nil, fmt.Errorf("failed to register runtime value provider: %w", err)
|
||||
}
|
||||
|
||||
return runner, nil
|
||||
}
|
||||
|
||||
func (runner *RuntimeQueryRunner) get(keyOrPrefix string) ([]record.Record, error) {
|
||||
query := strings.TrimPrefix(
|
||||
keyOrPrefix,
|
||||
runner.keyPrefix,
|
||||
)
|
||||
|
||||
log.Infof("netquery: executing custom SQL query: %q", query)
|
||||
|
||||
var result []map[string]interface{}
|
||||
if err := runner.db.Execute(context.Background(), query, orm.WithResult(&result)); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform query %q: %w", query, err)
|
||||
}
|
||||
|
||||
// we need to wrap the result slice into a map as portbase/database attempts
|
||||
// to inject a _meta field.
|
||||
blob, err := json.Marshal(map[string]interface{}{
|
||||
"result": result,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal result: %w", err)
|
||||
}
|
||||
|
||||
// construct a new record wrapper that uses the already prepared JSON blob.
|
||||
key := fmt.Sprintf("%s:%s", runner.reg.DatabaseName(), keyOrPrefix)
|
||||
wrapper, err := record.NewWrapper(key, new(record.Meta), dsd.JSON, blob)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create record wrapper: %w", err)
|
||||
}
|
||||
|
||||
return []record.Record{wrapper}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user