Merge pull request #1341 from safing/feature/bandwidth-charts

Bandwidth Charts
This commit is contained in:
Daniel Hovie
2023-10-10 16:02:50 +02:00
committed by GitHub
8 changed files with 247 additions and 32 deletions

View File

@@ -679,9 +679,15 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
} }
defer conn.Unlock() defer conn.Unlock()
bytesIn := bwUpdate.BytesReceived
bytesOut := bwUpdate.BytesSent
// Update stats according to method. // Update stats according to method.
switch bwUpdate.Method { switch bwUpdate.Method {
case packet.Absolute: case packet.Absolute:
bytesIn = bwUpdate.BytesReceived - conn.BytesReceived
bytesOut = bwUpdate.BytesSent - conn.BytesSent
conn.BytesReceived = bwUpdate.BytesReceived conn.BytesReceived = bwUpdate.BytesReceived
conn.BytesSent = bwUpdate.BytesSent conn.BytesSent = bwUpdate.BytesSent
case packet.Additive: case packet.Additive:
@@ -697,10 +703,11 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
if err := netquery.DefaultModule.Store.UpdateBandwidth( if err := netquery.DefaultModule.Store.UpdateBandwidth(
ctx, ctx,
conn.HistoryEnabled, conn.HistoryEnabled,
fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile),
conn.Process().GetKey(), conn.Process().GetKey(),
conn.ID, conn.ID,
conn.BytesReceived, bytesIn,
conn.BytesSent, bytesOut,
); err != nil { ); err != nil {
log.Errorf("filter: failed to persist bandwidth data: %s", err) log.Errorf("filter: failed to persist bandwidth data: %s", err)
} }

View File

@@ -13,12 +13,12 @@ import (
"github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/netquery/orm"
) )
// ChartHandler handles requests for connection charts. // ActiveChartHandler handles requests for connection charts.
type ChartHandler struct { type ActiveChartHandler struct {
Database *Database Database *Database
} }
func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
requestPayload, err := ch.parseRequest(req) requestPayload, err := ch.parseRequest(req)
if err != nil { if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest) http.Error(resp, err.Error(), http.StatusBadRequest)
@@ -62,7 +62,7 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
}) })
} }
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
var body io.Reader var body io.Reader
switch req.Method { switch req.Method {
@@ -99,10 +99,11 @@ WITH RECURSIVE epoch(x) AS (
UNION ALL UNION ALL
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0 SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
) )
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked FROM epoch SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked
FROM epoch
JOIN connections JOIN connections
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0) ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
%s %s
GROUP BY round(timestamp/10, 0)*10;` GROUP BY round(timestamp/10, 0)*10;`
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig) clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)

View File

@@ -0,0 +1,148 @@
package netquery
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/safing/portmaster/netquery/orm"
)
// BandwidthChartHandler handles requests for connection charts.
type BandwidthChartHandler struct {
Database *Database
}
// BandwidthChartRequest holds a request for a bandwidth chart.
type BandwidthChartRequest struct {
Interval int `json:"interval"`
Query Query `json:"query"`
GroupBy []string `json:"groupBy"`
}
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := ch.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
orm.WithSchema(*ch.Database.Schema),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
"results": result,
"query": query,
"params": paramMap,
})
}
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
var body io.Reader
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
var requestPayload BandwidthChartRequest
blob, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
if req.Interval == 0 {
req.Interval = 10
}
interval := fmt.Sprintf("round(time/%d, 0)*%d", req.Interval, req.Interval)
// make sure there are only allowed fields specified in the request group-by
for _, gb := range req.GroupBy {
def := schema.GetColumnDef(gb)
if def == nil {
return "", nil, fmt.Errorf("unsupported groupBy key: %q", gb)
}
}
selects := append([]string{
interval + " as timestamp",
"SUM(incoming) as incoming",
"SUM(outgoing) as outgoing",
}, req.GroupBy...)
groupBy := append([]string{interval}, req.GroupBy...)
whereClause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
if err != nil {
return "", nil, err
}
if whereClause != "" {
whereClause = "WHERE " + whereClause
}
template := fmt.Sprintf(
`SELECT %s
FROM main.bandwidth AS bw
JOIN main.connections AS conns
ON bw.conn_id = conns.id
%s
GROUP BY %s
ORDER BY time ASC`,
strings.Join(selects, ", "),
whereClause,
strings.Join(groupBy, ", "),
)
return template, params, nil
}

View File

@@ -299,6 +299,19 @@ func (db *Database) ApplyMigrations() error {
} }
} }
bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth (
conn_id TEXT NOT NULL,
time INTEGER NOT NULL,
incoming INTEGER NOT NULL,
outgoing INTEGER NOT NULL,
CONSTRAINT fk_conn_id
FOREIGN KEY(conn_id) REFERENCES connections(id)
ON DELETE CASCADE
)`
if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil {
return fmt.Errorf("failed to create main.bandwidth database: %w", err)
}
return nil return nil
} }
@@ -535,21 +548,16 @@ func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes // UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database. // the bandwidth data to the history database.
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error { func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
params := map[string]any{ params := map[string]any{
":id": makeNqIDFromParts(processKey, connID), ":id": makeNqIDFromParts(processKey, connID),
} }
parts := []string{} parts := []string{}
if bytesReceived != 0 { parts = append(parts, "bytes_received = (bytes_received + :bytes_received)")
parts = append(parts, "bytes_received = :bytes_received") params[":bytes_received"] = bytesReceived
params[":bytes_received"] = bytesReceived parts = append(parts, "bytes_sent = (bytes_sent + :bytes_sent)")
} params[":bytes_sent"] = bytesSent
if bytesSent != 0 {
parts = append(parts, "bytes_sent = :bytes_sent")
params[":bytes_sent"] = bytesSent
}
updateSet := strings.Join(parts, ", ") updateSet := strings.Join(parts, ", ")
@@ -570,6 +578,13 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro
} }
} }
// also add the date to the in-memory bandwidth database
params[":time"] = time.Now().Unix()
stmt := "INSERT INTO main.bandwidth (conn_id, time, incoming, outgoing) VALUES(:id, :time, :bytes_received, :bytes_sent)"
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err))
}
return merr.ErrorOrNil() return merr.ErrorOrNil()
} }

View File

@@ -38,7 +38,7 @@ type (
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes // UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database. // the bandwidth data to the history database.
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
// CleanupHistory deletes data outside of the retention time frame from the history database. // CleanupHistory deletes data outside of the retention time frame from the history database.
CleanupHistory(ctx context.Context) error CleanupHistory(ctx context.Context) error

View File

@@ -87,7 +87,11 @@ func (m *module) prepare() error {
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
} }
chartHandler := &ChartHandler{ chartHandler := &ActiveChartHandler{
Database: m.Store,
}
bwChartHandler := &BandwidthChartHandler{
Database: m.Store, Database: m.Store,
} }
@@ -129,6 +133,19 @@ func (m *module) prepare() error {
return fmt.Errorf("failed to register API endpoint: %w", err) return fmt.Errorf("failed to register API endpoint: %w", err)
} }
if err := api.RegisterEndpoint(api.Endpoint{
// TODO: Use query parameters instead.
Path: "netquery/charts/bandwidth",
MimeType: "application/json",
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: bwChartHandler.ServeHTTP,
Name: "Bandwidth Chart",
Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{ if err := api.RegisterEndpoint(api.Endpoint{
Name: "Remove connections from profile history", Name: "Remove connections from profile history",
Description: "Remove all connections from the history database for one or more profiles", Description: "Remove all connections from the history database for one or more profiles",
@@ -137,7 +154,6 @@ func (m *module) prepare() error {
Write: api.PermitUser, Write: api.PermitUser,
BelongsTo: m.Module, BelongsTo: m.Module,
ActionFunc: func(ar *api.Request) (msg string, err error) { ActionFunc: func(ar *api.Request) (msg string, err error) {
// TODO: Use query parameters instead.
var body struct { var body struct {
ProfileIDs []string `json:"profileIDs"` ProfileIDs []string `json:"profileIDs"`
} }

View File

@@ -71,12 +71,18 @@ type (
Distinct bool `json:"distinct"` Distinct bool `json:"distinct"`
} }
FieldSelect struct {
Field string `json:"field"`
As string `json:"as"`
}
Select struct { Select struct {
Field string `json:"field"` Field string `json:"field"`
Count *Count `json:"$count,omitempty"` FieldSelect *FieldSelect `json:"$field"`
Sum *Sum `json:"$sum,omitempty"` Count *Count `json:"$count,omitempty"`
Min *Min `json:"$min,omitempty"` Sum *Sum `json:"$sum,omitempty"`
Distinct *string `json:"$distinct,omitempty"` Min *Min `json:"$min,omitempty"`
Distinct *string `json:"$distinct,omitempty"`
} }
Selects []Select Selects []Select
@@ -449,6 +455,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T
values := query[column] values := query[column]
colDef, ok := lm[column] colDef, ok := lm[column]
if !ok { if !ok {
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column)) errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
continue continue
@@ -538,11 +545,12 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
// directly // directly
if blob[0] == '{' { if blob[0] == '{' {
var res struct { var res struct {
Field string `json:"field"` Field string `json:"field"`
Count *Count `json:"$count"` Count *Count `json:"$count"`
Sum *Sum `json:"$sum"` Sum *Sum `json:"$sum"`
Min *Min `json:"$min"` Min *Min `json:"$min"`
Distinct *string `json:"$distinct"` Distinct *string `json:"$distinct"`
FieldSelect *FieldSelect `json:"$field"`
} }
if err := json.Unmarshal(blob, &res); err != nil { if err := json.Unmarshal(blob, &res); err != nil {
@@ -551,6 +559,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
sel.Count = res.Count sel.Count = res.Count
sel.Field = res.Field sel.Field = res.Field
sel.FieldSelect = res.FieldSelect
sel.Distinct = res.Distinct sel.Distinct = res.Distinct
sel.Sum = res.Sum sel.Sum = res.Sum
sel.Min = res.Min sel.Min = res.Min
@@ -570,6 +579,11 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+") return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
} }
} }
if sel.FieldSelect != nil && sel.FieldSelect.As != "" {
if !charOnlyRegexp.MatchString(sel.FieldSelect.As) {
return fmt.Errorf("invalid characters in $field.as, value must match [a-zA-Z]+")
}
}
return nil return nil
} }

View File

@@ -126,6 +126,8 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
} else { } else {
field = "*" field = "*"
} }
case s.FieldSelect != nil:
field = s.FieldSelect.Field
default: default:
field = s.Field field = s.Field
} }
@@ -141,6 +143,18 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
} }
switch { switch {
case s.FieldSelect != nil:
as := s.FieldSelect.As
if as == "" {
as = s.FieldSelect.Field
}
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("%s AS %s", s.FieldSelect.Field, as),
)
req.whitelistedFields = append(req.whitelistedFields, as)
case s.Count != nil: case s.Count != nil:
as := s.Count.As as := s.Count.As
if as == "" { if as == "" {