diff --git a/netquery/bandwidth_chart_handler.go b/netquery/bandwidth_chart_handler.go index 921673d5..e247a830 100644 --- a/netquery/bandwidth_chart_handler.go +++ b/netquery/bandwidth_chart_handler.go @@ -19,9 +19,9 @@ type BandwidthChartHandler struct { } type BandwidthChartRequest struct { - AllProfiles bool `json:"allProfiles"` - Profiles []string `json:"profiles"` - Connections []string `json:"connections"` + Interval int `json:"interval"` + Query Query `json:"query"` + GroupBy []string `json:"groupBy"` } func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { @@ -99,47 +99,45 @@ func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChar } func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { - selects := []string{ - "(round(time/10, 0)*10) as time", + if req.Interval == 0 { + req.Interval = 10 + } + + interval := fmt.Sprintf("round(time/%d, 0)*%d", req.Interval, req.Interval) + + // make sure there are only allowed fields specified in the request group-by + for _, gb := range req.GroupBy { + def := schema.GetColumnDef(gb) + if def == nil { + return "", nil, fmt.Errorf("unsupported groupBy key: %q", gb) + } + } + + selects := append([]string{ + interval + " as timestamp", "SUM(incoming) as incoming", "SUM(outgoing) as outgoing", + }, req.GroupBy...) + + groupBy := append([]string{interval}, req.GroupBy...) + + whereClause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig) + if err != nil { + return "", nil, err } - groupBy := []string{"round(time/10, 0)*10"} - whereClause := "" - params := make(map[string]any) - if (len(req.Profiles) > 0) || (req.AllProfiles == true) { - groupBy = []string{"profile", "round(time/10, 0)*10"} - selects = append(selects, "profile") - - if !req.AllProfiles { - clauses := make([]string, len(req.Profiles)) - - for idx, p := range req.Profiles { - key := fmt.Sprintf(":p%d", idx) - clauses[idx] = "profile = " + key - params[key] = p - } - - whereClause = "WHERE " + strings.Join(clauses, " OR ") - } - } else if len(req.Connections) > 0 { - groupBy = []string{"conn_id", "round(time/10, 0)*10"} - selects = append(selects, "conn_id") - - clauses := make([]string, len(req.Connections)) - - for idx, p := range req.Connections { - key := fmt.Sprintf(":c%d", idx) - clauses[idx] = "conn_id = " + key - params[key] = p - } - - whereClause = "WHERE " + strings.Join(clauses, " OR ") + if whereClause != "" { + whereClause = "WHERE " + whereClause } template := fmt.Sprintf( - `SELECT %s FROM main.bandwidth %s GROUP BY %s ORDER BY time ASC`, + `SELECT %s + FROM main.bandwidth AS bw + JOIN main.connections AS conns + ON bw.conn_id = conns.id + %s + GROUP BY %s + ORDER BY time ASC`, strings.Join(selects, ", "), whereClause, strings.Join(groupBy, ", "), diff --git a/netquery/database.go b/netquery/database.go index 4750dbe4..7adae82b 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -301,7 +301,6 @@ func (db *Database) ApplyMigrations() error { bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth ( conn_id TEXT NOT NULL, - profile TEXT NOT NULL, time INTEGER NOT NULL, incoming INTEGER NOT NULL, outgoing INTEGER NOT NULL, @@ -555,9 +554,9 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro } parts := []string{} - parts = append(parts, "bytes_received = :bytes_received") + parts = append(parts, "bytes_received = (bytes_received + :bytes_received)") params[":bytes_received"] = bytesReceived - parts = append(parts, "bytes_sent = :bytes_sent") + parts = append(parts, "bytes_sent = (bytes_sent + :bytes_sent)") params[":bytes_sent"] = bytesSent updateSet := strings.Join(parts, ", ") @@ -581,8 +580,7 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro // also add the date to the in-memory bandwidth database params[":time"] = time.Now().Unix() - params[":profile"] = profileKey - stmt := "INSERT INTO main.bandwidth (conn_id, profile, time, incoming, outgoing) VALUES(:id, :profile, :time, :bytes_received, :bytes_sent)" + stmt := "INSERT INTO main.bandwidth (conn_id, time, incoming, outgoing) VALUES(:id, :time, :bytes_received, :bytes_sent)" if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil { merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err)) } diff --git a/netquery/query.go b/netquery/query.go index 2ef943bc..2b81bfb1 100644 --- a/netquery/query.go +++ b/netquery/query.go @@ -455,6 +455,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T values := query[column] colDef, ok := lm[column] if !ok { + errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column)) continue