Merge pull request #1341 from safing/feature/bandwidth-charts

Bandwidth Charts
This commit is contained in:
Daniel Hovie
2023-10-10 16:02:50 +02:00
committed by GitHub
8 changed files with 247 additions and 32 deletions

View File

@@ -679,9 +679,15 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
}
defer conn.Unlock()
bytesIn := bwUpdate.BytesReceived
bytesOut := bwUpdate.BytesSent
// Update stats according to method.
switch bwUpdate.Method {
case packet.Absolute:
bytesIn = bwUpdate.BytesReceived - conn.BytesReceived
bytesOut = bwUpdate.BytesSent - conn.BytesSent
conn.BytesReceived = bwUpdate.BytesReceived
conn.BytesSent = bwUpdate.BytesSent
case packet.Additive:
@@ -697,10 +703,11 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
if err := netquery.DefaultModule.Store.UpdateBandwidth(
ctx,
conn.HistoryEnabled,
fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile),
conn.Process().GetKey(),
conn.ID,
conn.BytesReceived,
conn.BytesSent,
bytesIn,
bytesOut,
); err != nil {
log.Errorf("filter: failed to persist bandwidth data: %s", err)
}

View File

@@ -13,12 +13,12 @@ import (
"github.com/safing/portmaster/netquery/orm"
)
// ChartHandler handles requests for connection charts.
type ChartHandler struct {
// ActiveChartHandler handles requests for connection charts.
type ActiveChartHandler struct {
Database *Database
}
func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
@@ -62,7 +62,7 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
})
}
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
var body io.Reader
switch req.Method {
@@ -99,10 +99,11 @@ WITH RECURSIVE epoch(x) AS (
UNION ALL
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
)
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked FROM epoch
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked
FROM epoch
JOIN connections
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
%s
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
%s
GROUP BY round(timestamp/10, 0)*10;`
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)

View File

@@ -0,0 +1,148 @@
package netquery
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/safing/portmaster/netquery/orm"
)
// BandwidthChartHandler handles requests for connection charts.
type BandwidthChartHandler struct {
Database *Database
}
// BandwidthChartRequest holds a request for a bandwidth chart.
type BandwidthChartRequest struct {
Interval int `json:"interval"`
Query Query `json:"query"`
GroupBy []string `json:"groupBy"`
}
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := ch.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
orm.WithSchema(*ch.Database.Schema),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
"results": result,
"query": query,
"params": paramMap,
})
}
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
var body io.Reader
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
var requestPayload BandwidthChartRequest
blob, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
if req.Interval == 0 {
req.Interval = 10
}
interval := fmt.Sprintf("round(time/%d, 0)*%d", req.Interval, req.Interval)
// make sure there are only allowed fields specified in the request group-by
for _, gb := range req.GroupBy {
def := schema.GetColumnDef(gb)
if def == nil {
return "", nil, fmt.Errorf("unsupported groupBy key: %q", gb)
}
}
selects := append([]string{
interval + " as timestamp",
"SUM(incoming) as incoming",
"SUM(outgoing) as outgoing",
}, req.GroupBy...)
groupBy := append([]string{interval}, req.GroupBy...)
whereClause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
if err != nil {
return "", nil, err
}
if whereClause != "" {
whereClause = "WHERE " + whereClause
}
template := fmt.Sprintf(
`SELECT %s
FROM main.bandwidth AS bw
JOIN main.connections AS conns
ON bw.conn_id = conns.id
%s
GROUP BY %s
ORDER BY time ASC`,
strings.Join(selects, ", "),
whereClause,
strings.Join(groupBy, ", "),
)
return template, params, nil
}

View File

@@ -299,6 +299,19 @@ func (db *Database) ApplyMigrations() error {
}
}
bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth (
conn_id TEXT NOT NULL,
time INTEGER NOT NULL,
incoming INTEGER NOT NULL,
outgoing INTEGER NOT NULL,
CONSTRAINT fk_conn_id
FOREIGN KEY(conn_id) REFERENCES connections(id)
ON DELETE CASCADE
)`
if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil {
return fmt.Errorf("failed to create main.bandwidth database: %w", err)
}
return nil
}
@@ -535,21 +548,16 @@ func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database.
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
params := map[string]any{
":id": makeNqIDFromParts(processKey, connID),
}
parts := []string{}
if bytesReceived != 0 {
parts = append(parts, "bytes_received = :bytes_received")
params[":bytes_received"] = bytesReceived
}
if bytesSent != 0 {
parts = append(parts, "bytes_sent = :bytes_sent")
params[":bytes_sent"] = bytesSent
}
parts = append(parts, "bytes_received = (bytes_received + :bytes_received)")
params[":bytes_received"] = bytesReceived
parts = append(parts, "bytes_sent = (bytes_sent + :bytes_sent)")
params[":bytes_sent"] = bytesSent
updateSet := strings.Join(parts, ", ")
@@ -570,6 +578,13 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro
}
}
// also add the date to the in-memory bandwidth database
params[":time"] = time.Now().Unix()
stmt := "INSERT INTO main.bandwidth (conn_id, time, incoming, outgoing) VALUES(:id, :time, :bytes_received, :bytes_sent)"
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err))
}
return merr.ErrorOrNil()
}

View File

@@ -38,7 +38,7 @@ type (
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database.
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
// CleanupHistory deletes data outside of the retention time frame from the history database.
CleanupHistory(ctx context.Context) error

View File

@@ -87,7 +87,11 @@ func (m *module) prepare() error {
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
}
chartHandler := &ChartHandler{
chartHandler := &ActiveChartHandler{
Database: m.Store,
}
bwChartHandler := &BandwidthChartHandler{
Database: m.Store,
}
@@ -129,6 +133,19 @@ func (m *module) prepare() error {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
// TODO: Use query parameters instead.
Path: "netquery/charts/bandwidth",
MimeType: "application/json",
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: bwChartHandler.ServeHTTP,
Name: "Bandwidth Chart",
Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
Name: "Remove connections from profile history",
Description: "Remove all connections from the history database for one or more profiles",
@@ -137,7 +154,6 @@ func (m *module) prepare() error {
Write: api.PermitUser,
BelongsTo: m.Module,
ActionFunc: func(ar *api.Request) (msg string, err error) {
// TODO: Use query parameters instead.
var body struct {
ProfileIDs []string `json:"profileIDs"`
}

View File

@@ -71,12 +71,18 @@ type (
Distinct bool `json:"distinct"`
}
FieldSelect struct {
Field string `json:"field"`
As string `json:"as"`
}
Select struct {
Field string `json:"field"`
Count *Count `json:"$count,omitempty"`
Sum *Sum `json:"$sum,omitempty"`
Min *Min `json:"$min,omitempty"`
Distinct *string `json:"$distinct,omitempty"`
Field string `json:"field"`
FieldSelect *FieldSelect `json:"$field"`
Count *Count `json:"$count,omitempty"`
Sum *Sum `json:"$sum,omitempty"`
Min *Min `json:"$min,omitempty"`
Distinct *string `json:"$distinct,omitempty"`
}
Selects []Select
@@ -449,6 +455,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T
values := query[column]
colDef, ok := lm[column]
if !ok {
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
continue
@@ -538,11 +545,12 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
// directly
if blob[0] == '{' {
var res struct {
Field string `json:"field"`
Count *Count `json:"$count"`
Sum *Sum `json:"$sum"`
Min *Min `json:"$min"`
Distinct *string `json:"$distinct"`
Field string `json:"field"`
Count *Count `json:"$count"`
Sum *Sum `json:"$sum"`
Min *Min `json:"$min"`
Distinct *string `json:"$distinct"`
FieldSelect *FieldSelect `json:"$field"`
}
if err := json.Unmarshal(blob, &res); err != nil {
@@ -551,6 +559,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
sel.Count = res.Count
sel.Field = res.Field
sel.FieldSelect = res.FieldSelect
sel.Distinct = res.Distinct
sel.Sum = res.Sum
sel.Min = res.Min
@@ -570,6 +579,11 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
}
}
if sel.FieldSelect != nil && sel.FieldSelect.As != "" {
if !charOnlyRegexp.MatchString(sel.FieldSelect.As) {
return fmt.Errorf("invalid characters in $field.as, value must match [a-zA-Z]+")
}
}
return nil
}

View File

@@ -126,6 +126,8 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
} else {
field = "*"
}
case s.FieldSelect != nil:
field = s.FieldSelect.Field
default:
field = s.Field
}
@@ -141,6 +143,18 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
}
switch {
case s.FieldSelect != nil:
as := s.FieldSelect.As
if as == "" {
as = s.FieldSelect.Field
}
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("%s AS %s", s.FieldSelect.Field, as),
)
req.whitelistedFields = append(req.whitelistedFields, as)
case s.Count != nil:
as := s.Count.As
if as == "" {