Update bandwidth handling and add basic chart support
This commit is contained in:
145
netquery/bandwidth_chart_handler.go
Normal file
145
netquery/bandwidth_chart_handler.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
)
|
||||
|
||||
// BandwidthChartHandler handles requests for connection charts.
|
||||
type BandwidthChartHandler struct {
|
||||
Database *Database
|
||||
}
|
||||
|
||||
type BandwidthChartRequest struct {
|
||||
Profiles []string `json:"profiles"`
|
||||
Connections []string `json:"connections"`
|
||||
}
|
||||
|
||||
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||
requestPayload, err := ch.parseRequest(req)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// actually execute the query against the database and collect the result
|
||||
var result []map[string]interface{}
|
||||
if err := ch.Database.Execute(
|
||||
req.Context(),
|
||||
query,
|
||||
orm.WithNamedArgs(paramMap),
|
||||
orm.WithResult(&result),
|
||||
orm.WithSchema(*ch.Database.Schema),
|
||||
); err != nil {
|
||||
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// send the HTTP status code
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
|
||||
// prepare the result encoder.
|
||||
enc := json.NewEncoder(resp)
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
|
||||
"results": result,
|
||||
"query": query,
|
||||
"params": paramMap,
|
||||
})
|
||||
}
|
||||
|
||||
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
|
||||
var body io.Reader
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodPost, http.MethodPut:
|
||||
body = req.Body
|
||||
case http.MethodGet:
|
||||
body = strings.NewReader(req.URL.Query().Get("q"))
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid HTTP method")
|
||||
}
|
||||
|
||||
var requestPayload BandwidthChartRequest
|
||||
blob, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body" + err.Error())
|
||||
}
|
||||
|
||||
body = bytes.NewReader(blob)
|
||||
|
||||
dec := json.NewDecoder(body)
|
||||
dec.DisallowUnknownFields()
|
||||
|
||||
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
return &requestPayload, nil
|
||||
}
|
||||
|
||||
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||
selects := []string{
|
||||
"(round(time/10, 0)*10) as time",
|
||||
"SUM(incoming) as incoming",
|
||||
"SUM(outgoing) as outgoing",
|
||||
}
|
||||
groupBy := []string{"round(time/10, 0)*10"}
|
||||
whereClause := ""
|
||||
params := make(map[string]any)
|
||||
|
||||
if len(req.Profiles) > 0 {
|
||||
groupBy = []string{"profile", "round(time/10, 0)*10"}
|
||||
selects = append(selects, "profile")
|
||||
clauses := make([]string, len(req.Profiles))
|
||||
|
||||
for idx, p := range req.Profiles {
|
||||
key := fmt.Sprintf(":p%d", idx)
|
||||
clauses[idx] = "profile = " + key
|
||||
params[key] = p
|
||||
}
|
||||
|
||||
whereClause = "WHERE " + strings.Join(clauses, " OR ")
|
||||
} else if len(req.Connections) > 0 {
|
||||
groupBy = []string{"conn_id", "round(time/10, 0)*10"}
|
||||
selects = append(selects, "conn_id")
|
||||
|
||||
clauses := make([]string, len(req.Connections))
|
||||
|
||||
for idx, p := range req.Connections {
|
||||
key := fmt.Sprintf(":c%d", idx)
|
||||
clauses[idx] = "conn_id = " + key
|
||||
params[key] = p
|
||||
}
|
||||
|
||||
whereClause = "WHERE " + strings.Join(clauses, " OR ")
|
||||
}
|
||||
|
||||
template := fmt.Sprintf(
|
||||
`SELECT %s FROM main.bandwidth %s GROUP BY %s ORDER BY time ASC`,
|
||||
strings.Join(selects, ", "),
|
||||
whereClause,
|
||||
strings.Join(groupBy, ", "),
|
||||
)
|
||||
|
||||
return template, params, nil
|
||||
}
|
||||
Reference in New Issue
Block a user