wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
675
service/netquery/query.go
Normal file
675
service/netquery/query.go
Normal file
@@ -0,0 +1,675 @@
|
||||
package netquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/slices"
|
||||
"zombiezen.com/go/sqlite"
|
||||
|
||||
"github.com/safing/portmaster/service/netquery/orm"
|
||||
)
|
||||
|
||||
// DatabaseName is a database name constant.
|
||||
type DatabaseName string
|
||||
|
||||
// Databases.
|
||||
const (
|
||||
LiveDatabase = DatabaseName("main")
|
||||
HistoryDatabase = DatabaseName("history")
|
||||
)
|
||||
|
||||
// Collection of Query and Matcher types.
|
||||
// NOTE: whenever adding support for new operators make sure
|
||||
// to update UnmarshalJSON as well.
|
||||
//
|
||||
//nolint:golint
|
||||
type (
|
||||
Query map[string][]Matcher
|
||||
|
||||
MatchType interface {
|
||||
Operator() string
|
||||
}
|
||||
|
||||
Equal interface{}
|
||||
|
||||
Matcher struct {
|
||||
Equal interface{} `json:"$eq,omitempty"`
|
||||
NotEqual interface{} `json:"$ne,omitempty"`
|
||||
In []interface{} `json:"$in,omitempty"`
|
||||
NotIn []interface{} `json:"$notIn,omitempty"`
|
||||
Like string `json:"$like,omitempty"`
|
||||
Greater *float64 `json:"$gt,omitempty"`
|
||||
GreaterOrEqual *float64 `json:"$ge,omitempty"`
|
||||
Less *float64 `json:"$lt,omitempty"`
|
||||
LessOrEqual *float64 `json:"$le,omitempty"`
|
||||
}
|
||||
|
||||
Count struct {
|
||||
As string `json:"as"`
|
||||
Field string `json:"field"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
Sum struct {
|
||||
Condition Query `json:"condition"`
|
||||
Field string `json:"field"`
|
||||
As string `json:"as"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
Min struct {
|
||||
Condition *Query `json:"condition,omitempty"`
|
||||
Field string `json:"field"`
|
||||
As string `json:"as"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
FieldSelect struct {
|
||||
Field string `json:"field"`
|
||||
As string `json:"as"`
|
||||
}
|
||||
|
||||
Select struct {
|
||||
Field string `json:"field"`
|
||||
FieldSelect *FieldSelect `json:"$field"`
|
||||
Count *Count `json:"$count,omitempty"`
|
||||
Sum *Sum `json:"$sum,omitempty"`
|
||||
Min *Min `json:"$min,omitempty"`
|
||||
Distinct *string `json:"$distinct,omitempty"`
|
||||
}
|
||||
|
||||
Selects []Select
|
||||
|
||||
TextSearch struct {
|
||||
Fields []string `json:"fields"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
QueryActiveConnectionChartPayload struct {
|
||||
Query Query `json:"query"`
|
||||
TextSearch *TextSearch `json:"textSearch"`
|
||||
}
|
||||
|
||||
OrderBy struct {
|
||||
Field string `json:"field"`
|
||||
Desc bool `json:"desc"`
|
||||
}
|
||||
|
||||
OrderBys []OrderBy
|
||||
|
||||
Pagination struct {
|
||||
PageSize int `json:"pageSize"`
|
||||
Page int `json:"page"`
|
||||
}
|
||||
)
|
||||
|
||||
// UnmarshalJSON unmarshals a Query from json.
|
||||
func (query *Query) UnmarshalJSON(blob []byte) error {
|
||||
if *query == nil {
|
||||
*query = make(Query)
|
||||
}
|
||||
|
||||
var model map[string]json.RawMessage
|
||||
|
||||
if err := json.Unmarshal(blob, &model); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for columnName, rawColumnQuery := range model {
|
||||
if len(rawColumnQuery) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch rawColumnQuery[0] {
|
||||
case '{':
|
||||
m, err := parseMatcher(rawColumnQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*query)[columnName] = []Matcher{*m}
|
||||
|
||||
case '[':
|
||||
var rawMatchers []json.RawMessage
|
||||
if err := json.Unmarshal(rawColumnQuery, &rawMatchers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*query)[columnName] = make([]Matcher, len(rawMatchers))
|
||||
for idx, val := range rawMatchers {
|
||||
// this should not happen
|
||||
if len(val) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// if val starts with a { we have a matcher definition
|
||||
if val[0] == '{' {
|
||||
m, err := parseMatcher(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*query)[columnName][idx] = *m
|
||||
|
||||
continue
|
||||
} else if val[0] == '[' {
|
||||
return fmt.Errorf("invalid token [ in query for column %s", columnName)
|
||||
}
|
||||
|
||||
// val is a dedicated JSON primitive and not an object or array
|
||||
// so we treat that as an EQUAL condition.
|
||||
var x interface{}
|
||||
if err := json.Unmarshal(val, &x); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*query)[columnName][idx] = Matcher{
|
||||
Equal: x,
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// value is a JSON primitive and not an object or array
|
||||
// so we treat that as an EQUAL condition.
|
||||
var x interface{}
|
||||
if err := json.Unmarshal(rawColumnQuery, &x); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*query)[columnName] = []Matcher{
|
||||
{Equal: x},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(ppacher): right now we only support LIMIT and OFFSET for pagination but that
|
||||
// has an issue that loading the same page twice might yield different results due to
|
||||
// new records shifting the result slice. To overcome this, return a "PageToken" to the
|
||||
// user that includes the time the initial query was created so paginated queries can
|
||||
// ensure new records don't end up in the result set.
|
||||
func (page *Pagination) toSQLLimitOffsetClause() string {
|
||||
limit := page.PageSize
|
||||
|
||||
// default and cap the limit to at most 100 items
|
||||
// per page to avoid out-of-memory conditions when loading
|
||||
// thousands of results at once.
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf("LIMIT %d", limit)
|
||||
|
||||
if page.Page > 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", page.Page*limit)
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func parseMatcher(raw json.RawMessage) (*Matcher, error) {
|
||||
var m Matcher
|
||||
if err := json.Unmarshal(raw, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid query matcher: %w", err)
|
||||
}
|
||||
|
||||
// log.Printf("parsed matcher %s: %+v", string(raw), m)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// Validate validates the matcher.
|
||||
func (match Matcher) Validate() error {
|
||||
found := 0
|
||||
|
||||
if match.Equal != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.NotEqual != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.In != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.NotIn != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.Like != "" {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.Greater != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.GreaterOrEqual != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.Less != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if match.LessOrEqual != nil {
|
||||
found++
|
||||
}
|
||||
|
||||
if found == 0 {
|
||||
return fmt.Errorf("no conditions specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
var (
|
||||
queryParts = make([]string, 0, len(text.Fields))
|
||||
params = make(map[string]interface{})
|
||||
)
|
||||
|
||||
key := fmt.Sprintf(":t%s", suffix)
|
||||
params[key] = fmt.Sprintf("%%%s%%", text.Value)
|
||||
|
||||
for _, field := range text.Fields {
|
||||
colDef := schema.GetColumnDef(field)
|
||||
if colDef == nil {
|
||||
return "", nil, fmt.Errorf("column %s is not allowed in text-search", colDef.Name)
|
||||
}
|
||||
if colDef.Type != sqlite.TypeText {
|
||||
return "", nil, fmt.Errorf("type of column %s cannot be used in text-search", colDef.Name)
|
||||
}
|
||||
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s LIKE %s", colDef.Name, key))
|
||||
}
|
||||
|
||||
if len(queryParts) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
return "( " + strings.Join(queryParts, " OR ") + " )", params, nil
|
||||
}
|
||||
|
||||
func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
var (
|
||||
queryParts []string
|
||||
params = make(map[string]interface{})
|
||||
errs = new(multierror.Error)
|
||||
key = fmt.Sprintf("%s%s", colDef.Name, suffix)
|
||||
)
|
||||
|
||||
add := func(operator, suffix string, list bool, values ...interface{}) {
|
||||
var placeholder []string
|
||||
|
||||
for idx, value := range values {
|
||||
var (
|
||||
encodedValue any
|
||||
err error
|
||||
)
|
||||
|
||||
kind := orm.NormalizeKind(reflect.TypeOf(value).Kind())
|
||||
isNumber := slices.Contains([]reflect.Kind{
|
||||
reflect.Uint,
|
||||
reflect.Int,
|
||||
reflect.Float64,
|
||||
}, kind)
|
||||
|
||||
// if we query a time-field that is queried as a number, don't do any encoding
|
||||
// here as the orm.DateTimeEncoder would convert the number to a string.
|
||||
if colDef.IsTime && colDef.Type == sqlite.TypeText && isNumber {
|
||||
encodedValue = value
|
||||
} else {
|
||||
encodedValue, err = orm.EncodeValue(ctx, &colDef, value, encoderConfig)
|
||||
if err != nil {
|
||||
errs.Errors = append(errs.Errors,
|
||||
fmt.Errorf("failed to encode %v for column %s: %w", value, colDef.Name, err),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
uniqKey := fmt.Sprintf(":%s%s%d", key, suffix, idx)
|
||||
placeholder = append(placeholder, uniqKey)
|
||||
params[uniqKey] = encodedValue
|
||||
}
|
||||
|
||||
nameStmt := colDef.Name
|
||||
|
||||
if len(values) > 0 {
|
||||
// NOTE(ppacher): for now we assume that the type of each element of values
|
||||
// is the same. We also can be sure that there is always at least one value.
|
||||
//
|
||||
// TODO(ppacher): if we start supporting values of different types here
|
||||
// we need to revisit the whole behavior as we might need to do more boolean
|
||||
// expression nesting to support that.
|
||||
kind := orm.NormalizeKind(reflect.TypeOf(values[0]).Kind())
|
||||
isNumber := slices.Contains([]reflect.Kind{
|
||||
reflect.Uint,
|
||||
reflect.Int,
|
||||
reflect.Float64,
|
||||
}, kind)
|
||||
|
||||
// if this is a time column that is stored in "text" format and the provided
|
||||
// value is a number type, we need to wrap the property in a strftime() method
|
||||
// call.
|
||||
if colDef.IsTime && colDef.Type == sqlite.TypeText && isNumber {
|
||||
nameStmt = fmt.Sprintf("strftime('%%s', %s)+0", nameStmt)
|
||||
}
|
||||
}
|
||||
|
||||
if len(placeholder) == 1 && !list {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", nameStmt, operator, placeholder[0]))
|
||||
} else {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", nameStmt, operator, strings.Join(placeholder, ", ")))
|
||||
}
|
||||
}
|
||||
|
||||
if match.Equal != nil {
|
||||
add("=", "eq", false, match.Equal)
|
||||
}
|
||||
|
||||
if match.NotEqual != nil {
|
||||
add("!=", "ne", false, match.NotEqual)
|
||||
}
|
||||
|
||||
if match.In != nil {
|
||||
add("IN", "in", true, match.In...)
|
||||
}
|
||||
|
||||
if match.NotIn != nil {
|
||||
add("NOT IN", "notin", true, match.NotIn...)
|
||||
}
|
||||
|
||||
if match.Like != "" {
|
||||
add("LIKE", "like", false, match.Like)
|
||||
}
|
||||
|
||||
if match.Greater != nil {
|
||||
add(">", "gt", false, *match.Greater)
|
||||
}
|
||||
|
||||
if match.GreaterOrEqual != nil {
|
||||
add(">=", "ge", false, *match.GreaterOrEqual)
|
||||
}
|
||||
|
||||
if match.Less != nil {
|
||||
add("<", "lt", false, *match.Less)
|
||||
}
|
||||
|
||||
if match.LessOrEqual != nil {
|
||||
add("<=", "le", false, *match.LessOrEqual)
|
||||
}
|
||||
|
||||
if len(queryParts) == 0 {
|
||||
// this is an empty matcher without a single condition.
|
||||
// we convert that to a no-op TRUE value
|
||||
return "( 1 = 1 )", nil, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
if len(queryParts) == 1 {
|
||||
return queryParts[0], params, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
if len(query) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// create a lookup map to validate column names
|
||||
lm := make(map[string]orm.ColumnDef, len(m.Columns))
|
||||
for _, col := range m.Columns {
|
||||
lm[col.Name] = col
|
||||
}
|
||||
|
||||
paramMap := make(map[string]interface{})
|
||||
columnStmts := make([]string, 0, len(query))
|
||||
|
||||
// get all keys and sort them so we get a stable output
|
||||
queryKeys := make([]string, 0, len(query))
|
||||
for column := range query {
|
||||
queryKeys = append(queryKeys, column)
|
||||
}
|
||||
sort.Strings(queryKeys)
|
||||
|
||||
// actually create the WHERE clause parts for each
|
||||
// column in query.
|
||||
errs := new(multierror.Error)
|
||||
for _, column := range queryKeys {
|
||||
values := query[column]
|
||||
colDef, ok := lm[column]
|
||||
if !ok {
|
||||
|
||||
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
queryParts := make([]string, len(values))
|
||||
for idx, val := range values {
|
||||
matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig)
|
||||
if err != nil {
|
||||
errs.Errors = append(errs.Errors,
|
||||
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// merge parameters up into the superior parameter map
|
||||
for key, val := range params {
|
||||
if _, ok := paramMap[key]; ok {
|
||||
// This is solely a developer mistake when implementing a matcher so no forgiving ...
|
||||
panic("sqlite parameter collision")
|
||||
}
|
||||
|
||||
paramMap[key] = val
|
||||
}
|
||||
|
||||
queryParts[idx] = matcherQuery
|
||||
}
|
||||
|
||||
columnStmts = append(columnStmts,
|
||||
fmt.Sprintf("( %s )", strings.Join(queryParts, " OR ")),
|
||||
)
|
||||
}
|
||||
|
||||
whereClause := strings.Join(columnStmts, " AND ")
|
||||
|
||||
return whereClause, paramMap, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a Selects from json.
|
||||
func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// if we are looking at a slice directly decode into
|
||||
// a []Select
|
||||
if blob[0] == '[' {
|
||||
var result []Select
|
||||
if err := json.Unmarshal(blob, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*sel) = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// if it's an object decode into a single select
|
||||
if blob[0] == '{' {
|
||||
var result Select
|
||||
if err := json.Unmarshal(blob, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*sel = []Select{result}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// otherwise this is just the field name
|
||||
var field string
|
||||
if err := json.Unmarshal(blob, &field); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a Select from json.
|
||||
func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// if we have an object at hand decode the select
|
||||
// directly
|
||||
if blob[0] == '{' {
|
||||
var res struct {
|
||||
Field string `json:"field"`
|
||||
Count *Count `json:"$count"`
|
||||
Sum *Sum `json:"$sum"`
|
||||
Min *Min `json:"$min"`
|
||||
Distinct *string `json:"$distinct"`
|
||||
FieldSelect *FieldSelect `json:"$field"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(blob, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sel.Count = res.Count
|
||||
sel.Field = res.Field
|
||||
sel.FieldSelect = res.FieldSelect
|
||||
sel.Distinct = res.Distinct
|
||||
sel.Sum = res.Sum
|
||||
sel.Min = res.Min
|
||||
|
||||
if sel.Count != nil && sel.Count.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Count.As) {
|
||||
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
if sel.Sum != nil && sel.Sum.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Sum.As) {
|
||||
return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
if sel.Min != nil && sel.Min.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Min.As) {
|
||||
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
if sel.FieldSelect != nil && sel.FieldSelect.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.FieldSelect.As) {
|
||||
return fmt.Errorf("invalid characters in $field.as, value must match [a-zA-Z]+")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var x string
|
||||
if err := json.Unmarshal(blob, &x); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sel.Field = x
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a OrderBys from json.
|
||||
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if blob[0] == '[' {
|
||||
var result []OrderBy
|
||||
if err := json.Unmarshal(blob, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*orderBys = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if blob[0] == '{' {
|
||||
var result OrderBy
|
||||
if err := json.Unmarshal(blob, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*orderBys = []OrderBy{result}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var field string
|
||||
if err := json.Unmarshal(blob, &field); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*orderBys = []OrderBy{
|
||||
{
|
||||
Field: field,
|
||||
Desc: false,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a OrderBy from json.
|
||||
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if blob[0] == '{' {
|
||||
var res struct {
|
||||
Field string `json:"field"`
|
||||
Desc bool `json:"desc"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(blob, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
orderBy.Desc = res.Desc
|
||||
orderBy.Field = res.Field
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var field string
|
||||
if err := json.Unmarshal(blob, &field); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
orderBy.Field = field
|
||||
orderBy.Desc = false
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user