Update netquery to support history module

This commit is contained in:
Patrick Pacher
2023-07-19 11:03:11 +02:00
committed by Daniel
parent cf2b8f26b9
commit dbffa8827b
13 changed files with 391 additions and 56 deletions

View File

@@ -14,6 +14,7 @@ import (
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm"
"golang.org/x/exp/slices"
)
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
@@ -152,13 +153,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
return "", nil, fmt.Errorf("generating where clause: %w", err)
}
if req.paramMap == nil {
req.paramMap = make(map[string]interface{})
}
for key, val := range paramMap {
req.paramMap[key] = val
}
req.mergeParams(paramMap)
if req.TextSearch != nil {
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
@@ -173,9 +168,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
whereClause += textClause
for key, val := range textParams {
req.paramMap[key] = val
}
req.mergeParams(textParams)
}
}
@@ -190,15 +183,21 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
}
selectClause := req.generateSelectClause()
inMem := `SELECT *, 'live' as _source FROM main.connections `
inHistory := `SELECT *, 'history' as _source FROM history.connections `
if whereClause != "" {
inMem += " WHERE " + whereClause
inHistory += " WHERE " + whereClause
whereClause = "WHERE " + whereClause
}
source := inMem + " UNION " + inHistory
if len(req.Databases) == 0 {
req.Databases = []DatabaseName{LiveDatabase, HistoryDatabase}
}
sources := make([]string, len(req.Databases))
for idx, db := range req.Databases {
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
}
source := strings.Join(sources, " UNION ")
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
@@ -210,6 +209,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
for idx, s := range req.Select {
var field string
switch {
case s.Count != nil:
field = s.Count.Field
@@ -218,6 +218,12 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
case s.Sum != nil:
// field is not used in case of $sum
field = "*"
case s.Min != nil:
if s.Min.Field != "" {
field = s.Min.Field
} else {
field = "*"
}
default:
field = s.Field
}
@@ -258,13 +264,40 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
return fmt.Errorf("in $sum: %w", err)
}
req.paramMap = params
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
case s.Min != nil:
if s.Min.As == "" {
return fmt.Errorf("missing 'as' for $min")
}
var (
clause string
params map[string]any
)
if s.Min.Field != "" {
clause = field
} else {
var err error
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $min: %w", err)
}
}
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
case s.Distinct != nil:
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
req.whitelistedFields = append(req.whitelistedFields, colName)
@@ -277,6 +310,16 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
return nil
}
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
if req.paramMap == nil {
req.paramMap = make(map[string]any)
}
for key, value := range params {
req.paramMap[key] = value
}
}
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
if len(req.GroupBy) == 0 {
return "", nil
@@ -339,16 +382,12 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
return colDef.Name, nil
}
for _, selected := range req.whitelistedFields {
if field == selected {
return field, nil
}
if slices.Contains(req.whitelistedFields, field) {
return field, nil
}
for _, selected := range req.selectedFields {
if field == selected {
return field, nil
}
if slices.Contains(req.selectedFields, field) {
return field, nil
}
return "", fmt.Errorf("column name %q not allowed", field)