Update netquery to support history module
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
|
||||
@@ -152,13 +153,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||
return "", nil, fmt.Errorf("generating where clause: %w", err)
|
||||
}
|
||||
|
||||
if req.paramMap == nil {
|
||||
req.paramMap = make(map[string]interface{})
|
||||
}
|
||||
|
||||
for key, val := range paramMap {
|
||||
req.paramMap[key] = val
|
||||
}
|
||||
req.mergeParams(paramMap)
|
||||
|
||||
if req.TextSearch != nil {
|
||||
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
|
||||
@@ -173,9 +168,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||
|
||||
whereClause += textClause
|
||||
|
||||
for key, val := range textParams {
|
||||
req.paramMap[key] = val
|
||||
}
|
||||
req.mergeParams(textParams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,15 +183,21 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||
}
|
||||
|
||||
selectClause := req.generateSelectClause()
|
||||
inMem := `SELECT *, 'live' as _source FROM main.connections `
|
||||
inHistory := `SELECT *, 'history' as _source FROM history.connections `
|
||||
|
||||
if whereClause != "" {
|
||||
inMem += " WHERE " + whereClause
|
||||
inHistory += " WHERE " + whereClause
|
||||
whereClause = "WHERE " + whereClause
|
||||
}
|
||||
|
||||
source := inMem + " UNION " + inHistory
|
||||
if len(req.Databases) == 0 {
|
||||
req.Databases = []DatabaseName{LiveDatabase, HistoryDatabase}
|
||||
}
|
||||
|
||||
sources := make([]string, len(req.Databases))
|
||||
for idx, db := range req.Databases {
|
||||
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
|
||||
}
|
||||
|
||||
source := strings.Join(sources, " UNION ")
|
||||
|
||||
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
|
||||
|
||||
@@ -210,6 +209,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
|
||||
for idx, s := range req.Select {
|
||||
var field string
|
||||
|
||||
switch {
|
||||
case s.Count != nil:
|
||||
field = s.Count.Field
|
||||
@@ -218,6 +218,12 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
|
||||
case s.Sum != nil:
|
||||
// field is not used in case of $sum
|
||||
field = "*"
|
||||
case s.Min != nil:
|
||||
if s.Min.Field != "" {
|
||||
field = s.Min.Field
|
||||
} else {
|
||||
field = "*"
|
||||
}
|
||||
default:
|
||||
field = s.Field
|
||||
}
|
||||
@@ -258,13 +264,40 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
|
||||
return fmt.Errorf("in $sum: %w", err)
|
||||
}
|
||||
|
||||
req.paramMap = params
|
||||
req.mergeParams(params)
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
|
||||
|
||||
case s.Min != nil:
|
||||
if s.Min.As == "" {
|
||||
return fmt.Errorf("missing 'as' for $min")
|
||||
}
|
||||
|
||||
var (
|
||||
clause string
|
||||
params map[string]any
|
||||
)
|
||||
|
||||
if s.Min.Field != "" {
|
||||
clause = field
|
||||
} else {
|
||||
var err error
|
||||
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in $min: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req.mergeParams(params)
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
|
||||
|
||||
case s.Distinct != nil:
|
||||
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
|
||||
req.whitelistedFields = append(req.whitelistedFields, colName)
|
||||
@@ -277,6 +310,16 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
|
||||
if req.paramMap == nil {
|
||||
req.paramMap = make(map[string]any)
|
||||
}
|
||||
|
||||
for key, value := range params {
|
||||
req.paramMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
|
||||
if len(req.GroupBy) == 0 {
|
||||
return "", nil
|
||||
@@ -339,16 +382,12 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
|
||||
return colDef.Name, nil
|
||||
}
|
||||
|
||||
for _, selected := range req.whitelistedFields {
|
||||
if field == selected {
|
||||
return field, nil
|
||||
}
|
||||
if slices.Contains(req.whitelistedFields, field) {
|
||||
return field, nil
|
||||
}
|
||||
|
||||
for _, selected := range req.selectedFields {
|
||||
if field == selected {
|
||||
return field, nil
|
||||
}
|
||||
if slices.Contains(req.selectedFields, field) {
|
||||
return field, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("column name %q not allowed", field)
|
||||
|
||||
Reference in New Issue
Block a user