diff --git a/netquery/chart_handler.go b/netquery/chart_handler.go new file mode 100644 index 00000000..22790fb5 --- /dev/null +++ b/netquery/chart_handler.go @@ -0,0 +1,118 @@ +package netquery + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "strings" + + "github.com/safing/portmaster/netquery/orm" +) + +type ChartHandler struct { + Database *Database +} + +func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + requestPayload, err := ch.parseRequest(req) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + // actually execute the query against the database and collect the result + var result []map[string]interface{} + if err := ch.Database.Execute( + req.Context(), + query, + orm.WithNamedArgs(paramMap), + orm.WithResult(&result), + orm.WithSchema(*ch.Database.Schema), + ); err != nil { + http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError) + + return + } + + // send the HTTP status code + resp.WriteHeader(http.StatusOK) + + // prepare the result encoder. + enc := json.NewEncoder(resp) + enc.SetEscapeHTML(false) + enc.SetIndent("", " ") + + enc.Encode(map[string]interface{}{ + "results": result, + "query": query, + "params": paramMap, + }) +} + +func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { + var body io.Reader + + switch req.Method { + case http.MethodPost, http.MethodPut: + body = req.Body + case http.MethodGet: + body = strings.NewReader(req.URL.Query().Get("q")) + default: + return nil, fmt.Errorf("invalid HTTP method") + } + + var requestPayload QueryActiveConnectionChartPayload + blob, err := ioutil.ReadAll(body) + if err != nil { + return nil, fmt.Errorf("failed to read body" + err.Error()) + } + + body = bytes.NewReader(blob) + + dec := json.NewDecoder(body) + dec.DisallowUnknownFields() + + if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("invalid query: %w", err) + } + + return &requestPayload, nil +} + +func (req *QueryActiveConnectionChartPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { + template := ` +WITH RECURSIVE epoch(x) AS ( + SELECT strftime('%%s')-600 + UNION ALL + SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0 +) +SELECT x as timestamp, COUNT(*) AS value FROM epoch + JOIN connections + ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 > timestamp+0) + %s + GROUP BY round(timestamp/10, 0)*10;` + + clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig) + if err != nil { + return "", nil, err + } + + if clause == "" { + return fmt.Sprintf(template, ""), map[string]interface{}{}, nil + } + + return fmt.Sprintf(template, "WHERE ( "+clause+")"), params, nil +} diff --git a/netquery/database.go b/netquery/database.go index 25f9e571..4c53ebdd 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -63,32 +63,35 @@ type ( // time. We cannot just use the network.Connection.ID because it is only unique // as long as the connection is still active and might be, although unlikely, // reused afterwards. - ID string `sqlite:"id,primary"` - ProfileID string `sqlite:"profile"` - Path string `sqlite:"path"` - Type string `sqlite:"type,varchar(8)"` - External bool `sqlite:"external"` - IPVersion packet.IPVersion `sqlite:"ip_version"` - IPProtocol packet.IPProtocol `sqlite:"ip_protocol"` - LocalIP string `sqlite:"local_ip"` - LocalPort uint16 `sqlite:"local_port"` - RemoteIP string `sqlite:"remote_ip"` - RemotePort uint16 `sqlite:"remote_port"` - Domain string `sqlite:"domain"` - Country string `sqlite:"country,varchar(2)"` - ASN uint `sqlite:"asn"` - ASOwner string `sqlite:"as_owner"` - Latitude float64 `sqlite:"latitude"` - Longitude float64 `sqlite:"longitude"` - Scope netutils.IPScope `sqlite:"scope"` - Verdict network.Verdict `sqlite:"verdict"` - Started time.Time `sqlite:"started,text,time"` - Ended *time.Time `sqlite:"ended,text,time"` - Tunneled bool `sqlite:"tunneled"` - Encrypted bool `sqlite:"encrypted"` - Internal bool `sqlite:"internal"` - Inbound bool `sqlite:"inbound"` - ExtraData json.RawMessage `sqlite:"extra_data"` + ID string `sqlite:"id,primary"` + ProfileID string `sqlite:"profile"` + ProfileSource string `sqlite:"profileSource"` + Path string `sqlite:"path"` + Type string `sqlite:"type,varchar(8)"` + External bool `sqlite:"external"` + IPVersion packet.IPVersion `sqlite:"ip_version"` + IPProtocol packet.IPProtocol `sqlite:"ip_protocol"` + LocalIP string `sqlite:"local_ip"` + LocalPort uint16 `sqlite:"local_port"` + RemoteIP string `sqlite:"remote_ip"` + RemotePort uint16 `sqlite:"remote_port"` + Domain string `sqlite:"domain"` + Country string `sqlite:"country,varchar(2)"` + ASN uint `sqlite:"asn"` + ASOwner string `sqlite:"as_owner"` + Latitude float64 `sqlite:"latitude"` + Longitude float64 `sqlite:"longitude"` + Scope netutils.IPScope `sqlite:"scope"` + Verdict network.Verdict `sqlite:"verdict"` + Started time.Time `sqlite:"started,text,time"` + Ended *time.Time `sqlite:"ended,text,time"` + Tunneled bool `sqlite:"tunneled"` + Encrypted bool `sqlite:"encrypted"` + Internal bool `sqlite:"internal"` + Direction string `sqlite:"direction"` + ExtraData json.RawMessage `sqlite:"extra_data"` + Allowed *bool `sqlite:"allowed"` + ProfileRevision int `sqlite:"profile_revision"` } ) @@ -190,11 +193,11 @@ func (db *Database) CountRows(ctx context.Context) (int, error) { // probably not worth the cylces... func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) { where := `WHERE ended IS NOT NULL - AND datetime(ended) < :threshold` + AND datetime(ended) < datetime(:threshold)` sql := "DELETE FROM connections " + where + ";" args := orm.WithNamedArgs(map[string]interface{}{ - ":threshold": threshold, + ":threshold": threshold.UTC().Format(orm.SqliteTimeFormat), }) var result []struct { @@ -232,7 +235,7 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { if err := sqlitex.ExecuteTransient(db.conn, "SELECT * FROM connections", &sqlitex.ExecOptions{ ResultFunc: func(stmt *sqlite.Stmt) error { var c Conn - if err := orm.DecodeStmt(ctx, stmt, &c, orm.DefaultDecodeConfig); err != nil { + if err := orm.DecodeStmt(ctx, db.Schema, stmt, &c, orm.DefaultDecodeConfig); err != nil { return err } diff --git a/netquery/manager.go b/netquery/manager.go index 16a78021..559ccdfe 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -155,23 +155,42 @@ func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn C func convertConnection(conn *network.Connection) (*Conn, error) { conn.Lock() defer conn.Unlock() + direction := "outbound" + if conn.Inbound { + direction = "inbound" + } c := Conn{ - ID: genConnID(conn), - External: conn.External, - IPVersion: conn.IPVersion, - IPProtocol: conn.IPProtocol, - LocalIP: conn.LocalIP.String(), - LocalPort: conn.LocalPort, - Verdict: conn.Verdict, - Started: time.Unix(conn.Started, 0), - Tunneled: conn.Tunneled, - Encrypted: conn.Encrypted, - Internal: conn.Internal, - Inbound: conn.Inbound, - Type: ConnectionTypeToString[conn.Type], - ProfileID: conn.ProcessContext.ProfileName, - Path: conn.ProcessContext.BinaryPath, + ID: genConnID(conn), + External: conn.External, + IPVersion: conn.IPVersion, + IPProtocol: conn.IPProtocol, + LocalIP: conn.LocalIP.String(), + LocalPort: conn.LocalPort, + Verdict: conn.Verdict, + Started: time.Unix(conn.Started, 0), + Tunneled: conn.Tunneled, + Encrypted: conn.Encrypted, + Internal: conn.Internal, + Direction: direction, + Type: ConnectionTypeToString[conn.Type], + ProfileID: conn.ProcessContext.Profile, + ProfileSource: conn.ProcessContext.Source, + Path: conn.ProcessContext.BinaryPath, + ProfileRevision: int(conn.ProfileRevisionCounter), + } + + switch conn.Type { + case network.DNSRequest: + c.Type = "dns" + case network.IPConnection: + c.Type = "ip" + } + + switch conn.Verdict { + case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel: + accepted := true + c.Allowed = &accepted } if conn.Ended > 0 { @@ -181,6 +200,10 @@ func convertConnection(conn *network.Connection) (*Conn, error) { extraData := map[string]interface{}{} + if conn.TunnelContext != nil { + extraData["tunnel"] = conn.TunnelContext + } + if conn.Entity != nil { extraData["cname"] = conn.Entity.CNAME extraData["blockedByLists"] = conn.Entity.BlockedByLists diff --git a/netquery/module_api.go b/netquery/module_api.go index a344ec3b..dd15189e 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -63,6 +63,10 @@ func (m *Module) Prepare() error { IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), } + chartHandler := &ChartHandler{ + Database: m.sqlStore, + } + // FIXME(ppacher): use appropriate permissions for this if err := api.RegisterEndpoint(api.Endpoint{ Path: "netquery/query", @@ -77,6 +81,19 @@ func (m *Module) Prepare() error { return fmt.Errorf("failed to register API endpoint: %w", err) } + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "netquery/charts/connection-active", + MimeType: "application/json", + Read: api.PermitAnyone, + Write: api.PermitAnyone, + BelongsTo: m.Module, + HandlerFunc: chartHandler.ServeHTTP, + Name: "Query In-Memory Database", + Description: "Query the in-memory sqlite database", + }); err != nil { + return fmt.Errorf("failed to register API endpoint: %w", err) + } + return nil } @@ -120,11 +137,12 @@ func (mod *Module) Start() error { case <-ctx.Done(): return nil case <-time.After(10 * time.Second): - count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold)) + threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold) + count, err := mod.sqlStore.Cleanup(ctx, threshold) if err != nil { log.Errorf("netquery: failed to count number of rows in memory: %s", err) } else { - log.Infof("netquery: successfully removed %d old rows", count) + log.Infof("netquery: successfully removed %d old rows that ended before %s", count, threshold) } } } @@ -135,7 +153,7 @@ func (mod *Module) Start() error { select { case <-ctx.Done(): return nil - case <-time.After(5 * time.Second): + case <-time.After(1 * time.Second): count, err := mod.sqlStore.CountRows(ctx) if err != nil { log.Errorf("netquery: failed to count number of rows in memory: %s", err) diff --git a/netquery/orm/decoder.go b/netquery/orm/decoder.go index 76359755..6cc16f97 100644 --- a/netquery/orm/decoder.go +++ b/netquery/orm/decoder.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log" "reflect" "strings" "time" @@ -51,7 +52,7 @@ type ( } // DecodeFunc is called for each non-basic type during decoding. - DecodeFunc func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) + DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) DecodeConfig struct { DecodeHooks []DecodeFunc @@ -63,7 +64,7 @@ type ( // be specified to provide support for special types. // See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion. // -func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeConfig) error { +func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result interface{}, cfg DecodeConfig) error { // make sure we got something to decode into ... if result == nil { return fmt.Errorf("%w, got %T", errStructPointerExpected, result) @@ -71,7 +72,7 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo // fast path for decoding into a map if mp, ok := result.(*map[string]interface{}); ok { - return decodeIntoMap(ctx, stmt, mp) + return decodeIntoMap(ctx, schema, stmt, mp, cfg) } // make sure we got a pointer in result @@ -147,10 +148,13 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo value = storage.Elem() } + colDef := schema.GetColumnDef(colName) + // execute all decode hooks but make sure we use decodeBasic() as the // last one. columnValue, err := runDecodeHooks( i, + colDef, stmt, fieldType, value, @@ -188,10 +192,19 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo // FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing // func DatetimeDecoder(loc *time.Location) DecodeFunc { - return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) { + return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) { + // if we have the column definition available we + // use the target go type from there. + outType := outval.Type() + + if colDef != nil { + outType = colDef.GoType + } + // we only care about "time.Time" here - if outval.Type().String() != "time.Time" { - return nil, nil + if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) { + log.Printf("not decoding %s %v", outType, colDef) + return nil, false, nil } switch stmt.ColumnType(colIdx) { @@ -201,39 +214,61 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc { // TODO(ppacher): actually split the tag value at "," and search // the slice for "unixnano" if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") { - return time.Unix(0, int64(stmt.ColumnInt(colIdx))), nil + return time.Unix(0, int64(stmt.ColumnInt(colIdx))), true, nil } - return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), nil + return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), true, nil case sqlite.TypeText: // stored ISO8601 but does not have any timezone information // assigned so we always treat it as loc here. t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc) if err != nil { - return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err) + return nil, false, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err) } - return t, nil + return t, true, nil case sqlite.TypeFloat: // stored as Julian day numbers - return nil, fmt.Errorf("REAL storage type not support for time.Time") + return nil, false, fmt.Errorf("REAL storage type not support for time.Time") + + case sqlite.TypeNull: + return nil, true, nil default: - return nil, fmt.Errorf("unsupported storage type for time.Time: %s", outval.Type()) + return nil, false, fmt.Errorf("unsupported storage type for time.Time: %s", stmt.ColumnType(colIdx)) } } } -func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) error { +func decodeIntoMap(ctx context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error { if *mp == nil { *mp = make(map[string]interface{}) } for i := 0; i < stmt.ColumnCount(); i++ { var x interface{} - val, err := decodeBasic()(i, stmt, reflect.StructField{}, reflect.ValueOf(&x).Elem()) + + colDef := schema.GetColumnDef(stmt.ColumnName(i)) + + outVal := reflect.ValueOf(&x).Elem() + fieldType := reflect.StructField{} + if colDef != nil { + outVal = reflect.New(colDef.GoType).Elem() + fieldType = reflect.StructField{ + Type: colDef.GoType, + } + } + + val, err := runDecodeHooks( + i, + colDef, + stmt, + fieldType, + outVal, + append(cfg.DecodeHooks, decodeBasic()), + ) if err != nil { return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err) } @@ -245,56 +280,99 @@ func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) e } func decodeBasic() DecodeFunc { - return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) { + return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (result interface{}, handled bool, err error) { valueKind := getKind(outval) colType := stmt.ColumnType(colIdx) colName := stmt.ColumnName(colIdx) errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type()) + // if we have the column definition available we + // use the target go type from there. + if colDef != nil { + valueKind = normalizeKind(colDef.GoType.Kind()) + + // if we have a column defintion we try to convert the value to + // the actual Go-type that was used in the model. + // this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage + // or that type aliases like (type myInt int) are decoded into myInt instead of int + defer func() { + if handled { + t := reflect.New(colDef.GoType).Elem() + + if result == nil || reflect.ValueOf(result).IsZero() { + return + } + + if reflect.ValueOf(result).Type().ConvertibleTo(colDef.GoType) { + result = reflect.ValueOf(result).Convert(colDef.GoType).Interface() + } + t.Set(reflect.ValueOf(result)) + + result = t.Interface() + } + }() + } + + log.Printf("decoding %s into kind %s", colName, valueKind) + + if colType == sqlite.TypeNull { + if colDef != nil && colDef.Nullable { + return nil, true, nil + } + + if colDef != nil && !colDef.Nullable { + return reflect.New(colDef.GoType).Elem().Interface(), true, nil + } + + if outval.Kind() == reflect.Ptr { + return nil, true, nil + } + } + switch valueKind { case reflect.String: if colType != sqlite.TypeText { - return nil, errInvalidType + return nil, false, errInvalidType } - return stmt.ColumnText(colIdx), nil + return stmt.ColumnText(colIdx), true, nil case reflect.Bool: // sqlite does not have a BOOL type, it rather stores a 1/0 in a column // with INTEGER affinity. if colType != sqlite.TypeInteger { - return nil, errInvalidType + return nil, false, errInvalidType } - return stmt.ColumnBool(colIdx), nil + return stmt.ColumnBool(colIdx), true, nil case reflect.Float64: if colType != sqlite.TypeFloat { - return nil, errInvalidType + return nil, false, errInvalidType } - return stmt.ColumnFloat(colIdx), nil + return stmt.ColumnFloat(colIdx), true, nil case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ... if colType != sqlite.TypeInteger { - return nil, errInvalidType + return nil, false, errInvalidType } - return stmt.ColumnInt(colIdx), nil + return stmt.ColumnInt(colIdx), true, nil case reflect.Slice: if outval.Type().Elem().Kind() != reflect.Uint8 { - return nil, fmt.Errorf("slices other than []byte for BLOB are not supported") + return nil, false, fmt.Errorf("slices other than []byte for BLOB are not supported") } if colType != sqlite.TypeBlob { - return nil, errInvalidType + return nil, false, errInvalidType } columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx)) if err != nil { - return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err) + return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err) } - return columnValue, nil + return columnValue, true, nil case reflect.Interface: var ( @@ -306,7 +384,7 @@ func decodeBasic() DecodeFunc { t = reflect.TypeOf([]byte{}) columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx)) if err != nil { - return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err) + return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err) } x = columnValue @@ -327,20 +405,20 @@ func decodeBasic() DecodeFunc { x = nil default: - return nil, fmt.Errorf("unsupported column type %s", colType) + return nil, false, fmt.Errorf("unsupported column type %s", colType) } if t == nil { - return nil, nil + return nil, true, nil } target := reflect.New(t).Elem() target.Set(reflect.ValueOf(x)) - return target.Interface(), nil + return target.Interface(), true, nil default: - return nil, fmt.Errorf("cannot decode into %s", valueKind) + return nil, false, fmt.Errorf("cannot decode into %s", valueKind) } } } @@ -362,14 +440,14 @@ func sqlColumnName(fieldType reflect.StructField) string { // runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks. // The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is // returned by a decode hook runDecodeHooks stops the error is returned to the caller. -func runDecodeHooks(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) { +func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) { for _, fn := range hooks { - res, err := fn(colIdx, stmt, fieldDef, outval) + res, end, err := fn(colIdx, colDef, stmt, fieldDef, outval) if err != nil { return res, err } - if res != nil { + if end { return res, nil } } diff --git a/netquery/orm/decoder_test.go b/netquery/orm/decoder_test.go index c6c836c0..5abd324e 100644 --- a/netquery/orm/decoder_test.go +++ b/netquery/orm/decoder_test.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "encoding/json" + "log" + "reflect" "testing" "time" @@ -82,6 +84,8 @@ func (ett *exampleTimeTypes) Equal(other interface{}) bool { return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil) } +type myInt int + type exampleTimeNano struct { T time.Time `sqlite:",unixnano"` } @@ -100,10 +104,11 @@ func Test_Decoder(t *testing.T) { refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) cases := []struct { - Desc string - Stmt testStmt - Result interface{} - Expected interface{} + Desc string + Stmt testStmt + ColumnDef []ColumnDef + Result interface{} + Expected interface{} }{ { "Decoding into nil is not allowed", @@ -114,6 +119,7 @@ func Test_Decoder(t *testing.T) { }, nil, nil, + nil, }, { "Decoding into basic types", @@ -132,6 +138,7 @@ func Test_Decoder(t *testing.T) { true, }, }, + nil, &exampleFieldTypes{}, &exampleFieldTypes{ S: "string value", @@ -157,6 +164,7 @@ func Test_Decoder(t *testing.T) { 1.2, }, }, + nil, &exampleFieldTypes{}, &exampleFieldTypes{ S: "string value", @@ -178,6 +186,7 @@ func Test_Decoder(t *testing.T) { true, }, }, + nil, &exampleFieldTypes{}, &exampleFieldTypes{ F: 1.2, @@ -201,6 +210,7 @@ func Test_Decoder(t *testing.T) { true, }, }, + nil, &examplePointerTypes{}, func() interface{} { s := "string value" @@ -231,6 +241,7 @@ func Test_Decoder(t *testing.T) { true, }, }, + nil, &examplePointerTypes{}, func() interface{} { s := "string value" @@ -255,6 +266,7 @@ func Test_Decoder(t *testing.T) { 1, }, }, + nil, &exampleStructTags{}, &exampleStructTags{ S: "string value", @@ -280,6 +292,7 @@ func Test_Decoder(t *testing.T) { 1, }, }, + nil, &exampleIntConv{}, &exampleIntConv{ 1, 1, 1, 1, 1, @@ -301,6 +314,7 @@ func Test_Decoder(t *testing.T) { 1.0, }, }, + nil, &exampleFieldTypes{}, &exampleFieldTypes{ F: 1.0, @@ -322,6 +336,7 @@ func Test_Decoder(t *testing.T) { 1.0, }, }, + nil, &examplePointerTypes{}, func() interface{} { f := 1.0 @@ -340,6 +355,7 @@ func Test_Decoder(t *testing.T) { ([]byte)("hello world"), }, }, + nil, &exampleBlobTypes{}, &exampleBlobTypes{ B: ([]byte)("hello world"), @@ -356,6 +372,7 @@ func Test_Decoder(t *testing.T) { ([]byte)("hello world"), }, }, + nil, &exampleJSONRawTypes{}, &exampleJSONRawTypes{ B: (json.RawMessage)("hello world"), @@ -374,6 +391,7 @@ func Test_Decoder(t *testing.T) { int(refTime.Unix()), }, }, + nil, &exampleTimeTypes{}, &exampleTimeTypes{ T: refTime, @@ -393,6 +411,7 @@ func Test_Decoder(t *testing.T) { int(refTime.UnixNano()), }, }, + nil, &exampleTimeNano{}, &exampleTimeNano{ T: refTime, @@ -411,6 +430,7 @@ func Test_Decoder(t *testing.T) { "value2", }, }, + nil, &exampleInterface{}, func() interface{} { var x interface{} @@ -439,6 +459,7 @@ func Test_Decoder(t *testing.T) { []byte("blob value"), }, }, + nil, new(map[string]interface{}), &map[string]interface{}{ "I": 1, @@ -447,14 +468,91 @@ func Test_Decoder(t *testing.T) { "B": []byte("blob value"), }, }, + { + "Decoding using type-hints", + testStmt{ + columns: []string{"B", "T"}, + types: []sqlite.ColumnType{ + sqlite.TypeInteger, + sqlite.TypeText, + }, + values: []interface{}{ + true, + refTime.Format(SqliteTimeFormat), + }, + }, + []ColumnDef{ + { + Name: "B", + Type: sqlite.TypeInteger, + GoType: reflect.TypeOf(true), + }, + { + Name: "T", + Type: sqlite.TypeText, + GoType: reflect.TypeOf(time.Time{}), + IsTime: true, + }, + }, + new(map[string]interface{}), + &map[string]interface{}{ + "B": true, + "T": refTime, + }, + }, + { + "Decoding into type aliases", + testStmt{ + columns: []string{"B"}, + types: []sqlite.ColumnType{ + sqlite.TypeBlob, + }, + values: []interface{}{ + []byte(`{"foo": "bar}`), + }, + }, + []ColumnDef{ + { + Name: "B", + Type: sqlite.TypeBlob, + GoType: reflect.TypeOf(json.RawMessage(`{"foo": "bar}`)), + }, + }, + new(map[string]interface{}), + &map[string]interface{}{ + "B": json.RawMessage(`{"foo": "bar}`), + }, + }, + { + "Decoding into type aliases #2", + testStmt{ + columns: []string{"I"}, + types: []sqlite.ColumnType{sqlite.TypeInteger}, + values: []interface{}{ + 10, + }, + }, + []ColumnDef{ + { + Name: "I", + Type: sqlite.TypeInteger, + GoType: reflect.TypeOf(myInt(0)), + }, + }, + new(map[string]interface{}), + &map[string]interface{}{ + "I": myInt(10), + }, + }, } for idx := range cases { c := cases[idx] t.Run(c.Desc, func(t *testing.T) { - t.Parallel() + //t.Parallel() - err := DecodeStmt(ctx, c.Stmt, c.Result, DefaultDecodeConfig) + log.Println(c.Desc) + err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig) if fn, ok := c.Expected.(func() interface{}); ok { c.Expected = fn() } diff --git a/netquery/orm/query_runner.go b/netquery/orm/query_runner.go index 6de0c66a..88eceefd 100644 --- a/netquery/orm/query_runner.go +++ b/netquery/orm/query_runner.go @@ -20,6 +20,7 @@ type ( NamedArgs map[string]interface{} Result interface{} DecodeConfig DecodeConfig + Schema TableSchema } ) @@ -56,6 +57,12 @@ func WithNamedArgs(args map[string]interface{}) QueryOption { } } +func WithSchema(tbl TableSchema) QueryOption { + return func(opts *queryOpts) { + opts.Schema = tbl + } +} + // WithResult sets the result receiver. result is expected to // be a pointer to a slice of struct or map types. // @@ -136,7 +143,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q currentField = reflect.New(valElemType) - if err := DecodeStmt(ctx, stmt, currentField.Interface(), args.DecodeConfig); err != nil { + if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil { return err } diff --git a/netquery/orm/schema_builder.go b/netquery/orm/schema_builder.go index 68783c3f..5d533be5 100644 --- a/netquery/orm/schema_builder.go +++ b/netquery/orm/schema_builder.go @@ -45,6 +45,7 @@ type ( Name string Nullable bool Type sqlite.ColumnType + GoType reflect.Type Length int PrimaryKey bool AutoIncrement bool @@ -145,6 +146,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) { ft = fieldType.Type.Elem() } + def.GoType = ft kind := normalizeKind(ft.Kind()) switch kind { diff --git a/netquery/query.go b/netquery/query.go index 6febf0b6..929255ee 100644 --- a/netquery/query.go +++ b/netquery/query.go @@ -16,6 +16,12 @@ import ( type ( Query map[string][]Matcher + MatchType interface { + Operator() string + } + + Equal interface{} + Matcher struct { Equal interface{} `json:"$eq,omitempty"` NotEqual interface{} `json:"$ne,omitempty"` @@ -27,12 +33,22 @@ type ( Count struct { As string `json:"as"` Field string `json:"field"` - Distinct bool `json:"distict"` + Distinct bool `json:"distinct"` } + Sum struct { + Condition Query `json:"condition"` + As string `json:"as"` + Distinct bool `json:"distinct"` + } + + // NOTE: whenever adding support for new operators make sure + // to update UnmarshalJSON as well. Select struct { - Field string `json:"field"` - Count *Count `json:"$count"` + Field string `json:"field"` + Count *Count `json:"$count,omitempty"` + Sum *Sum `json:"$sum,omitempty"` + Distinct *string `json:"$distinct"` } Selects []Select @@ -45,6 +61,11 @@ type ( selectedFields []string whitelistedFields []string + paramMap map[string]interface{} + } + + QueryActiveConnectionChartPayload struct { + Query Query `json:"query"` } OrderBy struct { @@ -179,15 +200,15 @@ func (match Matcher) Validate() error { return nil } -func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { +func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { var ( queryParts []string params = make(map[string]interface{}) errs = new(multierror.Error) - key = fmt.Sprintf("%s%d", colDef.Name, idx) + key = fmt.Sprintf("%s%s", colDef.Name, suffix) ) - add := func(operator, suffix string, values ...interface{}) { + add := func(operator, suffix string, list bool, values ...interface{}) { var placeholder []string for idx, value := range values { @@ -204,7 +225,7 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct params[uniqKey] = encodedValue } - if len(placeholder) == 1 { + if len(placeholder) == 1 && !list { queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0])) } else { queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", "))) @@ -212,23 +233,23 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct } if match.Equal != nil { - add("=", "eq", match.Equal) + add("=", "eq", false, match.Equal) } if match.NotEqual != nil { - add("!=", "ne", match.NotEqual) + add("!=", "ne", false, match.NotEqual) } if match.In != nil { - add("IN", "in", match.In...) + add("IN", "in", true, match.In...) } if match.NotIn != nil { - add("NOT IN", "notin", match.NotIn...) + add("NOT IN", "notin", true, match.NotIn...) } if match.Like != "" { - add("LIKE", "like", match.Like) + add("LIKE", "like", false, match.Like) } if len(queryParts) == 0 { @@ -244,7 +265,7 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil() } -func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { +func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { if len(query) == 0 { return "", nil, nil } @@ -279,7 +300,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, enc queryParts := make([]string, len(values)) for idx, val := range values { - matcherQuery, params, err := val.toSQLConditionClause(ctx, idx, "AND", colDef, encoderConfig) + matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig) if err != nil { errs.Errors = append(errs.Errors, fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err), @@ -359,8 +380,10 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { // directly if blob[0] == '{' { var res struct { - Field string `json:"field"` - Count *Count `json:"$count"` + Field string `json:"field"` + Count *Count `json:"$count"` + Sum *Sum `json:"$sum"` + Distinct *string `json:"$distinct"` } if err := json.Unmarshal(blob, &res); err != nil { @@ -369,6 +392,8 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { sel.Count = res.Count sel.Field = res.Field + sel.Distinct = res.Distinct + sel.Sum = res.Sum if sel.Count != nil && sel.Count.As != "" { if !charOnlyRegexp.MatchString(sel.Count.As) { diff --git a/netquery/query_handler.go b/netquery/query_handler.go index a1e4353f..318e494d 100644 --- a/netquery/query_handler.go +++ b/netquery/query_handler.go @@ -58,6 +58,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { query, orm.WithNamedArgs(paramMap), orm.WithResult(&result), + orm.WithSchema(*qh.Database.Schema), ); err != nil { http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError) @@ -139,13 +140,14 @@ func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, e } func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { - if err := req.prepareSelectedFields(schema); err != nil { + if err := req.prepareSelectedFields(ctx, schema); err != nil { return "", nil, fmt.Errorf("perparing selected fields: %w", err) } // build the SQL where clause from the payload query whereClause, paramMap, err := req.Query.toSQLWhereClause( ctx, + "", schema, orm.DefaultEncodeConfig, ) @@ -153,6 +155,14 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab return "", nil, fmt.Errorf("ganerating where clause: %w", err) } + if req.paramMap == nil { + req.paramMap = make(map[string]interface{}) + } + + for key, val := range paramMap { + req.paramMap[key] = val + } + // build the actual SQL query statement // FIXME(ppacher): add support for group-by and sort-by @@ -173,20 +183,26 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab } query += " " + groupByClause + " " + orderByClause - return query, paramMap, nil + return query, req.paramMap, nil } -func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) error { - for _, s := range req.Select { +func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error { + for idx, s := range req.Select { var field string - if s.Count != nil { + switch { + case s.Count != nil: field = s.Count.Field - } else { + case s.Distinct != nil: + field = *s.Distinct + case s.Sum != nil: + // field is not used in case of $sum + field = "*" + default: field = s.Field } colName := "*" - if field != "*" || s.Count == nil { + if field != "*" || (s.Count == nil && s.Sum == nil) { var err error colName, err = req.validateColumnName(schema, field) @@ -195,7 +211,8 @@ func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) e } } - if s.Count != nil { + switch { + case s.Count != nil: var as = s.Count.As if as == "" { as = fmt.Sprintf("%s_count", colName) @@ -204,9 +221,34 @@ func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) e if s.Count.Distinct { distinct = "DISTINCT " } - req.selectedFields = append(req.selectedFields, fmt.Sprintf("COUNT(%s%s) as %s", distinct, colName, as)) + req.selectedFields = append( + req.selectedFields, + fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as), + ) req.whitelistedFields = append(req.whitelistedFields, as) - } else { + + case s.Sum != nil: + if s.Sum.As == "" { + return fmt.Errorf("missing 'as' for $sum") + } + + clause, params, err := s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig) + if err != nil { + return fmt.Errorf("in $sum: %w", err) + } + + req.paramMap = params + req.selectedFields = append( + req.selectedFields, + fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As), + ) + req.whitelistedFields = append(req.whitelistedFields, s.Sum.As) + + case s.Distinct != nil: + req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName)) + req.whitelistedFields = append(req.whitelistedFields, colName) + + default: req.selectedFields = append(req.selectedFields, colName) } } @@ -251,6 +293,10 @@ func (req *QueryRequestPayload) generateSelectClause() string { } func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) { + if len(req.OrderBy) == 0 { + return "", nil + } + var orderBys = make([]string, len(req.OrderBy)) for idx, sort := range req.OrderBy { colName, err := req.validateColumnName(schema, sort.Field) @@ -286,7 +332,7 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel } } - return "", fmt.Errorf("column name %s not allowed", field) + return "", fmt.Errorf("column name %q not allowed", field) } // compile time check diff --git a/netquery/query_test.go b/netquery/query_test.go index f8e8b3e4..7d9a0393 100644 --- a/netquery/query_test.go +++ b/netquery/query_test.go @@ -228,7 +228,7 @@ func Test_QueryBuilder(t *testing.T) { for idx, c := range cases { t.Run(c.N, func(t *testing.T) { //t.Parallel() - str, params, err := c.Q.toSQLWhereClause(context.TODO(), tbl, orm.DefaultEncodeConfig) + str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig) if c.E != nil { if assert.Error(t, err) {