diff --git a/firewall/packet_handler.go b/firewall/packet_handler.go index 0ddf3a9a..e30d68aa 100644 --- a/firewall/packet_handler.go +++ b/firewall/packet_handler.go @@ -679,9 +679,15 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) { } defer conn.Unlock() + bytesIn := bwUpdate.BytesReceived + bytesOut := bwUpdate.BytesSent + // Update stats according to method. switch bwUpdate.Method { case packet.Absolute: + bytesIn = bwUpdate.BytesReceived - conn.BytesReceived + bytesOut = bwUpdate.BytesSent - conn.BytesSent + conn.BytesReceived = bwUpdate.BytesReceived conn.BytesSent = bwUpdate.BytesSent case packet.Additive: @@ -697,10 +703,11 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) { if err := netquery.DefaultModule.Store.UpdateBandwidth( ctx, conn.HistoryEnabled, + fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile), conn.Process().GetKey(), conn.ID, - conn.BytesReceived, - conn.BytesSent, + bytesIn, + bytesOut, ); err != nil { log.Errorf("filter: failed to persist bandwidth data: %s", err) } diff --git a/netquery/chart_handler.go b/netquery/active_chart_handler.go similarity index 85% rename from netquery/chart_handler.go rename to netquery/active_chart_handler.go index a44f03ac..08628394 100644 --- a/netquery/chart_handler.go +++ b/netquery/active_chart_handler.go @@ -13,12 +13,12 @@ import ( "github.com/safing/portmaster/netquery/orm" ) -// ChartHandler handles requests for connection charts. -type ChartHandler struct { +// ActiveChartHandler handles requests for connection charts. +type ActiveChartHandler struct { Database *Database } -func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl requestPayload, err := ch.parseRequest(req) if err != nil { http.Error(resp, err.Error(), http.StatusBadRequest) @@ -62,7 +62,7 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { }) } -func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl +func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl var body io.Reader switch req.Method { @@ -99,10 +99,11 @@ WITH RECURSIVE epoch(x) AS ( UNION ALL SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0 ) -SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked FROM epoch +SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked + FROM epoch JOIN connections - ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0) - %s + ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0) + %s GROUP BY round(timestamp/10, 0)*10;` clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig) diff --git a/netquery/bandwidth_chart_handler.go b/netquery/bandwidth_chart_handler.go new file mode 100644 index 00000000..5bb5b526 --- /dev/null +++ b/netquery/bandwidth_chart_handler.go @@ -0,0 +1,148 @@ +package netquery + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/safing/portmaster/netquery/orm" +) + +// BandwidthChartHandler handles requests for connection charts. +type BandwidthChartHandler struct { + Database *Database +} + +// BandwidthChartRequest holds a request for a bandwidth chart. +type BandwidthChartRequest struct { + Interval int `json:"interval"` + Query Query `json:"query"` + GroupBy []string `json:"groupBy"` +} + +func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl + requestPayload, err := ch.parseRequest(req) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + // actually execute the query against the database and collect the result + var result []map[string]interface{} + if err := ch.Database.Execute( + req.Context(), + query, + orm.WithNamedArgs(paramMap), + orm.WithResult(&result), + orm.WithSchema(*ch.Database.Schema), + ); err != nil { + http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError) + + return + } + + // send the HTTP status code + resp.WriteHeader(http.StatusOK) + + // prepare the result encoder. + enc := json.NewEncoder(resp) + enc.SetEscapeHTML(false) + enc.SetIndent("", " ") + + _ = enc.Encode(map[string]interface{}{ //nolint:errchkjson + "results": result, + "query": query, + "params": paramMap, + }) +} + +func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl + var body io.Reader + + switch req.Method { + case http.MethodPost, http.MethodPut: + body = req.Body + case http.MethodGet: + body = strings.NewReader(req.URL.Query().Get("q")) + default: + return nil, fmt.Errorf("invalid HTTP method") + } + + var requestPayload BandwidthChartRequest + blob, err := io.ReadAll(body) + if err != nil { + return nil, fmt.Errorf("failed to read body" + err.Error()) + } + + body = bytes.NewReader(blob) + + dec := json.NewDecoder(body) + dec.DisallowUnknownFields() + + if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("invalid query: %w", err) + } + + return &requestPayload, nil +} + +func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { + if req.Interval == 0 { + req.Interval = 10 + } + + interval := fmt.Sprintf("round(time/%d, 0)*%d", req.Interval, req.Interval) + + // make sure there are only allowed fields specified in the request group-by + for _, gb := range req.GroupBy { + def := schema.GetColumnDef(gb) + if def == nil { + return "", nil, fmt.Errorf("unsupported groupBy key: %q", gb) + } + } + + selects := append([]string{ + interval + " as timestamp", + "SUM(incoming) as incoming", + "SUM(outgoing) as outgoing", + }, req.GroupBy...) + + groupBy := append([]string{interval}, req.GroupBy...) + + whereClause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig) + if err != nil { + return "", nil, err + } + + if whereClause != "" { + whereClause = "WHERE " + whereClause + } + + template := fmt.Sprintf( + `SELECT %s + FROM main.bandwidth AS bw + JOIN main.connections AS conns + ON bw.conn_id = conns.id + %s + GROUP BY %s + ORDER BY time ASC`, + strings.Join(selects, ", "), + whereClause, + strings.Join(groupBy, ", "), + ) + + return template, params, nil +} diff --git a/netquery/database.go b/netquery/database.go index 807c0097..7adae82b 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -299,6 +299,19 @@ func (db *Database) ApplyMigrations() error { } } + bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth ( + conn_id TEXT NOT NULL, + time INTEGER NOT NULL, + incoming INTEGER NOT NULL, + outgoing INTEGER NOT NULL, + CONSTRAINT fk_conn_id + FOREIGN KEY(conn_id) REFERENCES connections(id) + ON DELETE CASCADE + )` + if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil { + return fmt.Errorf("failed to create main.bandwidth database: %w", err) + } + return nil } @@ -535,21 +548,16 @@ func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error { // UpdateBandwidth updates bandwidth data for the connection and optionally also writes // the bandwidth data to the history database. -func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error { +func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error { params := map[string]any{ ":id": makeNqIDFromParts(processKey, connID), } parts := []string{} - if bytesReceived != 0 { - parts = append(parts, "bytes_received = :bytes_received") - params[":bytes_received"] = bytesReceived - } - - if bytesSent != 0 { - parts = append(parts, "bytes_sent = :bytes_sent") - params[":bytes_sent"] = bytesSent - } + parts = append(parts, "bytes_received = (bytes_received + :bytes_received)") + params[":bytes_received"] = bytesReceived + parts = append(parts, "bytes_sent = (bytes_sent + :bytes_sent)") + params[":bytes_sent"] = bytesSent updateSet := strings.Join(parts, ", ") @@ -570,6 +578,13 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro } } + // also add the date to the in-memory bandwidth database + params[":time"] = time.Now().Unix() + stmt := "INSERT INTO main.bandwidth (conn_id, time, incoming, outgoing) VALUES(:id, :time, :bytes_received, :bytes_sent)" + if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil { + merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err)) + } + return merr.ErrorOrNil() } diff --git a/netquery/manager.go b/netquery/manager.go index 8749f482..810782a8 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -38,7 +38,7 @@ type ( // UpdateBandwidth updates bandwidth data for the connection and optionally also writes // the bandwidth data to the history database. - UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error + UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error // CleanupHistory deletes data outside of the retention time frame from the history database. CleanupHistory(ctx context.Context) error diff --git a/netquery/module_api.go b/netquery/module_api.go index 39647186..63a3d07f 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -87,7 +87,11 @@ func (m *module) prepare() error { IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), } - chartHandler := &ChartHandler{ + chartHandler := &ActiveChartHandler{ + Database: m.Store, + } + + bwChartHandler := &BandwidthChartHandler{ Database: m.Store, } @@ -129,6 +133,19 @@ func (m *module) prepare() error { return fmt.Errorf("failed to register API endpoint: %w", err) } + if err := api.RegisterEndpoint(api.Endpoint{ + // TODO: Use query parameters instead. + Path: "netquery/charts/bandwidth", + MimeType: "application/json", + Write: api.PermitUser, + BelongsTo: m.Module, + HandlerFunc: bwChartHandler.ServeHTTP, + Name: "Bandwidth Chart", + Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.", + }); err != nil { + return fmt.Errorf("failed to register API endpoint: %w", err) + } + if err := api.RegisterEndpoint(api.Endpoint{ Name: "Remove connections from profile history", Description: "Remove all connections from the history database for one or more profiles", @@ -137,7 +154,6 @@ func (m *module) prepare() error { Write: api.PermitUser, BelongsTo: m.Module, ActionFunc: func(ar *api.Request) (msg string, err error) { - // TODO: Use query parameters instead. var body struct { ProfileIDs []string `json:"profileIDs"` } diff --git a/netquery/query.go b/netquery/query.go index f62cd958..2b81bfb1 100644 --- a/netquery/query.go +++ b/netquery/query.go @@ -71,12 +71,18 @@ type ( Distinct bool `json:"distinct"` } + FieldSelect struct { + Field string `json:"field"` + As string `json:"as"` + } + Select struct { - Field string `json:"field"` - Count *Count `json:"$count,omitempty"` - Sum *Sum `json:"$sum,omitempty"` - Min *Min `json:"$min,omitempty"` - Distinct *string `json:"$distinct,omitempty"` + Field string `json:"field"` + FieldSelect *FieldSelect `json:"$field"` + Count *Count `json:"$count,omitempty"` + Sum *Sum `json:"$sum,omitempty"` + Min *Min `json:"$min,omitempty"` + Distinct *string `json:"$distinct,omitempty"` } Selects []Select @@ -449,6 +455,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T values := query[column] colDef, ok := lm[column] if !ok { + errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column)) continue @@ -538,11 +545,12 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { // directly if blob[0] == '{' { var res struct { - Field string `json:"field"` - Count *Count `json:"$count"` - Sum *Sum `json:"$sum"` - Min *Min `json:"$min"` - Distinct *string `json:"$distinct"` + Field string `json:"field"` + Count *Count `json:"$count"` + Sum *Sum `json:"$sum"` + Min *Min `json:"$min"` + Distinct *string `json:"$distinct"` + FieldSelect *FieldSelect `json:"$field"` } if err := json.Unmarshal(blob, &res); err != nil { @@ -551,6 +559,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { sel.Count = res.Count sel.Field = res.Field + sel.FieldSelect = res.FieldSelect sel.Distinct = res.Distinct sel.Sum = res.Sum sel.Min = res.Min @@ -570,6 +579,11 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+") } } + if sel.FieldSelect != nil && sel.FieldSelect.As != "" { + if !charOnlyRegexp.MatchString(sel.FieldSelect.As) { + return fmt.Errorf("invalid characters in $field.as, value must match [a-zA-Z]+") + } + } return nil } diff --git a/netquery/query_request.go b/netquery/query_request.go index b4a07041..ea5162a9 100644 --- a/netquery/query_request.go +++ b/netquery/query_request.go @@ -126,6 +126,8 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem } else { field = "*" } + case s.FieldSelect != nil: + field = s.FieldSelect.Field default: field = s.Field } @@ -141,6 +143,18 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem } switch { + case s.FieldSelect != nil: + as := s.FieldSelect.As + if as == "" { + as = s.FieldSelect.Field + } + + req.selectedFields = append( + req.selectedFields, + fmt.Sprintf("%s AS %s", s.FieldSelect.Field, as), + ) + req.whitelistedFields = append(req.whitelistedFields, as) + case s.Count != nil: as := s.Count.As if as == "" {