Restructure modules (#1572)
* Move portbase into monorepo * Add new simple module mgr * [WIP] Switch to new simple module mgr * Add StateMgr and more worker variants * [WIP] Switch more modules * [WIP] Switch more modules * [WIP] swtich more modules * [WIP] switch all SPN modules * [WIP] switch all service modules * [WIP] Convert all workers to the new module system * [WIP] add new task system to module manager * [WIP] Add second take for scheduling workers * [WIP] Add FIXME for bugs in new scheduler * [WIP] Add minor improvements to scheduler * [WIP] Add new worker scheduler * [WIP] Fix more bug related to new module system * [WIP] Fix start handing of the new module system * [WIP] Improve startup process * [WIP] Fix minor issues * [WIP] Fix missing subsystem in settings * [WIP] Initialize managers in constructor * [WIP] Move module event initialization to constrictors * [WIP] Fix setting for enabling and disabling the SPN module * [WIP] Move API registeration into module construction * [WIP] Update states mgr for all modules * [WIP] Add CmdLine operation support * Add state helper methods to module group and instance * Add notification and module status handling to status package * Fix starting issues * Remove pilot widget and update security lock to new status data * Remove debug logs * Improve http server shutdown * Add workaround for cleanly shutting down firewall+netquery * Improve logging * Add syncing states with notifications for new module system * Improve starting, stopping, shutdown; resolve FIXMEs/TODOs * [WIP] Fix most unit tests * Review new module system and fix minor issues * Push shutdown and restart events again via API * Set sleep mode via interface * Update example/template module * [WIP] Fix spn/cabin unit test * Remove deprecated UI elements * Make log output more similar for the logging transition phase * Switch spn hub and observer cmds to new module system * Fix log sources * Make worker mgr less error prone * Fix tests and minor issues * Fix observation hub * Improve shutdown and restart handling * Split up big connection.go source file * Move varint and dsd packages to structures repo * Improve expansion test * Fix linter warnings * Fix interception module on windows * Fix linter errors --------- Co-authored-by: Vladimir Stoilov <vladimir@safing.io>
This commit is contained in:
116
base/database/accessor/accessor-json-bytes.go
Normal file
116
base/database/accessor/accessor-json-bytes.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package accessor
|
||||
|
||||
import (
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// JSONBytesAccessor is a json string with get functions.
|
||||
type JSONBytesAccessor struct {
|
||||
json *[]byte
|
||||
}
|
||||
|
||||
// NewJSONBytesAccessor adds the Accessor interface to a JSON bytes string.
|
||||
func NewJSONBytesAccessor(json *[]byte) *JSONBytesAccessor {
|
||||
return &JSONBytesAccessor{
|
||||
json: json,
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value identified by key.
|
||||
func (ja *JSONBytesAccessor) Set(key string, value interface{}) error {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if result.Exists() {
|
||||
err := checkJSONValueType(result, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
newJSON, err := sjson.SetBytes(*ja.json, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*ja.json = newJSON
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the value found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) Get(key string) (value interface{}, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() {
|
||||
return nil, false
|
||||
}
|
||||
return result.Value(), true
|
||||
}
|
||||
|
||||
// GetString returns the string found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetString(key string) (value string, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.String {
|
||||
return emptyString, false
|
||||
}
|
||||
return result.String(), true
|
||||
}
|
||||
|
||||
// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetStringArray(key string) (value []string, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() && !result.IsArray() {
|
||||
return nil, false
|
||||
}
|
||||
slice := result.Array()
|
||||
sliceCopy := make([]string, len(slice))
|
||||
for i, res := range slice {
|
||||
if res.Type == gjson.String {
|
||||
sliceCopy[i] = res.String()
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return sliceCopy, true
|
||||
}
|
||||
|
||||
// GetInt returns the int found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetInt(key string) (value int64, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return result.Int(), true
|
||||
}
|
||||
|
||||
// GetFloat returns the float found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetFloat(key string) (value float64, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return result.Float(), true
|
||||
}
|
||||
|
||||
// GetBool returns the bool found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetBool(key string) (value bool, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
switch {
|
||||
case !result.Exists():
|
||||
return false, false
|
||||
case result.Type == gjson.True:
|
||||
return true, true
|
||||
case result.Type == gjson.False:
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
// Exists returns the whether the given key exists.
|
||||
func (ja *JSONBytesAccessor) Exists(key string) bool {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
return result.Exists()
|
||||
}
|
||||
|
||||
// Type returns the accessor type as a string.
|
||||
func (ja *JSONBytesAccessor) Type() string {
|
||||
return "JSONBytesAccessor"
|
||||
}
|
||||
140
base/database/accessor/accessor-json-string.go
Normal file
140
base/database/accessor/accessor-json-string.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package accessor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// JSONAccessor is a json string with get functions.
|
||||
type JSONAccessor struct {
|
||||
json *string
|
||||
}
|
||||
|
||||
// NewJSONAccessor adds the Accessor interface to a JSON string.
|
||||
func NewJSONAccessor(json *string) *JSONAccessor {
|
||||
return &JSONAccessor{
|
||||
json: json,
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value identified by key.
|
||||
func (ja *JSONAccessor) Set(key string, value interface{}) error {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if result.Exists() {
|
||||
err := checkJSONValueType(result, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
newJSON, err := sjson.Set(*ja.json, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*ja.json = newJSON
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkJSONValueType(jsonValue gjson.Result, key string, value interface{}) error {
|
||||
switch value.(type) {
|
||||
case string:
|
||||
if jsonValue.Type != gjson.String {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
if jsonValue.Type != gjson.Number {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case bool:
|
||||
if jsonValue.Type != gjson.True && jsonValue.Type != gjson.False {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case []string:
|
||||
if !jsonValue.IsArray() {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the value found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) Get(key string) (value interface{}, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() {
|
||||
return nil, false
|
||||
}
|
||||
return result.Value(), true
|
||||
}
|
||||
|
||||
// GetString returns the string found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetString(key string) (value string, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.String {
|
||||
return emptyString, false
|
||||
}
|
||||
return result.String(), true
|
||||
}
|
||||
|
||||
// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetStringArray(key string) (value []string, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() && !result.IsArray() {
|
||||
return nil, false
|
||||
}
|
||||
slice := result.Array()
|
||||
sliceCopy := make([]string, len(slice))
|
||||
for i, res := range slice {
|
||||
if res.Type == gjson.String {
|
||||
sliceCopy[i] = res.String()
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return sliceCopy, true
|
||||
}
|
||||
|
||||
// GetInt returns the int found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetInt(key string) (value int64, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return result.Int(), true
|
||||
}
|
||||
|
||||
// GetFloat returns the float found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetFloat(key string) (value float64, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return result.Float(), true
|
||||
}
|
||||
|
||||
// GetBool returns the bool found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetBool(key string) (value bool, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
switch {
|
||||
case !result.Exists():
|
||||
return false, false
|
||||
case result.Type == gjson.True:
|
||||
return true, true
|
||||
case result.Type == gjson.False:
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
// Exists returns the whether the given key exists.
|
||||
func (ja *JSONAccessor) Exists(key string) bool {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
return result.Exists()
|
||||
}
|
||||
|
||||
// Type returns the accessor type as a string.
|
||||
func (ja *JSONAccessor) Type() string {
|
||||
return "JSONAccessor"
|
||||
}
|
||||
169
base/database/accessor/accessor-struct.go
Normal file
169
base/database/accessor/accessor-struct.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package accessor
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// StructAccessor is a json string with get functions.
|
||||
type StructAccessor struct {
|
||||
object reflect.Value
|
||||
}
|
||||
|
||||
// NewStructAccessor adds the Accessor interface to a JSON string.
|
||||
func NewStructAccessor(object interface{}) *StructAccessor {
|
||||
return &StructAccessor{
|
||||
object: reflect.ValueOf(object).Elem(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value identified by key.
|
||||
func (sa *StructAccessor) Set(key string, value interface{}) error {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() {
|
||||
return errors.New("struct field does not exist")
|
||||
}
|
||||
if !field.CanSet() {
|
||||
return fmt.Errorf("field %s or struct is immutable", field.String())
|
||||
}
|
||||
|
||||
newVal := reflect.ValueOf(value)
|
||||
|
||||
// set directly if type matches
|
||||
if newVal.Kind() == field.Kind() {
|
||||
field.Set(newVal)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handle special cases
|
||||
switch field.Kind() { // nolint:exhaustive
|
||||
|
||||
// ints
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
var newInt int64
|
||||
switch newVal.Kind() { // nolint:exhaustive
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
newInt = newVal.Int()
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
newInt = int64(newVal.Uint())
|
||||
default:
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String())
|
||||
}
|
||||
if field.OverflowInt(newInt) {
|
||||
return fmt.Errorf("setting field %s (%s) to %d would overflow", key, field.Kind().String(), newInt)
|
||||
}
|
||||
field.SetInt(newInt)
|
||||
|
||||
// uints
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
var newUint uint64
|
||||
switch newVal.Kind() { // nolint:exhaustive
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
newUint = uint64(newVal.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
newUint = newVal.Uint()
|
||||
default:
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String())
|
||||
}
|
||||
if field.OverflowUint(newUint) {
|
||||
return fmt.Errorf("setting field %s (%s) to %d would overflow", key, field.Kind().String(), newUint)
|
||||
}
|
||||
field.SetUint(newUint)
|
||||
|
||||
// floats
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch newVal.Kind() { // nolint:exhaustive
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.SetFloat(newVal.Float())
|
||||
default:
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String())
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the value found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) Get(key string) (value interface{}, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() || !field.CanInterface() {
|
||||
return nil, false
|
||||
}
|
||||
return field.Interface(), true
|
||||
}
|
||||
|
||||
// GetString returns the string found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetString(key string) (value string, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() || field.Kind() != reflect.String {
|
||||
return "", false
|
||||
}
|
||||
return field.String(), true
|
||||
}
|
||||
|
||||
// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetStringArray(key string) (value []string, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() || field.Kind() != reflect.Slice || !field.CanInterface() {
|
||||
return nil, false
|
||||
}
|
||||
v := field.Interface()
|
||||
slice, ok := v.([]string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return slice, true
|
||||
}
|
||||
|
||||
// GetInt returns the int found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetInt(key string) (value int64, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() {
|
||||
return 0, false
|
||||
}
|
||||
switch field.Kind() { // nolint:exhaustive
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return field.Int(), true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return int64(field.Uint()), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// GetFloat returns the float found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetFloat(key string) (value float64, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() {
|
||||
return 0, false
|
||||
}
|
||||
switch field.Kind() { // nolint:exhaustive
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return field.Float(), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// GetBool returns the bool found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetBool(key string) (value bool, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() || field.Kind() != reflect.Bool {
|
||||
return false, false
|
||||
}
|
||||
return field.Bool(), true
|
||||
}
|
||||
|
||||
// Exists returns the whether the given key exists.
|
||||
func (sa *StructAccessor) Exists(key string) bool {
|
||||
field := sa.object.FieldByName(key)
|
||||
return field.IsValid()
|
||||
}
|
||||
|
||||
// Type returns the accessor type as a string.
|
||||
func (sa *StructAccessor) Type() string {
|
||||
return "StructAccessor"
|
||||
}
|
||||
18
base/database/accessor/accessor.go
Normal file
18
base/database/accessor/accessor.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package accessor
|
||||
|
||||
const (
|
||||
emptyString = ""
|
||||
)
|
||||
|
||||
// Accessor provides an interface to supply the query matcher a method to retrieve values from an object.
|
||||
type Accessor interface {
|
||||
Get(key string) (value interface{}, ok bool)
|
||||
GetString(key string) (value string, ok bool)
|
||||
GetStringArray(key string) (value []string, ok bool)
|
||||
GetInt(key string) (value int64, ok bool)
|
||||
GetFloat(key string) (value float64, ok bool)
|
||||
GetBool(key string) (value bool, ok bool)
|
||||
Exists(key string) bool
|
||||
Set(key string, value interface{}) error
|
||||
Type() string
|
||||
}
|
||||
291
base/database/accessor/accessor_test.go
Normal file
291
base/database/accessor/accessor_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
//nolint:maligned,unparam
|
||||
package accessor
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
)
|
||||
|
||||
type TestStruct struct {
|
||||
S string
|
||||
A []string
|
||||
I int
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
UI uint
|
||||
UI8 uint8
|
||||
UI16 uint16
|
||||
UI32 uint32
|
||||
UI64 uint64
|
||||
F32 float32
|
||||
F64 float64
|
||||
B bool
|
||||
}
|
||||
|
||||
var (
|
||||
testStruct = &TestStruct{
|
||||
S: "banana",
|
||||
A: []string{"black", "white"},
|
||||
I: 42,
|
||||
I8: 42,
|
||||
I16: 42,
|
||||
I32: 42,
|
||||
I64: 42,
|
||||
UI: 42,
|
||||
UI8: 42,
|
||||
UI16: 42,
|
||||
UI32: 42,
|
||||
UI64: 42,
|
||||
F32: 42.42,
|
||||
F64: 42.42,
|
||||
B: true,
|
||||
}
|
||||
testJSONBytes, _ = json.Marshal(testStruct) //nolint:errchkjson
|
||||
testJSON = string(testJSONBytes)
|
||||
)
|
||||
|
||||
func testGetString(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue string) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetString(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get string with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get string with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if v != expectedValue {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetStringArray(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue []string) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetStringArray(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get []string with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get []string with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if !utils.StringSliceEqual(v, expectedValue) {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetInt(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue int64) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetInt(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get int with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get int with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if v != expectedValue {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetFloat(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue float64) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetFloat(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get float with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get float with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if int64(v) != int64(expectedValue) {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetBool(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue bool) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetBool(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get bool with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get bool with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if v != expectedValue {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testExists(t *testing.T, acc Accessor, key string, shouldSucceed bool) {
|
||||
t.Helper()
|
||||
|
||||
ok := acc.Exists(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s should report key %s as existing", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should report key %s as non-existing", acc.Type(), key)
|
||||
}
|
||||
}
|
||||
|
||||
func testSet(t *testing.T, acc Accessor, key string, shouldSucceed bool, valueToSet interface{}) {
|
||||
t.Helper()
|
||||
|
||||
err := acc.Set(key, valueToSet)
|
||||
switch {
|
||||
case err != nil && shouldSucceed:
|
||||
t.Errorf("%s failed to set %s to %+v: %s", acc.Type(), key, valueToSet, err)
|
||||
case err == nil && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to set %s to %+v", acc.Type(), key, valueToSet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test interface compliance.
|
||||
accs := []Accessor{
|
||||
NewJSONAccessor(&testJSON),
|
||||
NewJSONBytesAccessor(&testJSONBytes),
|
||||
NewStructAccessor(testStruct),
|
||||
}
|
||||
|
||||
// get
|
||||
for _, acc := range accs {
|
||||
testGetString(t, acc, "S", true, "banana")
|
||||
testGetStringArray(t, acc, "A", true, []string{"black", "white"})
|
||||
testGetInt(t, acc, "I", true, 42)
|
||||
testGetInt(t, acc, "I8", true, 42)
|
||||
testGetInt(t, acc, "I16", true, 42)
|
||||
testGetInt(t, acc, "I32", true, 42)
|
||||
testGetInt(t, acc, "I64", true, 42)
|
||||
testGetInt(t, acc, "UI", true, 42)
|
||||
testGetInt(t, acc, "UI8", true, 42)
|
||||
testGetInt(t, acc, "UI16", true, 42)
|
||||
testGetInt(t, acc, "UI32", true, 42)
|
||||
testGetInt(t, acc, "UI64", true, 42)
|
||||
testGetFloat(t, acc, "F32", true, 42.42)
|
||||
testGetFloat(t, acc, "F64", true, 42.42)
|
||||
testGetBool(t, acc, "B", true, true)
|
||||
}
|
||||
|
||||
// set
|
||||
for _, acc := range accs {
|
||||
testSet(t, acc, "S", true, "coconut")
|
||||
testSet(t, acc, "A", true, []string{"green", "blue"})
|
||||
testSet(t, acc, "I", true, uint32(44))
|
||||
testSet(t, acc, "I8", true, uint64(44))
|
||||
testSet(t, acc, "I16", true, uint8(44))
|
||||
testSet(t, acc, "I32", true, uint16(44))
|
||||
testSet(t, acc, "I64", true, 44)
|
||||
testSet(t, acc, "UI", true, 44)
|
||||
testSet(t, acc, "UI8", true, int64(44))
|
||||
testSet(t, acc, "UI16", true, int32(44))
|
||||
testSet(t, acc, "UI32", true, int8(44))
|
||||
testSet(t, acc, "UI64", true, int16(44))
|
||||
testSet(t, acc, "F32", true, 44.44)
|
||||
testSet(t, acc, "F64", true, 44.44)
|
||||
testSet(t, acc, "B", true, false)
|
||||
}
|
||||
|
||||
// get again to check if new values were set
|
||||
for _, acc := range accs {
|
||||
testGetString(t, acc, "S", true, "coconut")
|
||||
testGetStringArray(t, acc, "A", true, []string{"green", "blue"})
|
||||
testGetInt(t, acc, "I", true, 44)
|
||||
testGetInt(t, acc, "I8", true, 44)
|
||||
testGetInt(t, acc, "I16", true, 44)
|
||||
testGetInt(t, acc, "I32", true, 44)
|
||||
testGetInt(t, acc, "I64", true, 44)
|
||||
testGetInt(t, acc, "UI", true, 44)
|
||||
testGetInt(t, acc, "UI8", true, 44)
|
||||
testGetInt(t, acc, "UI16", true, 44)
|
||||
testGetInt(t, acc, "UI32", true, 44)
|
||||
testGetInt(t, acc, "UI64", true, 44)
|
||||
testGetFloat(t, acc, "F32", true, 44.44)
|
||||
testGetFloat(t, acc, "F64", true, 44.44)
|
||||
testGetBool(t, acc, "B", true, false)
|
||||
}
|
||||
|
||||
// failures
|
||||
for _, acc := range accs {
|
||||
testSet(t, acc, "S", false, true)
|
||||
testSet(t, acc, "S", false, false)
|
||||
testSet(t, acc, "S", false, 1)
|
||||
testSet(t, acc, "S", false, 1.1)
|
||||
|
||||
testSet(t, acc, "A", false, "1")
|
||||
testSet(t, acc, "A", false, true)
|
||||
testSet(t, acc, "A", false, false)
|
||||
testSet(t, acc, "A", false, 1)
|
||||
testSet(t, acc, "A", false, 1.1)
|
||||
|
||||
testSet(t, acc, "I", false, "1")
|
||||
testSet(t, acc, "I8", false, "1")
|
||||
testSet(t, acc, "I16", false, "1")
|
||||
testSet(t, acc, "I32", false, "1")
|
||||
testSet(t, acc, "I64", false, "1")
|
||||
testSet(t, acc, "UI", false, "1")
|
||||
testSet(t, acc, "UI8", false, "1")
|
||||
testSet(t, acc, "UI16", false, "1")
|
||||
testSet(t, acc, "UI32", false, "1")
|
||||
testSet(t, acc, "UI64", false, "1")
|
||||
|
||||
testSet(t, acc, "F32", false, "1.1")
|
||||
testSet(t, acc, "F64", false, "1.1")
|
||||
|
||||
testSet(t, acc, "B", false, "false")
|
||||
testSet(t, acc, "B", false, 1)
|
||||
testSet(t, acc, "B", false, 1.1)
|
||||
}
|
||||
|
||||
// get again to check if values werent changed when an error occurred
|
||||
for _, acc := range accs {
|
||||
testGetString(t, acc, "S", true, "coconut")
|
||||
testGetStringArray(t, acc, "A", true, []string{"green", "blue"})
|
||||
testGetInt(t, acc, "I", true, 44)
|
||||
testGetInt(t, acc, "I8", true, 44)
|
||||
testGetInt(t, acc, "I16", true, 44)
|
||||
testGetInt(t, acc, "I32", true, 44)
|
||||
testGetInt(t, acc, "I64", true, 44)
|
||||
testGetInt(t, acc, "UI", true, 44)
|
||||
testGetInt(t, acc, "UI8", true, 44)
|
||||
testGetInt(t, acc, "UI16", true, 44)
|
||||
testGetInt(t, acc, "UI32", true, 44)
|
||||
testGetInt(t, acc, "UI64", true, 44)
|
||||
testGetFloat(t, acc, "F32", true, 44.44)
|
||||
testGetFloat(t, acc, "F64", true, 44.44)
|
||||
testGetBool(t, acc, "B", true, false)
|
||||
}
|
||||
|
||||
// test existence
|
||||
for _, acc := range accs {
|
||||
testExists(t, acc, "S", true)
|
||||
testExists(t, acc, "A", true)
|
||||
testExists(t, acc, "I", true)
|
||||
testExists(t, acc, "I8", true)
|
||||
testExists(t, acc, "I16", true)
|
||||
testExists(t, acc, "I32", true)
|
||||
testExists(t, acc, "I64", true)
|
||||
testExists(t, acc, "UI", true)
|
||||
testExists(t, acc, "UI8", true)
|
||||
testExists(t, acc, "UI16", true)
|
||||
testExists(t, acc, "UI32", true)
|
||||
testExists(t, acc, "UI64", true)
|
||||
testExists(t, acc, "F32", true)
|
||||
testExists(t, acc, "F64", true)
|
||||
testExists(t, acc, "B", true)
|
||||
}
|
||||
|
||||
// test non-existence
|
||||
for _, acc := range accs {
|
||||
testExists(t, acc, "X", false)
|
||||
}
|
||||
}
|
||||
65
base/database/boilerplate_test.go
Normal file
65
base/database/boilerplate_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
type Example struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Name string
|
||||
Score int
|
||||
}
|
||||
|
||||
var exampleDB = NewInterface(&Options{
|
||||
Internal: true,
|
||||
Local: true,
|
||||
})
|
||||
|
||||
// GetExample gets an Example from the database.
|
||||
func GetExample(key string) (*Example, error) {
|
||||
r, err := exampleDB.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// unwrap
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
newExample := &Example{}
|
||||
err = record.Unwrap(r, newExample)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newExample, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
newExample, ok := r.(*Example)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *Example, but %T", r)
|
||||
}
|
||||
return newExample, nil
|
||||
}
|
||||
|
||||
func (e *Example) Save() error {
|
||||
return exampleDB.Put(e)
|
||||
}
|
||||
|
||||
func (e *Example) SaveAs(key string) error {
|
||||
e.SetKey(key)
|
||||
return exampleDB.PutNew(e)
|
||||
}
|
||||
|
||||
func NewExample(key, name string, score int) *Example {
|
||||
newExample := &Example{
|
||||
Name: name,
|
||||
Score: score,
|
||||
}
|
||||
newExample.SetKey(key)
|
||||
return newExample
|
||||
}
|
||||
355
base/database/controller.go
Normal file
355
base/database/controller.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
// A Controller takes care of all the extra database logic.
|
||||
type Controller struct {
|
||||
database *Database
|
||||
storage storage.Interface
|
||||
shadowDelete bool
|
||||
|
||||
hooksLock sync.RWMutex
|
||||
hooks []*RegisteredHook
|
||||
|
||||
subscriptionLock sync.RWMutex
|
||||
subscriptions []*Subscription
|
||||
}
|
||||
|
||||
// newController creates a new controller for a storage.
|
||||
func newController(database *Database, storageInt storage.Interface, shadowDelete bool) *Controller {
|
||||
return &Controller{
|
||||
database: database,
|
||||
storage: storageInt,
|
||||
shadowDelete: shadowDelete,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the storage is read only.
|
||||
func (c *Controller) ReadOnly() bool {
|
||||
return c.storage.ReadOnly()
|
||||
}
|
||||
|
||||
// Injected returns whether the storage is injected.
|
||||
func (c *Controller) Injected() bool {
|
||||
return c.storage.Injected()
|
||||
}
|
||||
|
||||
// Get returns the record with the given key.
|
||||
func (c *Controller) Get(key string) (record.Record, error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
if err := c.runPreGetHooks(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := c.storage.Get(key)
|
||||
if err != nil {
|
||||
// replace not found error
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
r, err = c.runPostGetHooks(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !r.Meta().CheckValidity() {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// GetMeta returns the metadata of the record with the given key.
|
||||
func (c *Controller) GetMeta(key string) (*record.Meta, error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
var m *record.Meta
|
||||
var err error
|
||||
if metaDB, ok := c.storage.(storage.MetaHandler); ok {
|
||||
m, err = metaDB.GetMeta(key)
|
||||
if err != nil {
|
||||
// replace not found error
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
r, err := c.storage.Get(key)
|
||||
if err != nil {
|
||||
// replace not found error
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
m = r.Meta()
|
||||
}
|
||||
|
||||
if !m.CheckValidity() {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Put saves a record in the database, executes any registered
|
||||
// pre-put hooks and finally send an update to all subscribers.
|
||||
// The record must be locked and secured from concurrent access
|
||||
// when calling Put().
|
||||
func (c *Controller) Put(r record.Record) (err error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
|
||||
if c.ReadOnly() {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
r, err = c.runPrePutHooks(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !c.shadowDelete && r.Meta().IsDeleted() {
|
||||
// Immediate delete.
|
||||
err = c.storage.Delete(r.DatabaseKey())
|
||||
} else {
|
||||
// Put or shadow delete.
|
||||
r, err = c.storage.Put(r)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
return errors.New("storage returned nil record after successful put operation")
|
||||
}
|
||||
|
||||
c.notifySubscribers(r)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database. It does not
|
||||
// process any hooks or update subscriptions. Use with care!
|
||||
func (c *Controller) PutMany() (chan<- record.Record, <-chan error) {
|
||||
if shuttingDown.IsSet() {
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrShuttingDown
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
if c.ReadOnly() {
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrReadOnly
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
if batcher, ok := c.storage.(storage.Batcher); ok {
|
||||
return batcher.PutMany(c.shadowDelete)
|
||||
}
|
||||
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrNotImplemented
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
// Query executes the given query on the database.
|
||||
func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
it, err := c.storage.Query(q, local, internal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return it, nil
|
||||
}
|
||||
|
||||
// PushUpdate pushes a record update to subscribers.
|
||||
// The caller must hold the record's lock when calling
|
||||
// PushUpdate.
|
||||
func (c *Controller) PushUpdate(r record.Record) {
|
||||
if c != nil {
|
||||
if shuttingDown.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
c.notifySubscribers(r)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) addSubscription(sub *Subscription) {
|
||||
if shuttingDown.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
c.subscriptionLock.Lock()
|
||||
defer c.subscriptionLock.Unlock()
|
||||
|
||||
c.subscriptions = append(c.subscriptions, sub)
|
||||
}
|
||||
|
||||
// Maintain runs the Maintain method on the storage.
|
||||
func (c *Controller) Maintain(ctx context.Context) error {
|
||||
if shuttingDown.IsSet() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
|
||||
if maintainer, ok := c.storage.(storage.Maintainer); ok {
|
||||
return maintainer.Maintain(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainThorough runs the MaintainThorough method on the
|
||||
// storage.
|
||||
func (c *Controller) MaintainThorough(ctx context.Context) error {
|
||||
if shuttingDown.IsSet() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
|
||||
if maintainer, ok := c.storage.(storage.Maintainer); ok {
|
||||
return maintainer.MaintainThorough(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainRecordStates runs the record state lifecycle
|
||||
// maintenance on the storage.
|
||||
func (c *Controller) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time) error {
|
||||
if shuttingDown.IsSet() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
|
||||
return c.storage.MaintainRecordStates(ctx, purgeDeletedBefore, c.shadowDelete)
|
||||
}
|
||||
|
||||
// Purge deletes all records that match the given query.
|
||||
// It returns the number of successful deletes and an error.
|
||||
func (c *Controller) Purge(ctx context.Context, q *query.Query, local, internal bool) (int, error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return 0, ErrShuttingDown
|
||||
}
|
||||
|
||||
if purger, ok := c.storage.(storage.Purger); ok {
|
||||
return purger.Purge(ctx, q, local, internal, c.shadowDelete)
|
||||
}
|
||||
|
||||
return 0, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Shutdown shuts down the storage.
|
||||
func (c *Controller) Shutdown() error {
|
||||
return c.storage.Shutdown()
|
||||
}
|
||||
|
||||
// notifySubscribers notifies all subscribers that are interested
|
||||
// in r. r must be locked when calling notifySubscribers.
|
||||
// Any subscriber that is not blocking on it's feed channel will
|
||||
// be skipped.
|
||||
func (c *Controller) notifySubscribers(r record.Record) {
|
||||
c.subscriptionLock.RLock()
|
||||
defer c.subscriptionLock.RUnlock()
|
||||
|
||||
for _, sub := range c.subscriptions {
|
||||
if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) {
|
||||
select {
|
||||
case sub.Feed <- r:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) runPreGetHooks(key string) error {
|
||||
c.hooksLock.RLock()
|
||||
defer c.hooksLock.RUnlock()
|
||||
|
||||
for _, hook := range c.hooks {
|
||||
if !hook.h.UsesPreGet() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hook.q.MatchesKey(key) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := hook.h.PreGet(key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) runPostGetHooks(r record.Record) (record.Record, error) {
|
||||
c.hooksLock.RLock()
|
||||
defer c.hooksLock.RUnlock()
|
||||
|
||||
var err error
|
||||
for _, hook := range c.hooks {
|
||||
if !hook.h.UsesPostGet() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hook.q.Matches(r) {
|
||||
continue
|
||||
}
|
||||
|
||||
r, err = hook.h.PostGet(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Controller) runPrePutHooks(r record.Record) (record.Record, error) {
|
||||
c.hooksLock.RLock()
|
||||
defer c.hooksLock.RUnlock()
|
||||
|
||||
var err error
|
||||
for _, hook := range c.hooks {
|
||||
if !hook.h.UsesPrePut() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hook.q.Matches(r) {
|
||||
continue
|
||||
}
|
||||
|
||||
r, err = hook.h.PrePut(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
106
base/database/controllers.go
Normal file
106
base/database/controllers.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
// StorageTypeInjected is the type of injected databases.
|
||||
const StorageTypeInjected = "injected"
|
||||
|
||||
var (
|
||||
controllers = make(map[string]*Controller)
|
||||
controllersLock sync.RWMutex
|
||||
)
|
||||
|
||||
func getController(name string) (*Controller, error) {
|
||||
if !initialized.IsSet() {
|
||||
return nil, errors.New("database not initialized")
|
||||
}
|
||||
|
||||
// return database if already started
|
||||
controllersLock.RLock()
|
||||
controller, ok := controllers[name]
|
||||
controllersLock.RUnlock()
|
||||
if ok {
|
||||
return controller, nil
|
||||
}
|
||||
|
||||
controllersLock.Lock()
|
||||
defer controllersLock.Unlock()
|
||||
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
// get db registration
|
||||
registeredDB, err := getDatabase(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not start database %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Check if database is injected.
|
||||
if registeredDB.StorageType == StorageTypeInjected {
|
||||
return nil, fmt.Errorf("database storage is not injected")
|
||||
}
|
||||
|
||||
// get location
|
||||
dbLocation, err := getLocation(name, registeredDB.StorageType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not start database %s (type %s): %w", name, registeredDB.StorageType, err)
|
||||
}
|
||||
|
||||
// start database
|
||||
storageInt, err := storage.StartDatabase(name, registeredDB.StorageType, dbLocation)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not start database %s (type %s): %w", name, registeredDB.StorageType, err)
|
||||
}
|
||||
|
||||
controller = newController(registeredDB, storageInt, registeredDB.ShadowDelete)
|
||||
controllers[name] = controller
|
||||
return controller, nil
|
||||
}
|
||||
|
||||
// InjectDatabase injects an already running database into the system.
|
||||
func InjectDatabase(name string, storageInt storage.Interface) (*Controller, error) {
|
||||
controllersLock.Lock()
|
||||
defer controllersLock.Unlock()
|
||||
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
_, ok := controllers[name]
|
||||
if ok {
|
||||
return nil, fmt.Errorf(`database "%s" already loaded`, name)
|
||||
}
|
||||
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
// check if database is registered
|
||||
registeredDB, ok := registry[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("database %q not registered", name)
|
||||
}
|
||||
if registeredDB.StorageType != StorageTypeInjected {
|
||||
return nil, fmt.Errorf("database not of type %q", StorageTypeInjected)
|
||||
}
|
||||
|
||||
controller := newController(registeredDB, storageInt, false)
|
||||
controllers[name] = controller
|
||||
return controller, nil
|
||||
}
|
||||
|
||||
// Withdraw withdraws an injected database, but leaves the database registered.
|
||||
func (c *Controller) Withdraw() {
|
||||
if c != nil && c.Injected() {
|
||||
controllersLock.Lock()
|
||||
defer controllersLock.Unlock()
|
||||
|
||||
delete(controllers, c.database.Name)
|
||||
}
|
||||
}
|
||||
26
base/database/database.go
Normal file
26
base/database/database.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Database holds information about a registered database.
|
||||
type Database struct {
|
||||
Name string
|
||||
Description string
|
||||
StorageType string
|
||||
ShadowDelete bool // Whether deleted records should be kept until purged.
|
||||
Registered time.Time
|
||||
LastUpdated time.Time
|
||||
LastLoaded time.Time
|
||||
}
|
||||
|
||||
// Loaded updates the LastLoaded timestamp.
|
||||
func (db *Database) Loaded() {
|
||||
db.LastLoaded = time.Now().Round(time.Second)
|
||||
}
|
||||
|
||||
// Updated updates the LastUpdated timestamp.
|
||||
func (db *Database) Updated() {
|
||||
db.LastUpdated = time.Now().Round(time.Second)
|
||||
}
|
||||
303
base/database/database_test.go
Normal file
303
base/database/database_test.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime/pprof"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
q "github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
_ "github.com/safing/portmaster/base/database/storage/badger"
|
||||
_ "github.com/safing/portmaster/base/database/storage/bbolt"
|
||||
_ "github.com/safing/portmaster/base/database/storage/fstree"
|
||||
_ "github.com/safing/portmaster/base/database/storage/hashmap"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testDir, err := os.MkdirTemp("", "portbase-database-testing-")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = InitializeWithPath(testDir)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
// Clean up the test directory.
|
||||
// Do not defer, as we end this function with a os.Exit call.
|
||||
_ = os.RemoveAll(testDir)
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func makeKey(dbName, key string) string {
|
||||
return fmt.Sprintf("%s:%s", dbName, key)
|
||||
}
|
||||
|
||||
func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolint:maintidx,thelper
|
||||
t.Run(fmt.Sprintf("TestStorage_%s_%v", storageType, shadowDelete), func(t *testing.T) {
|
||||
dbName := fmt.Sprintf("testing-%s-%v", storageType, shadowDelete)
|
||||
fmt.Println(dbName)
|
||||
_, err := Register(&Database{
|
||||
Name: dbName,
|
||||
Description: fmt.Sprintf("Unit Test Database for %s", storageType),
|
||||
StorageType: storageType,
|
||||
ShadowDelete: shadowDelete,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dbController, err := getController(dbName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// hook
|
||||
hook, err := RegisterHook(q.New(dbName).MustBeValid(), &HookBase{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// interface
|
||||
db := NewInterface(&Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
// sub
|
||||
sub, err := db.Subscribe(q.New(dbName).MustBeValid())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
A := NewExample(dbName+":A", "Herbert", 411)
|
||||
err = A.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
B := NewExample(makeKey(dbName, "B"), "Fritz", 347)
|
||||
err = B.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
C := NewExample(makeKey(dbName, "C"), "Norbert", 217)
|
||||
err = C.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
exists, err := db.Exists(makeKey(dbName, "A"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("record %s should exist!", makeKey(dbName, "A"))
|
||||
}
|
||||
|
||||
A1, err := GetExample(makeKey(dbName, "A"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(A, A1) {
|
||||
log.Fatalf("A and A1 mismatch, A1: %v", A1)
|
||||
}
|
||||
|
||||
cnt := countRecords(t, db, q.New(dbName).Where(
|
||||
q.And(
|
||||
q.Where("Name", q.EndsWith, "bert"),
|
||||
q.Where("Score", q.GreaterThan, 100),
|
||||
),
|
||||
))
|
||||
if cnt != 2 {
|
||||
t.Fatalf("expected two records, got %d", cnt)
|
||||
}
|
||||
|
||||
// test putmany
|
||||
if _, ok := dbController.storage.(storage.Batcher); ok {
|
||||
batchPut := db.PutMany(dbName)
|
||||
records := []record.Record{A, B, C, nil} // nil is to signify finish
|
||||
for _, r := range records {
|
||||
err = batchPut(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// test maintenance
|
||||
if _, ok := dbController.storage.(storage.Maintainer); ok {
|
||||
now := time.Now().UTC()
|
||||
nowUnix := now.Unix()
|
||||
|
||||
// we start with 3 records without expiry
|
||||
cnt := countRecords(t, db, q.New(dbName))
|
||||
if cnt != 3 {
|
||||
t.Fatalf("expected three records, got %d", cnt)
|
||||
}
|
||||
// delete entry
|
||||
A.Meta().Deleted = nowUnix - 61
|
||||
err = A.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// expire entry
|
||||
B.Meta().Expires = nowUnix - 1
|
||||
err = B.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// one left
|
||||
cnt = countRecords(t, db, q.New(dbName))
|
||||
if cnt != 1 {
|
||||
t.Fatalf("expected one record, got %d", cnt)
|
||||
}
|
||||
|
||||
// run maintenance
|
||||
err = dbController.MaintainRecordStates(context.TODO(), now.Add(-60*time.Second))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// one left
|
||||
cnt = countRecords(t, db, q.New(dbName))
|
||||
if cnt != 1 {
|
||||
t.Fatalf("expected one record, got %d", cnt)
|
||||
}
|
||||
|
||||
// check status individually
|
||||
_, err = dbController.storage.Get("A")
|
||||
if !errors.Is(err, storage.ErrNotFound) {
|
||||
t.Errorf("A should be deleted and purged, err=%s", err)
|
||||
}
|
||||
B1, err := dbController.storage.Get("B")
|
||||
if err != nil {
|
||||
t.Fatalf("should exist: %s, original meta: %+v", err, B.Meta())
|
||||
}
|
||||
if B1.Meta().Deleted == 0 {
|
||||
t.Errorf("B should be deleted")
|
||||
}
|
||||
|
||||
// delete last entry
|
||||
C.Meta().Deleted = nowUnix - 1
|
||||
err = C.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// run maintenance
|
||||
err = dbController.MaintainRecordStates(context.TODO(), now)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check status individually
|
||||
B2, err := dbController.storage.Get("B")
|
||||
if err == nil {
|
||||
t.Errorf("B should be deleted and purged, meta: %+v", B2.Meta())
|
||||
} else if !errors.Is(err, storage.ErrNotFound) {
|
||||
t.Errorf("B should be deleted and purged, err=%s", err)
|
||||
}
|
||||
C2, err := dbController.storage.Get("C")
|
||||
if err == nil {
|
||||
t.Errorf("C should be deleted and purged, meta: %+v", C2.Meta())
|
||||
} else if !errors.Is(err, storage.ErrNotFound) {
|
||||
t.Errorf("C should be deleted and purged, err=%s", err)
|
||||
}
|
||||
|
||||
// none left
|
||||
cnt = countRecords(t, db, q.New(dbName))
|
||||
if cnt != 0 {
|
||||
t.Fatalf("expected no records, got %d", cnt)
|
||||
}
|
||||
}
|
||||
|
||||
err = hook.Cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = sub.Cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabaseSystem(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
// panic after 10 seconds, to check for locks
|
||||
finished := make(chan struct{})
|
||||
defer close(finished)
|
||||
go func() {
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(10 * time.Second):
|
||||
fmt.Println("===== TAKING TOO LONG - PRINTING STACK TRACES =====")
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, shadowDelete := range []bool{false, true} {
|
||||
testDatabase(t, "bbolt", shadowDelete)
|
||||
testDatabase(t, "hashmap", shadowDelete)
|
||||
testDatabase(t, "fstree", shadowDelete)
|
||||
// testDatabase(t, "badger", shadowDelete)
|
||||
// TODO: Fix badger tests
|
||||
}
|
||||
|
||||
err := MaintainRecordStates(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = Maintain(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = MaintainThorough(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = Shutdown()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func countRecords(t *testing.T, db *Interface, query *q.Query) int {
|
||||
t.Helper()
|
||||
|
||||
_, err := query.Check()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it, err := db.Query(query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cnt := 0
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
t.Fatal(it.Err())
|
||||
}
|
||||
return cnt
|
||||
}
|
||||
84
base/database/dbmodule/db.go
Normal file
84
base/database/dbmodule/db.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package dbmodule
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/dataroot"
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
type DBModule struct {
|
||||
mgr *mgr.Manager
|
||||
instance instance
|
||||
}
|
||||
|
||||
func (dbm *DBModule) Manager() *mgr.Manager {
|
||||
return dbm.mgr
|
||||
}
|
||||
|
||||
func (dbm *DBModule) Start() error {
|
||||
return start()
|
||||
}
|
||||
|
||||
func (dbm *DBModule) Stop() error {
|
||||
return stop()
|
||||
}
|
||||
|
||||
var databaseStructureRoot *utils.DirStructure
|
||||
|
||||
// SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure.
|
||||
func SetDatabaseLocation(dirStructureRoot *utils.DirStructure) {
|
||||
if databaseStructureRoot == nil {
|
||||
databaseStructureRoot = dirStructureRoot
|
||||
}
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
SetDatabaseLocation(dataroot.Root())
|
||||
if databaseStructureRoot == nil {
|
||||
return errors.New("database location not specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
startMaintenanceTasks()
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
return database.Shutdown()
|
||||
}
|
||||
|
||||
var (
|
||||
module *DBModule
|
||||
shimLoaded atomic.Bool
|
||||
)
|
||||
|
||||
func New(instance instance) (*DBModule, error) {
|
||||
if !shimLoaded.CompareAndSwap(false, true) {
|
||||
return nil, errors.New("only one instance allowed")
|
||||
}
|
||||
|
||||
if err := prep(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := mgr.New("DBModule")
|
||||
module = &DBModule{
|
||||
mgr: m,
|
||||
instance: instance,
|
||||
}
|
||||
|
||||
err := database.Initialize(databaseStructureRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return module, nil
|
||||
}
|
||||
|
||||
type instance interface{}
|
||||
30
base/database/dbmodule/maintenance.go
Normal file
30
base/database/dbmodule/maintenance.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package dbmodule
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
func startMaintenanceTasks() {
|
||||
_ = module.mgr.Repeat("basic maintenance", 10*time.Minute, maintainBasic)
|
||||
_ = module.mgr.Repeat("thorough maintenance", 1*time.Hour, maintainThorough)
|
||||
_ = module.mgr.Repeat("record maintenance", 1*time.Hour, maintainRecords)
|
||||
}
|
||||
|
||||
func maintainBasic(ctx *mgr.WorkerCtx) error {
|
||||
log.Infof("database: running Maintain")
|
||||
return database.Maintain(ctx.Ctx())
|
||||
}
|
||||
|
||||
func maintainThorough(ctx *mgr.WorkerCtx) error {
|
||||
log.Infof("database: running MaintainThorough")
|
||||
return database.MaintainThorough(ctx.Ctx())
|
||||
}
|
||||
|
||||
func maintainRecords(ctx *mgr.WorkerCtx) error {
|
||||
log.Infof("database: running MaintainRecordStates")
|
||||
return database.MaintainRecordStates(ctx.Ctx())
|
||||
}
|
||||
62
base/database/doc.go
Normal file
62
base/database/doc.go
Normal file
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
Package database provides a universal interface for interacting with the database.
|
||||
|
||||
# A Lazy Database
|
||||
|
||||
The database system can handle Go structs as well as serialized data by the dsd package.
|
||||
While data is in transit within the system, it does not know which form it currently has. Only when it reaches its destination, it must ensure that it is either of a certain type or dump it.
|
||||
|
||||
# Record Interface
|
||||
|
||||
The database system uses the Record interface to transparently handle all types of structs that get saved in the database. Structs include the Base struct to fulfill most parts of the Record interface.
|
||||
|
||||
Boilerplate Code:
|
||||
|
||||
type Example struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Name string
|
||||
Score int
|
||||
}
|
||||
|
||||
var (
|
||||
db = database.NewInterface(nil)
|
||||
)
|
||||
|
||||
// GetExample gets an Example from the database.
|
||||
func GetExample(key string) (*Example, error) {
|
||||
r, err := db.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// unwrap
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
new := &Example{}
|
||||
err = record.Unwrap(r, new)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return new, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
new, ok := r.(*Example)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *Example, but %T", r)
|
||||
}
|
||||
return new, nil
|
||||
}
|
||||
|
||||
func (e *Example) Save() error {
|
||||
return db.Put(e)
|
||||
}
|
||||
|
||||
func (e *Example) SaveAs(key string) error {
|
||||
e.SetKey(key)
|
||||
return db.PutNew(e)
|
||||
}
|
||||
*/
|
||||
package database
|
||||
14
base/database/errors.go
Normal file
14
base/database/errors.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrNotFound = errors.New("database entry not found")
|
||||
ErrPermissionDenied = errors.New("access to database record denied")
|
||||
ErrReadOnly = errors.New("database is read only")
|
||||
ErrShuttingDown = errors.New("database system is shutting down")
|
||||
ErrNotImplemented = errors.New("not implemented by this storage")
|
||||
)
|
||||
91
base/database/hook.go
Normal file
91
base/database/hook.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// Hook can be registered for a database query and
|
||||
// will be executed at certain points during the life
|
||||
// cycle of a database record.
|
||||
type Hook interface {
|
||||
// UsesPreGet should return true if the hook's PreGet
|
||||
// should be called prior to loading a database record
|
||||
// from the underlying storage.
|
||||
UsesPreGet() bool
|
||||
// PreGet is called before a database record is loaded from
|
||||
// the underlying storage. A PreGet hookd may be used to
|
||||
// implement more advanced access control on database keys.
|
||||
PreGet(dbKey string) error
|
||||
// UsesPostGet should return true if the hook's PostGet
|
||||
// should be called after loading a database record from
|
||||
// the underlying storage.
|
||||
UsesPostGet() bool
|
||||
// PostGet is called after a record has been loaded form the
|
||||
// underlying storage and may perform additional mutation
|
||||
// or access check based on the records data.
|
||||
// The passed record is already locked by the database system
|
||||
// so users can safely access all data of r.
|
||||
PostGet(r record.Record) (record.Record, error)
|
||||
// UsesPrePut should return true if the hook's PrePut method
|
||||
// should be called prior to saving a record in the database.
|
||||
UsesPrePut() bool
|
||||
// PrePut is called prior to saving (creating or updating) a
|
||||
// record in the database storage. It may be used to perform
|
||||
// extended validation or mutations on the record.
|
||||
// The passed record is already locked by the database system
|
||||
// so users can safely access all data of r.
|
||||
PrePut(r record.Record) (record.Record, error)
|
||||
}
|
||||
|
||||
// RegisteredHook is a registered database hook.
|
||||
type RegisteredHook struct {
|
||||
q *query.Query
|
||||
h Hook
|
||||
}
|
||||
|
||||
// RegisterHook registers a hook for records matching the given
|
||||
// query in the database.
|
||||
func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := getController(q.DatabaseName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rh := &RegisteredHook{
|
||||
q: q,
|
||||
h: hook,
|
||||
}
|
||||
|
||||
c.hooksLock.Lock()
|
||||
defer c.hooksLock.Unlock()
|
||||
c.hooks = append(c.hooks, rh)
|
||||
|
||||
return rh, nil
|
||||
}
|
||||
|
||||
// Cancel unregisteres the hook from the database. Once
|
||||
// Cancel returned the hook's methods will not be called
|
||||
// anymore for updates that matched the registered query.
|
||||
func (h *RegisteredHook) Cancel() error {
|
||||
c, err := getController(h.q.DatabaseName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.hooksLock.Lock()
|
||||
defer c.hooksLock.Unlock()
|
||||
|
||||
for key, hook := range c.hooks {
|
||||
if hook.q == h.q {
|
||||
c.hooks = append(c.hooks[:key], c.hooks[key+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
38
base/database/hookbase.go
Normal file
38
base/database/hookbase.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// HookBase implements the Hook interface and provides dummy functions to reduce boilerplate.
|
||||
type HookBase struct{}
|
||||
|
||||
// UsesPreGet implements the Hook interface and returns false.
|
||||
func (b *HookBase) UsesPreGet() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// UsesPostGet implements the Hook interface and returns false.
|
||||
func (b *HookBase) UsesPostGet() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// UsesPrePut implements the Hook interface and returns false.
|
||||
func (b *HookBase) UsesPrePut() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// PreGet implements the Hook interface.
|
||||
func (b *HookBase) PreGet(dbKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostGet implements the Hook interface.
|
||||
func (b *HookBase) PostGet(r record.Record) (record.Record, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PrePut implements the Hook interface.
|
||||
func (b *HookBase) PrePut(r record.Record) (record.Record, error) {
|
||||
return r, nil
|
||||
}
|
||||
585
base/database/interface.go
Normal file
585
base/database/interface.go
Normal file
@@ -0,0 +1,585 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bluele/gcache"
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
const (
|
||||
getDBFromKey = ""
|
||||
)
|
||||
|
||||
// Interface provides a method to access the database with attached options.
|
||||
type Interface struct {
|
||||
options *Options
|
||||
cache gcache.Cache
|
||||
|
||||
writeCache map[string]record.Record
|
||||
writeCacheLock sync.Mutex
|
||||
triggerCacheWrite chan struct{}
|
||||
}
|
||||
|
||||
// Options holds options that may be set for an Interface instance.
|
||||
type Options struct {
|
||||
// Local specifies if the interface is used by an actor on the local device.
|
||||
// Setting both the Local and Internal flags will bring performance
|
||||
// improvements because less checks are needed.
|
||||
Local bool
|
||||
|
||||
// Internal specifies if the interface is used by an actor within the
|
||||
// software. Setting both the Local and Internal flags will bring performance
|
||||
// improvements because less checks are needed.
|
||||
Internal bool
|
||||
|
||||
// AlwaysMakeSecret will have the interface mark all saved records as secret.
|
||||
// This means that they will be only accessible by an internal interface.
|
||||
AlwaysMakeSecret bool
|
||||
|
||||
// AlwaysMakeCrownjewel will have the interface mark all saved records as
|
||||
// crown jewels. This means that they will be only accessible by a local
|
||||
// interface.
|
||||
AlwaysMakeCrownjewel bool
|
||||
|
||||
// AlwaysSetRelativateExpiry will have the interface set a relative expiry,
|
||||
// based on the current time, on all saved records.
|
||||
AlwaysSetRelativateExpiry int64
|
||||
|
||||
// AlwaysSetAbsoluteExpiry will have the interface set an absolute expiry on
|
||||
// all saved records.
|
||||
AlwaysSetAbsoluteExpiry int64
|
||||
|
||||
// CacheSize defines that a cache should be used for this interface and
|
||||
// defines it's size.
|
||||
// Caching comes with an important caveat: If database records are changed
|
||||
// from another interface, the cache will not be invalidated for these
|
||||
// records. It will therefore serve outdated data until that record is
|
||||
// evicted from the cache.
|
||||
CacheSize int
|
||||
|
||||
// DelayCachedWrites defines a database name for which cache writes should
|
||||
// be cached and batched. The database backend must support the Batcher
|
||||
// interface. This option is only valid if used with a cache.
|
||||
// Additionally, this may only be used for internal and local interfaces.
|
||||
// Please note that this means that other interfaces will not be able to
|
||||
// guarantee to serve the latest record if records are written this way.
|
||||
DelayCachedWrites string
|
||||
}
|
||||
|
||||
// Apply applies options to the record metadata.
|
||||
func (o *Options) Apply(r record.Record) {
|
||||
r.UpdateMeta()
|
||||
if o.AlwaysMakeSecret {
|
||||
r.Meta().MakeSecret()
|
||||
}
|
||||
if o.AlwaysMakeCrownjewel {
|
||||
r.Meta().MakeCrownJewel()
|
||||
}
|
||||
if o.AlwaysSetAbsoluteExpiry > 0 {
|
||||
r.Meta().SetAbsoluteExpiry(o.AlwaysSetAbsoluteExpiry)
|
||||
} else if o.AlwaysSetRelativateExpiry > 0 {
|
||||
r.Meta().SetRelativateExpiry(o.AlwaysSetRelativateExpiry)
|
||||
}
|
||||
}
|
||||
|
||||
// HasAllPermissions returns whether the options specify the highest possible
|
||||
// permissions for operations.
|
||||
func (o *Options) HasAllPermissions() bool {
|
||||
return o.Local && o.Internal
|
||||
}
|
||||
|
||||
// hasAccessPermission checks if the interface options permit access to the
|
||||
// given record, locking the record for accessing it's attributes.
|
||||
func (o *Options) hasAccessPermission(r record.Record) bool {
|
||||
// Check if the options specify all permissions, which makes checking the
|
||||
// record unnecessary.
|
||||
if o.HasAllPermissions() {
|
||||
return true
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
// Check permissions against record.
|
||||
return r.Meta().CheckPermission(o.Local, o.Internal)
|
||||
}
|
||||
|
||||
// NewInterface returns a new Interface to the database.
|
||||
func NewInterface(opts *Options) *Interface {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
newIface := &Interface{
|
||||
options: opts,
|
||||
}
|
||||
if opts.CacheSize > 0 {
|
||||
cacheBuilder := gcache.New(opts.CacheSize).ARC()
|
||||
if opts.DelayCachedWrites != "" {
|
||||
cacheBuilder.EvictedFunc(newIface.cacheEvictHandler)
|
||||
newIface.writeCache = make(map[string]record.Record, opts.CacheSize/2)
|
||||
newIface.triggerCacheWrite = make(chan struct{})
|
||||
}
|
||||
newIface.cache = cacheBuilder.Build()
|
||||
}
|
||||
return newIface
|
||||
}
|
||||
|
||||
// Exists return whether a record with the given key exists.
|
||||
func (i *Interface) Exists(key string) (bool, error) {
|
||||
_, err := i.Get(key)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrNotFound):
|
||||
return false, nil
|
||||
case errors.Is(err, ErrPermissionDenied):
|
||||
return true, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Get return the record with the given key.
|
||||
func (i *Interface) Get(key string) (record.Record, error) {
|
||||
r, _, err := i.getRecord(getDBFromKey, key, false)
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (i *Interface) getRecord(dbName string, dbKey string, mustBeWriteable bool) (r record.Record, db *Controller, err error) { //nolint:unparam
|
||||
if dbName == "" {
|
||||
dbName, dbKey = record.ParseKey(dbKey)
|
||||
}
|
||||
|
||||
db, err = getController(dbName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if mustBeWriteable && db.ReadOnly() {
|
||||
return nil, db, ErrReadOnly
|
||||
}
|
||||
|
||||
r = i.checkCache(dbName + ":" + dbKey)
|
||||
if r != nil {
|
||||
if !i.options.hasAccessPermission(r) {
|
||||
return nil, db, ErrPermissionDenied
|
||||
}
|
||||
return r, db, nil
|
||||
}
|
||||
|
||||
r, err = db.Get(dbKey)
|
||||
if err != nil {
|
||||
return nil, db, err
|
||||
}
|
||||
|
||||
if !i.options.hasAccessPermission(r) {
|
||||
return nil, db, ErrPermissionDenied
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
ttl := r.Meta().GetRelativeExpiry()
|
||||
r.Unlock()
|
||||
i.updateCache(
|
||||
r,
|
||||
false, // writing
|
||||
false, // remove
|
||||
ttl, // expiry
|
||||
)
|
||||
|
||||
return r, db, nil
|
||||
}
|
||||
|
||||
func (i *Interface) getMeta(dbName string, dbKey string, mustBeWriteable bool) (m *record.Meta, db *Controller, err error) { //nolint:unparam
|
||||
if dbName == "" {
|
||||
dbName, dbKey = record.ParseKey(dbKey)
|
||||
}
|
||||
|
||||
db, err = getController(dbName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if mustBeWriteable && db.ReadOnly() {
|
||||
return nil, db, ErrReadOnly
|
||||
}
|
||||
|
||||
r := i.checkCache(dbName + ":" + dbKey)
|
||||
if r != nil {
|
||||
if !i.options.hasAccessPermission(r) {
|
||||
return nil, db, ErrPermissionDenied
|
||||
}
|
||||
return r.Meta(), db, nil
|
||||
}
|
||||
|
||||
m, err = db.GetMeta(dbKey)
|
||||
if err != nil {
|
||||
return nil, db, err
|
||||
}
|
||||
|
||||
if !m.CheckPermission(i.options.Local, i.options.Internal) {
|
||||
return nil, db, ErrPermissionDenied
|
||||
}
|
||||
|
||||
return m, db, nil
|
||||
}
|
||||
|
||||
// InsertValue inserts a value into a record.
|
||||
func (i *Interface) InsertValue(key string, attribute string, value interface{}) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
var acc accessor.Accessor
|
||||
if r.IsWrapped() {
|
||||
wrapper, ok := r.(*record.Wrapper)
|
||||
if !ok {
|
||||
return errors.New("record is malformed (reports to be wrapped but is not of type *record.Wrapper)")
|
||||
}
|
||||
acc = accessor.NewJSONBytesAccessor(&wrapper.Data)
|
||||
} else {
|
||||
acc = accessor.NewStructAccessor(r)
|
||||
}
|
||||
|
||||
err = acc.Set(attribute, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set value with %s: %w", acc.Type(), err)
|
||||
}
|
||||
|
||||
i.options.Apply(r)
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// Put saves a record to the database.
|
||||
func (i *Interface) Put(r record.Record) (err error) {
|
||||
// get record or only database
|
||||
var db *Controller
|
||||
if !i.options.HasAllPermissions() {
|
||||
_, db, err = i.getMeta(r.DatabaseName(), r.DatabaseKey(), true)
|
||||
if err != nil && !errors.Is(err, ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
db, err = getController(r.DatabaseName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if database is read only.
|
||||
if db.ReadOnly() {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
i.options.Apply(r)
|
||||
remove := r.Meta().IsDeleted()
|
||||
ttl := r.Meta().GetRelativeExpiry()
|
||||
r.Unlock()
|
||||
|
||||
// The record may not be locked when updating the cache.
|
||||
written := i.updateCache(r, true, remove, ttl)
|
||||
if written {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// PutNew saves a record to the database as a new record (ie. with new timestamps).
|
||||
func (i *Interface) PutNew(r record.Record) (err error) {
|
||||
// get record or only database
|
||||
var db *Controller
|
||||
if !i.options.HasAllPermissions() {
|
||||
_, db, err = i.getMeta(r.DatabaseName(), r.DatabaseKey(), true)
|
||||
if err != nil && !errors.Is(err, ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
db, err = getController(r.DatabaseName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if database is read only.
|
||||
if db.ReadOnly() {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
if r.Meta() != nil {
|
||||
r.Meta().Reset()
|
||||
}
|
||||
i.options.Apply(r)
|
||||
remove := r.Meta().IsDeleted()
|
||||
ttl := r.Meta().GetRelativeExpiry()
|
||||
r.Unlock()
|
||||
|
||||
// The record may not be locked when updating the cache.
|
||||
written := i.updateCache(r, true, remove, ttl)
|
||||
if written {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
// Warning: This is nearly a direct database access and omits many things:
|
||||
// - Record locking
|
||||
// - Hooks
|
||||
// - Subscriptions
|
||||
// - Caching
|
||||
// Use with care.
|
||||
func (i *Interface) PutMany(dbName string) (put func(record.Record) error) {
|
||||
interfaceBatch := make(chan record.Record, 100)
|
||||
|
||||
// permission check
|
||||
if !i.options.HasAllPermissions() {
|
||||
return func(r record.Record) error {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// get database
|
||||
db, err := getController(dbName)
|
||||
if err != nil {
|
||||
return func(r record.Record) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if database is read only.
|
||||
if db.ReadOnly() {
|
||||
return func(r record.Record) error {
|
||||
return ErrReadOnly
|
||||
}
|
||||
}
|
||||
|
||||
// start database access
|
||||
dbBatch, errs := db.PutMany()
|
||||
finished := abool.New()
|
||||
var internalErr error
|
||||
|
||||
// interface options proxy
|
||||
go func() {
|
||||
defer close(dbBatch) // signify that we are finished
|
||||
for {
|
||||
select {
|
||||
case r := <-interfaceBatch:
|
||||
// finished?
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
// apply options
|
||||
i.options.Apply(r)
|
||||
// pass along
|
||||
dbBatch <- r
|
||||
case <-time.After(1 * time.Second):
|
||||
// bail out
|
||||
internalErr = errors.New("timeout: putmany unused for too long")
|
||||
finished.Set()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return func(r record.Record) error {
|
||||
// finished?
|
||||
if finished.IsSet() {
|
||||
// check for internal error
|
||||
if internalErr != nil {
|
||||
return internalErr
|
||||
}
|
||||
// check for previous error
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
default:
|
||||
return errors.New("batch is closed")
|
||||
}
|
||||
}
|
||||
|
||||
// finish?
|
||||
if r == nil {
|
||||
finished.Set()
|
||||
interfaceBatch <- nil // signify that we are finished
|
||||
// do not close, as this fn could be called again with nil.
|
||||
return <-errs
|
||||
}
|
||||
|
||||
// check record scope
|
||||
if r.DatabaseName() != dbName {
|
||||
return errors.New("record out of database scope")
|
||||
}
|
||||
|
||||
// submit
|
||||
select {
|
||||
case interfaceBatch <- r:
|
||||
return nil
|
||||
case err := <-errs:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetAbsoluteExpiry sets an absolute record expiry.
|
||||
func (i *Interface) SetAbsoluteExpiry(key string, time int64) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().SetAbsoluteExpiry(time)
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// SetRelativateExpiry sets a relative (self-updating) record expiry.
|
||||
func (i *Interface) SetRelativateExpiry(key string, duration int64) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().SetRelativateExpiry(duration)
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// MakeSecret marks the record as a secret, meaning interfacing processes, such as an UI, are denied access to the record.
|
||||
func (i *Interface) MakeSecret(key string) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().MakeSecret()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// MakeCrownJewel marks a record as a crown jewel, meaning it will only be accessible locally.
|
||||
func (i *Interface) MakeCrownJewel(key string) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().MakeCrownJewel()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (i *Interface) Delete(key string) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if database is read only.
|
||||
if db.ReadOnly() {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().Delete()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// Query executes the given query on the database.
|
||||
// Will not see data that is in the write cache, waiting to be written.
|
||||
// Use with care with caching.
|
||||
func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := getController(q.DatabaseName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: Finish caching system integration.
|
||||
// Flush the cache before we query the database.
|
||||
// i.FlushCache()
|
||||
|
||||
return db.Query(q, i.options.Local, i.options.Internal)
|
||||
}
|
||||
|
||||
// Purge deletes all records that match the given query. It returns the number
|
||||
// of successful deletes and an error.
|
||||
func (i *Interface) Purge(ctx context.Context, q *query.Query) (int, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
db, err := getController(q.DatabaseName())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Check if database is read only before we add to the cache.
|
||||
if db.ReadOnly() {
|
||||
return 0, ErrReadOnly
|
||||
}
|
||||
|
||||
return db.Purge(ctx, q, i.options.Local, i.options.Internal)
|
||||
}
|
||||
|
||||
// Subscribe subscribes to updates matching the given query.
|
||||
func (i *Interface) Subscribe(q *query.Query) (*Subscription, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := getController(q.DatabaseName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sub := &Subscription{
|
||||
q: q,
|
||||
local: i.options.Local,
|
||||
internal: i.options.Internal,
|
||||
Feed: make(chan record.Record, 1000),
|
||||
}
|
||||
c.addSubscription(sub)
|
||||
return sub, nil
|
||||
}
|
||||
227
base/database/interface_cache.go
Normal file
227
base/database/interface_cache.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// DelayedCacheWriter must be run by the caller of an interface that uses delayed cache writing.
|
||||
func (i *Interface) DelayedCacheWriter(wc *mgr.WorkerCtx) error {
|
||||
// Check if the DelayedCacheWriter should be run at all.
|
||||
if i.options.CacheSize <= 0 || i.options.DelayCachedWrites == "" {
|
||||
return errors.New("delayed cache writer is not applicable to this database interface")
|
||||
}
|
||||
|
||||
// Check if backend support the Batcher interface.
|
||||
batchPut := i.PutMany(i.options.DelayCachedWrites)
|
||||
// End batchPut immediately and check for an error.
|
||||
err := batchPut(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// percentThreshold defines the minimum percentage of entries in the write cache in relation to the cache size that need to be present in order for flushing the cache to the database storage.
|
||||
percentThreshold := 25
|
||||
thresholdWriteTicker := time.NewTicker(5 * time.Second)
|
||||
forceWriteTicker := time.NewTicker(5 * time.Minute)
|
||||
|
||||
for {
|
||||
// Wait for trigger for writing the cache.
|
||||
select {
|
||||
case <-wc.Done():
|
||||
// The caller is shutting down, flush the cache to storage and exit.
|
||||
i.flushWriteCache(0)
|
||||
return nil
|
||||
|
||||
case <-i.triggerCacheWrite:
|
||||
// An entry from the cache was evicted that was also in the write cache.
|
||||
// This makes it likely that other entries that are also present in the
|
||||
// write cache will be evicted soon. Flush the write cache to storage
|
||||
// immediately in order to reduce single writes.
|
||||
i.flushWriteCache(0)
|
||||
|
||||
case <-thresholdWriteTicker.C:
|
||||
// Often check if the write cache has filled up to a certain degree and
|
||||
// flush it to storage before we start evicting to-be-written entries and
|
||||
// slow down the hot path again.
|
||||
i.flushWriteCache(percentThreshold)
|
||||
|
||||
case <-forceWriteTicker.C:
|
||||
// Once in a while, flush the write cache to storage no matter how much
|
||||
// it is filled. We don't want entries lingering around in the write
|
||||
// cache forever. This also reduces the amount of data loss in the event
|
||||
// of a total crash.
|
||||
i.flushWriteCache(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClearCache clears the read cache.
|
||||
func (i *Interface) ClearCache() {
|
||||
// Check if cache is in use.
|
||||
if i.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear all cache entries.
|
||||
i.cache.Purge()
|
||||
}
|
||||
|
||||
// FlushCache writes (and thus clears) the write cache.
|
||||
func (i *Interface) FlushCache() {
|
||||
// Check if write cache is in use.
|
||||
if i.options.DelayCachedWrites != "" {
|
||||
return
|
||||
}
|
||||
|
||||
i.flushWriteCache(0)
|
||||
}
|
||||
|
||||
func (i *Interface) flushWriteCache(percentThreshold int) {
|
||||
i.writeCacheLock.Lock()
|
||||
defer i.writeCacheLock.Unlock()
|
||||
|
||||
// Check if there is anything to do.
|
||||
if len(i.writeCache) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we reach the given threshold for writing to storage.
|
||||
if (len(i.writeCache)*100)/i.options.CacheSize < percentThreshold {
|
||||
return
|
||||
}
|
||||
|
||||
// Write the full cache in a batch operation.
|
||||
batchPut := i.PutMany(i.options.DelayCachedWrites)
|
||||
for _, r := range i.writeCache {
|
||||
err := batchPut(r)
|
||||
if err != nil {
|
||||
log.Warningf("database: failed to write write-cached entry to %q database: %s", i.options.DelayCachedWrites, err)
|
||||
}
|
||||
}
|
||||
// Finish batch.
|
||||
err := batchPut(nil)
|
||||
if err != nil {
|
||||
log.Warningf("database: failed to finish flushing write cache to %q database: %s", i.options.DelayCachedWrites, err)
|
||||
}
|
||||
|
||||
// Optimized map clearing following the Go1.11 recommendation.
|
||||
for key := range i.writeCache {
|
||||
delete(i.writeCache, key)
|
||||
}
|
||||
}
|
||||
|
||||
// cacheEvictHandler is run by the cache for every entry that gets evicted
|
||||
// from the cache.
|
||||
func (i *Interface) cacheEvictHandler(keyData, _ interface{}) {
|
||||
// Transform the key into a string.
|
||||
key, ok := keyData.(string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the evicted record is one that is to be written.
|
||||
// Lock the write cache until the end of the function.
|
||||
// The read cache is locked anyway for the whole duration.
|
||||
i.writeCacheLock.Lock()
|
||||
defer i.writeCacheLock.Unlock()
|
||||
r, ok := i.writeCache[key]
|
||||
if ok {
|
||||
delete(i.writeCache, key)
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Write record to database in order to mitigate race conditions where the record would appear
|
||||
// as non-existent for a short duration.
|
||||
db, err := getController(r.DatabaseName())
|
||||
if err != nil {
|
||||
log.Warningf("database: failed to write evicted cache entry %q: database %q does not exist", key, r.DatabaseName())
|
||||
return
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
err = db.Put(r)
|
||||
if err != nil {
|
||||
log.Warningf("database: failed to write evicted cache entry %q to database: %s", key, err)
|
||||
}
|
||||
|
||||
// Finally, trigger writing the full write cache because a to-be-written
|
||||
// entry was just evicted from the cache, and this makes it likely that more
|
||||
// to-be-written entries will be evicted shortly.
|
||||
select {
|
||||
case i.triggerCacheWrite <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interface) checkCache(key string) record.Record {
|
||||
// Check if cache is in use.
|
||||
if i.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if record exists in cache.
|
||||
cacheVal, err := i.cache.Get(key)
|
||||
if err == nil {
|
||||
r, ok := cacheVal.(record.Record)
|
||||
if ok {
|
||||
return r
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateCache updates an entry in the interface cache. The given record may
|
||||
// not be locked, as updating the cache might write an (unrelated) evicted
|
||||
// record to the database in the process. If this happens while the
|
||||
// DelayedCacheWriter flushes the write cache with the same record present,
|
||||
// this will deadlock.
|
||||
func (i *Interface) updateCache(r record.Record, write bool, remove bool, ttl int64) (written bool) {
|
||||
// Check if cache is in use.
|
||||
if i.cache == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if record should be deleted
|
||||
if remove {
|
||||
// Remove entry from cache.
|
||||
i.cache.Remove(r.Key())
|
||||
// Let write through to database storage.
|
||||
return false
|
||||
}
|
||||
|
||||
// Update cache with record.
|
||||
if ttl >= 0 {
|
||||
_ = i.cache.SetWithExpire(
|
||||
r.Key(),
|
||||
r,
|
||||
time.Duration(ttl)*time.Second,
|
||||
)
|
||||
} else {
|
||||
_ = i.cache.Set(
|
||||
r.Key(),
|
||||
r,
|
||||
)
|
||||
}
|
||||
|
||||
// Add record to write cache instead if:
|
||||
// 1. The record is being written.
|
||||
// 2. Write delaying is active.
|
||||
// 3. Write delaying is active for the database of this record.
|
||||
if write && r.DatabaseName() == i.options.DelayCachedWrites {
|
||||
i.writeCacheLock.Lock()
|
||||
defer i.writeCacheLock.Unlock()
|
||||
i.writeCache[r.Key()] = r
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
159
base/database/interface_cache_test.go
Normal file
159
base/database/interface_cache_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper
|
||||
b.Run(fmt.Sprintf("CacheWriting_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) {
|
||||
// Setup Benchmark.
|
||||
|
||||
// Create database.
|
||||
dbName := fmt.Sprintf("cache-w-benchmark-%s-%d-%d-%v", storageType, cacheSize, sampleSize, delayWrites)
|
||||
_, err := Register(&Database{
|
||||
Name: dbName,
|
||||
Description: fmt.Sprintf("Cache Benchmark Database for %s", storageType),
|
||||
StorageType: storageType,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Create benchmark interface.
|
||||
options := &Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
CacheSize: cacheSize,
|
||||
}
|
||||
if cacheSize > 0 && delayWrites {
|
||||
options.DelayCachedWrites = dbName
|
||||
}
|
||||
db := NewInterface(options)
|
||||
|
||||
// Start
|
||||
m := mgr.New("Cache writing benchmark test")
|
||||
var wg sync.WaitGroup
|
||||
if cacheSize > 0 && delayWrites {
|
||||
wg.Add(1)
|
||||
m.Go("Cache writing benchmark worker", func(wc *mgr.WorkerCtx) error {
|
||||
err := db.DelayedCacheWriter(wc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
wg.Done()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Start Benchmark.
|
||||
b.ResetTimer()
|
||||
for i := range b.N {
|
||||
testRecordID := i % sampleSize
|
||||
r := NewExample(
|
||||
dbName+":"+strconv.Itoa(testRecordID),
|
||||
"A",
|
||||
1,
|
||||
)
|
||||
err = db.Put(r)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// End cache writer and wait
|
||||
m.Cancel()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper
|
||||
b.Run(fmt.Sprintf("CacheReadWrite_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) {
|
||||
// Setup Benchmark.
|
||||
|
||||
// Create database.
|
||||
dbName := fmt.Sprintf("cache-rw-benchmark-%s-%d-%d-%v", storageType, cacheSize, sampleSize, delayWrites)
|
||||
_, err := Register(&Database{
|
||||
Name: dbName,
|
||||
Description: fmt.Sprintf("Cache Benchmark Database for %s", storageType),
|
||||
StorageType: storageType,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Create benchmark interface.
|
||||
options := &Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
CacheSize: cacheSize,
|
||||
}
|
||||
if cacheSize > 0 && delayWrites {
|
||||
options.DelayCachedWrites = dbName
|
||||
}
|
||||
db := NewInterface(options)
|
||||
|
||||
// Start
|
||||
m := mgr.New("Cache read/write benchmark test")
|
||||
var wg sync.WaitGroup
|
||||
if cacheSize > 0 && delayWrites {
|
||||
wg.Add(1)
|
||||
m.Go("Cache read/write benchmark worker", func(wc *mgr.WorkerCtx) error {
|
||||
err := db.DelayedCacheWriter(wc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
wg.Done()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Start Benchmark.
|
||||
b.ResetTimer()
|
||||
writing := true
|
||||
for i := range b.N {
|
||||
testRecordID := i % sampleSize
|
||||
key := dbName + ":" + strconv.Itoa(testRecordID)
|
||||
|
||||
if i > 0 && testRecordID == 0 {
|
||||
writing = !writing // switch between reading and writing every samplesize
|
||||
}
|
||||
|
||||
if writing {
|
||||
r := NewExample(key, "A", 1)
|
||||
err = db.Put(r)
|
||||
} else {
|
||||
_, err = db.Get(key)
|
||||
}
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// End cache writer and wait
|
||||
m.Cancel()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCache(b *testing.B) {
|
||||
for _, storageType := range []string{"bbolt", "hashmap"} {
|
||||
benchmarkCacheWriting(b, storageType, 32, 8, false)
|
||||
benchmarkCacheWriting(b, storageType, 32, 8, true)
|
||||
benchmarkCacheWriting(b, storageType, 32, 1024, false)
|
||||
benchmarkCacheWriting(b, storageType, 32, 1024, true)
|
||||
benchmarkCacheWriting(b, storageType, 512, 1024, false)
|
||||
benchmarkCacheWriting(b, storageType, 512, 1024, true)
|
||||
|
||||
benchmarkCacheReadWrite(b, storageType, 32, 8, false)
|
||||
benchmarkCacheReadWrite(b, storageType, 32, 8, true)
|
||||
benchmarkCacheReadWrite(b, storageType, 32, 1024, false)
|
||||
benchmarkCacheReadWrite(b, storageType, 32, 1024, true)
|
||||
benchmarkCacheReadWrite(b, storageType, 512, 1024, false)
|
||||
benchmarkCacheReadWrite(b, storageType, 512, 1024, true)
|
||||
}
|
||||
}
|
||||
54
base/database/iterator/iterator.go
Normal file
54
base/database/iterator/iterator.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package iterator
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// Iterator defines the iterator structure.
|
||||
type Iterator struct {
|
||||
Next chan record.Record
|
||||
Done chan struct{}
|
||||
|
||||
errLock sync.Mutex
|
||||
err error
|
||||
doneClosed *abool.AtomicBool
|
||||
}
|
||||
|
||||
// New creates a new Iterator.
|
||||
func New() *Iterator {
|
||||
return &Iterator{
|
||||
Next: make(chan record.Record, 10),
|
||||
Done: make(chan struct{}),
|
||||
doneClosed: abool.NewBool(false),
|
||||
}
|
||||
}
|
||||
|
||||
// Finish is called be the storage to signal the end of the query results.
|
||||
func (it *Iterator) Finish(err error) {
|
||||
close(it.Next)
|
||||
if it.doneClosed.SetToIf(false, true) {
|
||||
close(it.Done)
|
||||
}
|
||||
|
||||
it.errLock.Lock()
|
||||
defer it.errLock.Unlock()
|
||||
it.err = err
|
||||
}
|
||||
|
||||
// Cancel is called by the iteration consumer to cancel the running query.
|
||||
func (it *Iterator) Cancel() {
|
||||
if it.doneClosed.SetToIf(false, true) {
|
||||
close(it.Done)
|
||||
}
|
||||
}
|
||||
|
||||
// Err returns the iterator error, if exists.
|
||||
func (it *Iterator) Err() error {
|
||||
it.errLock.Lock()
|
||||
defer it.errLock.Unlock()
|
||||
return it.err
|
||||
}
|
||||
77
base/database/main.go
Normal file
77
base/database/main.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
databasesSubDir = "databases"
|
||||
)
|
||||
|
||||
var (
|
||||
initialized = abool.NewBool(false)
|
||||
|
||||
shuttingDown = abool.NewBool(false)
|
||||
shutdownSignal = make(chan struct{})
|
||||
|
||||
rootStructure *utils.DirStructure
|
||||
databasesStructure *utils.DirStructure
|
||||
)
|
||||
|
||||
// InitializeWithPath initializes the database at the specified location using a path.
|
||||
func InitializeWithPath(dirPath string) error {
|
||||
return Initialize(utils.NewDirStructure(dirPath, 0o0755))
|
||||
}
|
||||
|
||||
// Initialize initializes the database at the specified location using a dir structure.
|
||||
func Initialize(dirStructureRoot *utils.DirStructure) error {
|
||||
if initialized.SetToIf(false, true) {
|
||||
rootStructure = dirStructureRoot
|
||||
|
||||
// ensure root and databases dirs
|
||||
databasesStructure = rootStructure.ChildDir(databasesSubDir, 0o0700)
|
||||
err := databasesStructure.Ensure()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create/open database directory (%s): %w", rootStructure.Path, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return errors.New("database already initialized")
|
||||
}
|
||||
|
||||
// Shutdown shuts down the whole database system.
|
||||
func Shutdown() (err error) {
|
||||
if shuttingDown.SetToIf(false, true) {
|
||||
close(shutdownSignal)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
controllersLock.RLock()
|
||||
defer controllersLock.RUnlock()
|
||||
|
||||
for _, c := range controllers {
|
||||
err = c.Shutdown()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getLocation returns the storage location for the given name and type.
|
||||
func getLocation(name, storageType string) (string, error) {
|
||||
location := databasesStructure.ChildDir(name, 0o0700).ChildDir(storageType, 0o0700)
|
||||
// check location
|
||||
err := location.Ensure()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf(`failed to create/check database dir "%s": %w`, location.Path, err)
|
||||
}
|
||||
return location.Path, nil
|
||||
}
|
||||
64
base/database/maintenance.go
Normal file
64
base/database/maintenance.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Maintain runs the Maintain method on all storages.
|
||||
func Maintain(ctx context.Context) (err error) {
|
||||
// copy, as we might use the very long
|
||||
all := duplicateControllers()
|
||||
|
||||
for _, c := range all {
|
||||
err = c.Maintain(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MaintainThorough runs the MaintainThorough method on all storages.
|
||||
func MaintainThorough(ctx context.Context) (err error) {
|
||||
// copy, as we might use the very long
|
||||
all := duplicateControllers()
|
||||
|
||||
for _, c := range all {
|
||||
err = c.MaintainThorough(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MaintainRecordStates runs record state lifecycle maintenance on all storages.
|
||||
func MaintainRecordStates(ctx context.Context) (err error) {
|
||||
// delete immediately for now
|
||||
// TODO: increase purge threshold when starting to sync DBs
|
||||
purgeDeletedBefore := time.Now().UTC()
|
||||
|
||||
// copy, as we might use the very long
|
||||
all := duplicateControllers()
|
||||
|
||||
for _, c := range all {
|
||||
err = c.MaintainRecordStates(ctx, purgeDeletedBefore)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func duplicateControllers() (all []*Controller) {
|
||||
controllersLock.RLock()
|
||||
defer controllersLock.RUnlock()
|
||||
|
||||
all = make([]*Controller, 0, len(controllers))
|
||||
for _, c := range controllers {
|
||||
all = append(all, c)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
58
base/database/migration/error.go
Normal file
58
base/database/migration/error.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package migration
|
||||
|
||||
import "errors"
|
||||
|
||||
// DiagnosticStep describes one migration step in the Diagnostics.
|
||||
type DiagnosticStep struct {
|
||||
Version string
|
||||
Description string
|
||||
}
|
||||
|
||||
// Diagnostics holds a detailed error report about a failed migration.
|
||||
type Diagnostics struct { //nolint:errname
|
||||
// Message holds a human readable message of the encountered
|
||||
// error.
|
||||
Message string
|
||||
// Wrapped must be set to the underlying error that was encountered
|
||||
// while preparing or executing migrations.
|
||||
Wrapped error
|
||||
// StartOfMigration is set to the version of the database before
|
||||
// any migrations are applied.
|
||||
StartOfMigration string
|
||||
// LastSuccessfulMigration is set to the version of the database
|
||||
// which has been applied successfully before the error happened.
|
||||
LastSuccessfulMigration string
|
||||
// TargetVersion is set to the version of the database that the
|
||||
// migration run aimed for. That is, it's the last available version
|
||||
// added to the registry.
|
||||
TargetVersion string
|
||||
// ExecutionPlan is a list of migration steps that were planned to
|
||||
// be executed.
|
||||
ExecutionPlan []DiagnosticStep
|
||||
// FailedMigration is the description of the migration that has
|
||||
// failed.
|
||||
FailedMigration string
|
||||
}
|
||||
|
||||
// Error returns a string representation of the migration error.
|
||||
func (err *Diagnostics) Error() string {
|
||||
msg := ""
|
||||
if err.FailedMigration != "" {
|
||||
msg = err.FailedMigration + ": "
|
||||
}
|
||||
if err.Message != "" {
|
||||
msg += err.Message + ": "
|
||||
}
|
||||
msg += err.Wrapped.Error()
|
||||
return msg
|
||||
}
|
||||
|
||||
// Unwrap returns the actual error that happened when executing
|
||||
// a migration. It implements the interface required by the stdlib
|
||||
// errors package to support errors.Is() and errors.As().
|
||||
func (err *Diagnostics) Unwrap() error {
|
||||
if u := errors.Unwrap(err.Wrapped); u != nil {
|
||||
return u
|
||||
}
|
||||
return err.Wrapped
|
||||
}
|
||||
220
base/database/migration/migration.go
Normal file
220
base/database/migration/migration.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// MigrateFunc is called when a migration should be applied to the
|
||||
// database. It receives the current version (from) and the target
|
||||
// version (to) of the database and a dedicated interface for
|
||||
// interacting with data stored in the DB.
|
||||
// A dedicated log.ContextTracer is added to ctx for each migration
|
||||
// run.
|
||||
type MigrateFunc func(ctx context.Context, from, to *version.Version, dbInterface *database.Interface) error
|
||||
|
||||
// Migration represents a registered data-migration that should be applied to
|
||||
// some database. Migrations are stacked on top and executed in order of increasing
|
||||
// version number (see Version field).
|
||||
type Migration struct {
|
||||
// Description provides a short human-readable description of the
|
||||
// migration.
|
||||
Description string
|
||||
// Version should hold the version of the database/subsystem after
|
||||
// the migration has been applied.
|
||||
Version string
|
||||
// MigrateFuc is executed when the migration should be performed.
|
||||
MigrateFunc MigrateFunc
|
||||
}
|
||||
|
||||
// Registry holds a migration stack.
|
||||
type Registry struct {
|
||||
key string
|
||||
|
||||
lock sync.Mutex
|
||||
migrations []Migration
|
||||
}
|
||||
|
||||
// New creates a new migration registry.
|
||||
// The key should be the name of the database key that is used to store
|
||||
// the version of the last successfully applied migration.
|
||||
func New(key string) *Registry {
|
||||
return &Registry{
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds one or more migrations to reg.
|
||||
func (reg *Registry) Add(migrations ...Migration) error {
|
||||
reg.lock.Lock()
|
||||
defer reg.lock.Unlock()
|
||||
for _, m := range migrations {
|
||||
if _, err := version.NewSemver(m.Version); err != nil {
|
||||
return fmt.Errorf("migration %q: invalid version %s: %w", m.Description, m.Version, err)
|
||||
}
|
||||
reg.migrations = append(reg.migrations, m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Migrate migrates the database by executing all registered
|
||||
// migration in order of increasing version numbers. The error
|
||||
// returned, if not nil, is always of type *Diagnostics.
|
||||
func (reg *Registry) Migrate(ctx context.Context) (err error) {
|
||||
reg.lock.Lock()
|
||||
defer reg.lock.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
log.Infof("migration: migration of %s started", reg.key)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Errorf("migration: migration of %s failed after %s: %s", reg.key, time.Since(start), err)
|
||||
} else {
|
||||
log.Infof("migration: migration of %s finished after %s", reg.key, time.Since(start))
|
||||
}
|
||||
}()
|
||||
|
||||
db := database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
startOfMigration, err := reg.getLatestSuccessfulMigration(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
execPlan, diag, err := reg.getExecutionPlan(startOfMigration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(execPlan) == 0 {
|
||||
return nil
|
||||
}
|
||||
diag.TargetVersion = execPlan[len(execPlan)-1].Version
|
||||
|
||||
// finally, apply our migrations
|
||||
lastAppliedMigration := startOfMigration
|
||||
for _, m := range execPlan {
|
||||
target, _ := version.NewSemver(m.Version) // we can safely ignore the error here
|
||||
|
||||
migrationCtx, tracer := log.AddTracer(ctx)
|
||||
|
||||
if err := m.MigrateFunc(migrationCtx, lastAppliedMigration, target, db); err != nil {
|
||||
diag.Wrapped = err
|
||||
diag.FailedMigration = m.Description
|
||||
tracer.Errorf("migration: migration for %s failed: %s - %s", reg.key, target.String(), m.Description)
|
||||
tracer.Submit()
|
||||
return diag
|
||||
}
|
||||
|
||||
lastAppliedMigration = target
|
||||
diag.LastSuccessfulMigration = lastAppliedMigration.String()
|
||||
|
||||
if err := reg.saveLastSuccessfulMigration(db, target); err != nil {
|
||||
diag.Message = "failed to persist migration status"
|
||||
diag.Wrapped = err
|
||||
diag.FailedMigration = m.Description
|
||||
}
|
||||
tracer.Infof("migration: applied migration for %s: %s - %s", reg.key, target.String(), m.Description)
|
||||
tracer.Submit()
|
||||
}
|
||||
|
||||
// all migrations have been applied successfully, we're done here
|
||||
return nil
|
||||
}
|
||||
|
||||
func (reg *Registry) getLatestSuccessfulMigration(db *database.Interface) (*version.Version, error) {
|
||||
// find the latest version stored in the database
|
||||
rec, err := db.Get(reg.key)
|
||||
if errors.Is(err, database.ErrNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, &Diagnostics{
|
||||
Message: "failed to query database for migration status",
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap the record to get the actual database
|
||||
r, ok := rec.(*record.Wrapper)
|
||||
if !ok {
|
||||
return nil, &Diagnostics{
|
||||
Wrapped: errors.New("expected wrapped database record"),
|
||||
}
|
||||
}
|
||||
|
||||
sv, err := version.NewSemver(string(r.Data))
|
||||
if err != nil {
|
||||
return nil, &Diagnostics{
|
||||
Message: "failed to parse version stored in migration status record",
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
return sv, nil
|
||||
}
|
||||
|
||||
func (reg *Registry) saveLastSuccessfulMigration(db *database.Interface, ver *version.Version) error {
|
||||
r := &record.Wrapper{
|
||||
Data: []byte(ver.String()),
|
||||
Format: dsd.RAW,
|
||||
}
|
||||
r.SetKey(reg.key)
|
||||
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
func (reg *Registry) getExecutionPlan(startOfMigration *version.Version) ([]Migration, *Diagnostics, error) {
|
||||
// create a look-up map for migrations indexed by their semver created a
|
||||
// list of version (sorted by increasing number) that we use as our execution
|
||||
// plan.
|
||||
lm := make(map[string]Migration)
|
||||
versions := make(version.Collection, 0, len(reg.migrations))
|
||||
for _, m := range reg.migrations {
|
||||
ver, err := version.NewSemver(m.Version)
|
||||
if err != nil {
|
||||
return nil, nil, &Diagnostics{
|
||||
Message: "failed to parse version of migration",
|
||||
Wrapped: err,
|
||||
FailedMigration: m.Description,
|
||||
}
|
||||
}
|
||||
lm[ver.String()] = m // use .String() for a normalized string representation
|
||||
versions = append(versions, ver)
|
||||
}
|
||||
sort.Sort(versions)
|
||||
|
||||
diag := new(Diagnostics)
|
||||
if startOfMigration != nil {
|
||||
diag.StartOfMigration = startOfMigration.String()
|
||||
}
|
||||
|
||||
// prepare our diagnostics and the execution plan
|
||||
execPlan := make([]Migration, 0, len(versions))
|
||||
for _, ver := range versions {
|
||||
// skip an migration that has already been applied.
|
||||
if startOfMigration != nil && startOfMigration.GreaterThanOrEqual(ver) {
|
||||
continue
|
||||
}
|
||||
m := lm[ver.String()]
|
||||
diag.ExecutionPlan = append(diag.ExecutionPlan, DiagnosticStep{
|
||||
Description: m.Description,
|
||||
Version: ver.String(),
|
||||
})
|
||||
execPlan = append(execPlan, m)
|
||||
}
|
||||
|
||||
return execPlan, diag, nil
|
||||
}
|
||||
55
base/database/query/README.md
Normal file
55
base/database/query/README.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Query
|
||||
|
||||
## Control Flow
|
||||
|
||||
- Grouping with `(` and `)`
|
||||
- Chaining with `and` and `or`
|
||||
- _NO_ mixing! Be explicit and use grouping.
|
||||
- Negation with `not`
|
||||
- in front of expression for group: `not (...)`
|
||||
- inside expression for clause: `name not matches "^King "`
|
||||
|
||||
## Selectors
|
||||
|
||||
Supported by all feeders:
|
||||
- root level field: `field`
|
||||
- sub level field: `field.sub`
|
||||
- array/slice/map access: `map.0`
|
||||
- array/slice/map length: `map.#`
|
||||
|
||||
Please note that some feeders may have other special characters. It is advised to only use alphanumeric characters for keys.
|
||||
|
||||
## Operators
|
||||
|
||||
| Name | Textual | Req. Type | Internal Type | Compared with |
|
||||
|-------------------------|--------------------|-----------|---------------|---------------------------|
|
||||
| Equals | `==` | int | int64 | `==` |
|
||||
| GreaterThan | `>` | int | int64 | `>` |
|
||||
| GreaterThanOrEqual | `>=` | int | int64 | `>=` |
|
||||
| LessThan | `<` | int | int64 | `<` |
|
||||
| LessThanOrEqual | `<=` | int | int64 | `<=` |
|
||||
| FloatEquals | `f==` | float | float64 | `==` |
|
||||
| FloatGreaterThan | `f>` | float | float64 | `>` |
|
||||
| FloatGreaterThanOrEqual | `f>=` | float | float64 | `>=` |
|
||||
| FloatLessThan | `f<` | float | float64 | `<` |
|
||||
| FloatLessThanOrEqual | `f<=` | float | float64 | `<=` |
|
||||
| SameAs | `sameas`, `s==` | string | string | `==` |
|
||||
| Contains | `contains`, `co` | string | string | `strings.Contains()` |
|
||||
| StartsWith | `startswith`, `sw` | string | string | `strings.HasPrefix()` |
|
||||
| EndsWith | `endswith`, `ew` | string | string | `strings.HasSuffix()` |
|
||||
| In | `in` | string | string | for loop with `==` |
|
||||
| Matches | `matches`, `re` | string | string | `regexp.Regexp.Matches()` |
|
||||
| Is | `is` | bool* | bool | `==` |
|
||||
| Exists | `exists`, `ex` | any | n/a | n/a |
|
||||
|
||||
\*accepts strings: 1, t, T, true, True, TRUE, 0, f, F, false, False, FALSE
|
||||
|
||||
## Escaping
|
||||
|
||||
If you need to use a control character within a value (ie. not for controlling), escape it with `\`.
|
||||
It is recommended to wrap a word into parenthesis instead of escaping control characters, when possible.
|
||||
|
||||
| Location | Characters to be escaped |
|
||||
|---|---|
|
||||
| Within parenthesis (`"`) | `"`, `\` |
|
||||
| Everywhere else | `(`, `)`, `"`, `\`, `\t`, `\r`, `\n`, ` ` (space) |
|
||||
46
base/database/query/condition-and.go
Normal file
46
base/database/query/condition-and.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// And combines multiple conditions with a logical _AND_ operator.
|
||||
func And(conditions ...Condition) Condition {
|
||||
return &andCond{
|
||||
conditions: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
type andCond struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
func (c *andCond) complies(acc accessor.Accessor) bool {
|
||||
for _, cond := range c.conditions {
|
||||
if !cond.complies(acc) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *andCond) check() (err error) {
|
||||
for _, cond := range c.conditions {
|
||||
err = cond.check()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *andCond) string() string {
|
||||
all := make([]string, 0, len(c.conditions))
|
||||
for _, cond := range c.conditions {
|
||||
all = append(all, cond.string())
|
||||
}
|
||||
return fmt.Sprintf("(%s)", strings.Join(all, " and "))
|
||||
}
|
||||
69
base/database/query/condition-bool.go
Normal file
69
base/database/query/condition-bool.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type boolCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value bool
|
||||
}
|
||||
|
||||
func newBoolCondition(key string, operator uint8, value interface{}) *boolCondition {
|
||||
var parsedValue bool
|
||||
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
parsedValue = v
|
||||
case string:
|
||||
var err error
|
||||
parsedValue, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return &boolCondition{
|
||||
key: fmt.Sprintf("could not parse \"%s\" to bool: %s", v, err),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
default:
|
||||
return &boolCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for int64", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
|
||||
return &boolCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: parsedValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *boolCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetBool(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case Is:
|
||||
return comp == c.value
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *boolCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *boolCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %t", escapeString(c.key), getOpName(c.operator), c.value)
|
||||
}
|
||||
27
base/database/query/condition-error.go
Normal file
27
base/database/query/condition-error.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type errorCondition struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func newErrorCondition(err error) *errorCondition {
|
||||
return &errorCondition{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *errorCondition) complies(acc accessor.Accessor) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *errorCondition) check() error {
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *errorCondition) string() string {
|
||||
return "[ERROR]"
|
||||
}
|
||||
35
base/database/query/condition-exists.go
Normal file
35
base/database/query/condition-exists.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type existsCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
}
|
||||
|
||||
func newExistsCondition(key string, operator uint8) *existsCondition {
|
||||
return &existsCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *existsCondition) complies(acc accessor.Accessor) bool {
|
||||
return acc.Exists(c.key)
|
||||
}
|
||||
|
||||
func (c *existsCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *existsCondition) string() string {
|
||||
return fmt.Sprintf("%s %s", escapeString(c.key), getOpName(c.operator))
|
||||
}
|
||||
97
base/database/query/condition-float.go
Normal file
97
base/database/query/condition-float.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type floatCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value float64
|
||||
}
|
||||
|
||||
func newFloatCondition(key string, operator uint8, value interface{}) *floatCondition {
|
||||
var parsedValue float64
|
||||
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
parsedValue = float64(v)
|
||||
case int8:
|
||||
parsedValue = float64(v)
|
||||
case int16:
|
||||
parsedValue = float64(v)
|
||||
case int32:
|
||||
parsedValue = float64(v)
|
||||
case int64:
|
||||
parsedValue = float64(v)
|
||||
case uint:
|
||||
parsedValue = float64(v)
|
||||
case uint8:
|
||||
parsedValue = float64(v)
|
||||
case uint16:
|
||||
parsedValue = float64(v)
|
||||
case uint32:
|
||||
parsedValue = float64(v)
|
||||
case float32:
|
||||
parsedValue = float64(v)
|
||||
case float64:
|
||||
parsedValue = v
|
||||
case string:
|
||||
var err error
|
||||
parsedValue, err = strconv.ParseFloat(v, 64)
|
||||
if err != nil {
|
||||
return &floatCondition{
|
||||
key: fmt.Sprintf("could not parse %s to float64: %s", v, err),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
default:
|
||||
return &floatCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for float64", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
|
||||
return &floatCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: parsedValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *floatCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetFloat(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case FloatEquals:
|
||||
return comp == c.value
|
||||
case FloatGreaterThan:
|
||||
return comp > c.value
|
||||
case FloatGreaterThanOrEqual:
|
||||
return comp >= c.value
|
||||
case FloatLessThan:
|
||||
return comp < c.value
|
||||
case FloatLessThanOrEqual:
|
||||
return comp <= c.value
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *floatCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *floatCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %g", escapeString(c.key), getOpName(c.operator), c.value)
|
||||
}
|
||||
93
base/database/query/condition-int.go
Normal file
93
base/database/query/condition-int.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type intCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value int64
|
||||
}
|
||||
|
||||
func newIntCondition(key string, operator uint8, value interface{}) *intCondition {
|
||||
var parsedValue int64
|
||||
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
parsedValue = int64(v)
|
||||
case int8:
|
||||
parsedValue = int64(v)
|
||||
case int16:
|
||||
parsedValue = int64(v)
|
||||
case int32:
|
||||
parsedValue = int64(v)
|
||||
case int64:
|
||||
parsedValue = v
|
||||
case uint:
|
||||
parsedValue = int64(v)
|
||||
case uint8:
|
||||
parsedValue = int64(v)
|
||||
case uint16:
|
||||
parsedValue = int64(v)
|
||||
case uint32:
|
||||
parsedValue = int64(v)
|
||||
case string:
|
||||
var err error
|
||||
parsedValue, err = strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return &intCondition{
|
||||
key: fmt.Sprintf("could not parse %s to int64: %s (hint: use \"sameas\" to compare strings)", v, err),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
default:
|
||||
return &intCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for int64", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
|
||||
return &intCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: parsedValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *intCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetInt(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case Equals:
|
||||
return comp == c.value
|
||||
case GreaterThan:
|
||||
return comp > c.value
|
||||
case GreaterThanOrEqual:
|
||||
return comp >= c.value
|
||||
case LessThan:
|
||||
return comp < c.value
|
||||
case LessThanOrEqual:
|
||||
return comp <= c.value
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *intCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *intCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %d", escapeString(c.key), getOpName(c.operator), c.value)
|
||||
}
|
||||
36
base/database/query/condition-not.go
Normal file
36
base/database/query/condition-not.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// Not negates the supplied condition.
|
||||
func Not(c Condition) Condition {
|
||||
return ¬Cond{
|
||||
notC: c,
|
||||
}
|
||||
}
|
||||
|
||||
type notCond struct {
|
||||
notC Condition
|
||||
}
|
||||
|
||||
func (c *notCond) complies(acc accessor.Accessor) bool {
|
||||
return !c.notC.complies(acc)
|
||||
}
|
||||
|
||||
func (c *notCond) check() error {
|
||||
return c.notC.check()
|
||||
}
|
||||
|
||||
func (c *notCond) string() string {
|
||||
next := c.notC.string()
|
||||
if strings.HasPrefix(next, "(") {
|
||||
return fmt.Sprintf("not %s", c.notC.string())
|
||||
}
|
||||
splitted := strings.Split(next, " ")
|
||||
return strings.Join(append([]string{splitted[0], "not"}, splitted[1:]...), " ")
|
||||
}
|
||||
46
base/database/query/condition-or.go
Normal file
46
base/database/query/condition-or.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// Or combines multiple conditions with a logical _OR_ operator.
|
||||
func Or(conditions ...Condition) Condition {
|
||||
return &orCond{
|
||||
conditions: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
type orCond struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
func (c *orCond) complies(acc accessor.Accessor) bool {
|
||||
for _, cond := range c.conditions {
|
||||
if cond.complies(acc) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *orCond) check() (err error) {
|
||||
for _, cond := range c.conditions {
|
||||
err = cond.check()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *orCond) string() string {
|
||||
all := make([]string, 0, len(c.conditions))
|
||||
for _, cond := range c.conditions {
|
||||
all = append(all, cond.string())
|
||||
}
|
||||
return fmt.Sprintf("(%s)", strings.Join(all, " or "))
|
||||
}
|
||||
63
base/database/query/condition-regex.go
Normal file
63
base/database/query/condition-regex.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type regexCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
func newRegexCondition(key string, operator uint8, value interface{}) *regexCondition {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
r, err := regexp.Compile(v)
|
||||
if err != nil {
|
||||
return ®exCondition{
|
||||
key: fmt.Sprintf("could not compile regex \"%s\": %s", v, err),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
return ®exCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
regex: r,
|
||||
}
|
||||
default:
|
||||
return ®exCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for string", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *regexCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetString(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case Matches:
|
||||
return c.regex.MatchString(comp)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *regexCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *regexCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(c.regex.String()))
|
||||
}
|
||||
62
base/database/query/condition-string.go
Normal file
62
base/database/query/condition-string.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type stringCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value string
|
||||
}
|
||||
|
||||
func newStringCondition(key string, operator uint8, value interface{}) *stringCondition {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return &stringCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: v,
|
||||
}
|
||||
default:
|
||||
return &stringCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for string", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *stringCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetString(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case SameAs:
|
||||
return c.value == comp
|
||||
case Contains:
|
||||
return strings.Contains(comp, c.value)
|
||||
case StartsWith:
|
||||
return strings.HasPrefix(comp, c.value)
|
||||
case EndsWith:
|
||||
return strings.HasSuffix(comp, c.value)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *stringCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stringCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(c.value))
|
||||
}
|
||||
69
base/database/query/condition-stringslice.go
Normal file
69
base/database/query/condition-stringslice.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
)
|
||||
|
||||
type stringSliceCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value []string
|
||||
}
|
||||
|
||||
func newStringSliceCondition(key string, operator uint8, value interface{}) *stringSliceCondition {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
parsedValue := strings.Split(v, ",")
|
||||
if len(parsedValue) < 2 {
|
||||
return &stringSliceCondition{
|
||||
key: v,
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
return &stringSliceCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: parsedValue,
|
||||
}
|
||||
case []string:
|
||||
return &stringSliceCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: v,
|
||||
}
|
||||
default:
|
||||
return &stringSliceCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for []string", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *stringSliceCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetString(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case In:
|
||||
return utils.StringInSlice(c.value, comp)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *stringSliceCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return fmt.Errorf("could not parse \"%s\" to []string", c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stringSliceCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(strings.Join(c.value, ",")))
|
||||
}
|
||||
71
base/database/query/condition.go
Normal file
71
base/database/query/condition.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// Condition is an interface to provide a common api to all condition types.
|
||||
type Condition interface {
|
||||
complies(acc accessor.Accessor) bool
|
||||
check() error
|
||||
string() string
|
||||
}
|
||||
|
||||
// Operators.
|
||||
const (
|
||||
Equals uint8 = iota // int
|
||||
GreaterThan // int
|
||||
GreaterThanOrEqual // int
|
||||
LessThan // int
|
||||
LessThanOrEqual // int
|
||||
FloatEquals // float
|
||||
FloatGreaterThan // float
|
||||
FloatGreaterThanOrEqual // float
|
||||
FloatLessThan // float
|
||||
FloatLessThanOrEqual // float
|
||||
SameAs // string
|
||||
Contains // string
|
||||
StartsWith // string
|
||||
EndsWith // string
|
||||
In // stringSlice
|
||||
Matches // regex
|
||||
Is // bool: accepts 1, t, T, TRUE, true, True, 0, f, F, FALSE
|
||||
Exists // any
|
||||
|
||||
errorPresent uint8 = 255
|
||||
)
|
||||
|
||||
// Where returns a condition to add to a query.
|
||||
func Where(key string, operator uint8, value interface{}) Condition {
|
||||
switch operator {
|
||||
case Equals,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
LessThan,
|
||||
LessThanOrEqual:
|
||||
return newIntCondition(key, operator, value)
|
||||
case FloatEquals,
|
||||
FloatGreaterThan,
|
||||
FloatGreaterThanOrEqual,
|
||||
FloatLessThan,
|
||||
FloatLessThanOrEqual:
|
||||
return newFloatCondition(key, operator, value)
|
||||
case SameAs,
|
||||
Contains,
|
||||
StartsWith,
|
||||
EndsWith:
|
||||
return newStringCondition(key, operator, value)
|
||||
case In:
|
||||
return newStringSliceCondition(key, operator, value)
|
||||
case Matches:
|
||||
return newRegexCondition(key, operator, value)
|
||||
case Is:
|
||||
return newBoolCondition(key, operator, value)
|
||||
case Exists:
|
||||
return newExistsCondition(key, operator)
|
||||
default:
|
||||
return newErrorCondition(fmt.Errorf("no operator with ID %d", operator))
|
||||
}
|
||||
}
|
||||
86
base/database/query/condition_test.go
Normal file
86
base/database/query/condition_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package query
|
||||
|
||||
import "testing"
|
||||
|
||||
func testSuccess(t *testing.T, c Condition) {
|
||||
t.Helper()
|
||||
|
||||
err := c.check()
|
||||
if err != nil {
|
||||
t.Errorf("failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaces(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testSuccess(t, newIntCondition("banana", Equals, uint(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, uint8(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, uint16(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, uint32(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int8(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int16(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int32(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int64(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, "1"))
|
||||
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, uint(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, uint8(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, uint16(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, uint32(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int8(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int16(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int32(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int64(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, float32(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, float64(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, "1.1"))
|
||||
|
||||
testSuccess(t, newStringCondition("banana", SameAs, "coconut"))
|
||||
testSuccess(t, newRegexCondition("banana", Matches, "coconut"))
|
||||
testSuccess(t, newStringSliceCondition("banana", FloatEquals, []string{"banana", "coconut"}))
|
||||
testSuccess(t, newStringSliceCondition("banana", FloatEquals, "banana,coconut"))
|
||||
}
|
||||
|
||||
func testCondError(t *testing.T, c Condition) {
|
||||
t.Helper()
|
||||
|
||||
err := c.check()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConditionErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// test invalid value types
|
||||
testCondError(t, newBoolCondition("banana", Is, 1))
|
||||
testCondError(t, newFloatCondition("banana", FloatEquals, true))
|
||||
testCondError(t, newIntCondition("banana", Equals, true))
|
||||
testCondError(t, newStringCondition("banana", SameAs, 1))
|
||||
testCondError(t, newRegexCondition("banana", Matches, 1))
|
||||
testCondError(t, newStringSliceCondition("banana", Matches, 1))
|
||||
|
||||
// test error presence
|
||||
testCondError(t, newBoolCondition("banana", errorPresent, true))
|
||||
testCondError(t, And(newBoolCondition("banana", errorPresent, true)))
|
||||
testCondError(t, Or(newBoolCondition("banana", errorPresent, true)))
|
||||
testCondError(t, newExistsCondition("banana", errorPresent))
|
||||
testCondError(t, newFloatCondition("banana", errorPresent, 1.1))
|
||||
testCondError(t, newIntCondition("banana", errorPresent, 1))
|
||||
testCondError(t, newStringCondition("banana", errorPresent, "coconut"))
|
||||
testCondError(t, newRegexCondition("banana", errorPresent, "coconut"))
|
||||
}
|
||||
|
||||
func TestWhere(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := Where("", 254, nil)
|
||||
err := c.check()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
}
|
||||
53
base/database/query/operators.go
Normal file
53
base/database/query/operators.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package query
|
||||
|
||||
var (
|
||||
operatorNames = map[string]uint8{
|
||||
"==": Equals,
|
||||
">": GreaterThan,
|
||||
">=": GreaterThanOrEqual,
|
||||
"<": LessThan,
|
||||
"<=": LessThanOrEqual,
|
||||
"f==": FloatEquals,
|
||||
"f>": FloatGreaterThan,
|
||||
"f>=": FloatGreaterThanOrEqual,
|
||||
"f<": FloatLessThan,
|
||||
"f<=": FloatLessThanOrEqual,
|
||||
"sameas": SameAs,
|
||||
"s==": SameAs,
|
||||
"contains": Contains,
|
||||
"co": Contains,
|
||||
"startswith": StartsWith,
|
||||
"sw": StartsWith,
|
||||
"endswith": EndsWith,
|
||||
"ew": EndsWith,
|
||||
"in": In,
|
||||
"matches": Matches,
|
||||
"re": Matches,
|
||||
"is": Is,
|
||||
"exists": Exists,
|
||||
"ex": Exists,
|
||||
}
|
||||
|
||||
primaryNames = make(map[uint8]string)
|
||||
)
|
||||
|
||||
func init() {
|
||||
for opName, opID := range operatorNames {
|
||||
name, ok := primaryNames[opID]
|
||||
if ok {
|
||||
if len(name) < len(opName) {
|
||||
primaryNames[opID] = opName
|
||||
}
|
||||
} else {
|
||||
primaryNames[opID] = opName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getOpName(operator uint8) string {
|
||||
name, ok := primaryNames[operator]
|
||||
if ok {
|
||||
return name
|
||||
}
|
||||
return "[unknown]"
|
||||
}
|
||||
11
base/database/query/operators_test.go
Normal file
11
base/database/query/operators_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package query
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetOpName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if getOpName(254) != "[unknown]" {
|
||||
t.Error("unexpected output")
|
||||
}
|
||||
}
|
||||
350
base/database/query/parser.go
Normal file
350
base/database/query/parser.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type snippet struct {
|
||||
text string
|
||||
globalPosition int
|
||||
}
|
||||
|
||||
// ParseQuery parses a plaintext query. Special characters (that must be escaped with a '\') are: `\()` and any whitespaces.
|
||||
//
|
||||
//nolint:gocognit
|
||||
func ParseQuery(query string) (*Query, error) {
|
||||
snippets, err := extractSnippets(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
snippetsPos := 0
|
||||
|
||||
getSnippet := func() (*snippet, error) {
|
||||
// order is important, as parseAndOr will always consume one additional snippet.
|
||||
snippetsPos++
|
||||
if snippetsPos > len(snippets) {
|
||||
return nil, fmt.Errorf("unexpected end at position %d", len(query))
|
||||
}
|
||||
return snippets[snippetsPos-1], nil
|
||||
}
|
||||
remainingSnippets := func() int {
|
||||
return len(snippets) - snippetsPos
|
||||
}
|
||||
|
||||
// check for query word
|
||||
queryWord, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if queryWord.text != "query" {
|
||||
return nil, errors.New("queries must start with \"query\"")
|
||||
}
|
||||
|
||||
// get prefix
|
||||
prefix, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := New(prefix.text)
|
||||
|
||||
for remainingSnippets() > 0 {
|
||||
command, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch command.text {
|
||||
case "where":
|
||||
if q.where != nil {
|
||||
return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
|
||||
// parse conditions
|
||||
condition, err := parseAndOr(getSnippet, remainingSnippets, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// go one back, as parseAndOr had to check if its done
|
||||
snippetsPos--
|
||||
|
||||
q.Where(condition)
|
||||
case "orderby":
|
||||
if q.orderBy != "" {
|
||||
return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
|
||||
orderBySnippet, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.OrderBy(orderBySnippet.text)
|
||||
case "limit":
|
||||
if q.limit != 0 {
|
||||
return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
|
||||
limitSnippet, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
limit, err := strconv.ParseUint(limitSnippet.text, 10, 31)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse integer (%s) at position %d", limitSnippet.text, limitSnippet.globalPosition)
|
||||
}
|
||||
|
||||
q.Limit(int(limit))
|
||||
case "offset":
|
||||
if q.offset != 0 {
|
||||
return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
|
||||
offsetSnippet, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset, err := strconv.ParseUint(offsetSnippet.text, 10, 31)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse integer (%s) at position %d", offsetSnippet.text, offsetSnippet.globalPosition)
|
||||
}
|
||||
|
||||
q.Offset(int(offset))
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown clause \"%s\" at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
}
|
||||
|
||||
return q.Check()
|
||||
}
|
||||
|
||||
func extractSnippets(text string) (snippets []*snippet, err error) {
|
||||
skip := false
|
||||
start := -1
|
||||
inParenthesis := false
|
||||
var pos int
|
||||
var char rune
|
||||
|
||||
for pos, char = range text {
|
||||
|
||||
// skip
|
||||
if skip {
|
||||
skip = false
|
||||
continue
|
||||
}
|
||||
if char == '\\' {
|
||||
skip = true
|
||||
}
|
||||
|
||||
// wait for parenthesis to be overs
|
||||
if inParenthesis {
|
||||
if char == '"' {
|
||||
snippets = append(snippets, &snippet{
|
||||
text: prepToken(text[start+1 : pos]),
|
||||
globalPosition: start + 1,
|
||||
})
|
||||
start = -1
|
||||
inParenthesis = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// handle segments
|
||||
switch char {
|
||||
case '\t', '\n', '\r', ' ', '(', ')':
|
||||
if start >= 0 {
|
||||
snippets = append(snippets, &snippet{
|
||||
text: prepToken(text[start:pos]),
|
||||
globalPosition: start + 1,
|
||||
})
|
||||
start = -1
|
||||
}
|
||||
default:
|
||||
if start == -1 {
|
||||
start = pos
|
||||
}
|
||||
}
|
||||
|
||||
// handle special segment characters
|
||||
switch char {
|
||||
case '(', ')':
|
||||
snippets = append(snippets, &snippet{
|
||||
text: text[pos : pos+1],
|
||||
globalPosition: pos + 1,
|
||||
})
|
||||
case '"':
|
||||
if start < pos {
|
||||
return nil, fmt.Errorf("parenthesis ('\"') may not be used within words, please escape with '\\' (position: %d)", pos+1)
|
||||
}
|
||||
inParenthesis = true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// add last
|
||||
if start >= 0 {
|
||||
snippets = append(snippets, &snippet{
|
||||
text: prepToken(text[start : pos+1]),
|
||||
globalPosition: start + 1,
|
||||
})
|
||||
}
|
||||
|
||||
return snippets, nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() int, rootCondition bool) (Condition, error) {
|
||||
var (
|
||||
isOr = false
|
||||
typeSet = false
|
||||
wrapInNot = false
|
||||
expectingMore = true
|
||||
conditions []Condition
|
||||
)
|
||||
|
||||
for {
|
||||
if !expectingMore && rootCondition && remainingSnippets() == 0 {
|
||||
// advance snippetsPos by one, as it will be set back by 1
|
||||
_, _ = getSnippet()
|
||||
if len(conditions) == 1 {
|
||||
return conditions[0], nil
|
||||
}
|
||||
if isOr {
|
||||
return Or(conditions...), nil
|
||||
}
|
||||
return And(conditions...), nil
|
||||
}
|
||||
|
||||
firstSnippet, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !expectingMore && rootCondition {
|
||||
switch firstSnippet.text {
|
||||
case "orderby", "limit", "offset":
|
||||
if len(conditions) == 1 {
|
||||
return conditions[0], nil
|
||||
}
|
||||
if isOr {
|
||||
return Or(conditions...), nil
|
||||
}
|
||||
return And(conditions...), nil
|
||||
}
|
||||
}
|
||||
|
||||
switch firstSnippet.text {
|
||||
case "(":
|
||||
condition, err := parseAndOr(getSnippet, remainingSnippets, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wrapInNot {
|
||||
conditions = append(conditions, Not(condition))
|
||||
wrapInNot = false
|
||||
} else {
|
||||
conditions = append(conditions, condition)
|
||||
}
|
||||
expectingMore = true
|
||||
case ")":
|
||||
if len(conditions) == 1 {
|
||||
return conditions[0], nil
|
||||
}
|
||||
if isOr {
|
||||
return Or(conditions...), nil
|
||||
}
|
||||
return And(conditions...), nil
|
||||
case "and":
|
||||
if typeSet && isOr {
|
||||
return nil, fmt.Errorf("you may not mix \"and\" and \"or\" (position: %d)", firstSnippet.globalPosition)
|
||||
}
|
||||
isOr = false
|
||||
typeSet = true
|
||||
expectingMore = true
|
||||
case "or":
|
||||
if typeSet && !isOr {
|
||||
return nil, fmt.Errorf("you may not mix \"and\" and \"or\" (position: %d)", firstSnippet.globalPosition)
|
||||
}
|
||||
isOr = true
|
||||
typeSet = true
|
||||
expectingMore = true
|
||||
case "not":
|
||||
wrapInNot = true
|
||||
expectingMore = true
|
||||
default:
|
||||
condition, err := parseCondition(firstSnippet, getSnippet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wrapInNot {
|
||||
conditions = append(conditions, Not(condition))
|
||||
wrapInNot = false
|
||||
} else {
|
||||
conditions = append(conditions, condition)
|
||||
}
|
||||
expectingMore = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseCondition(firstSnippet *snippet, getSnippet func() (*snippet, error)) (Condition, error) {
|
||||
wrapInNot := false
|
||||
|
||||
// get operator name
|
||||
opName, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// negate?
|
||||
if opName.text == "not" {
|
||||
wrapInNot = true
|
||||
opName, err = getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// get operator
|
||||
operator, ok := operatorNames[opName.text]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown operator at position %d", opName.globalPosition)
|
||||
}
|
||||
|
||||
// don't need a value for "exists"
|
||||
if operator == Exists {
|
||||
if wrapInNot {
|
||||
return Not(Where(firstSnippet.text, operator, nil)), nil
|
||||
}
|
||||
return Where(firstSnippet.text, operator, nil), nil
|
||||
}
|
||||
|
||||
// get value
|
||||
value, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wrapInNot {
|
||||
return Not(Where(firstSnippet.text, operator, value.text)), nil
|
||||
}
|
||||
return Where(firstSnippet.text, operator, value.text), nil
|
||||
}
|
||||
|
||||
var escapeReplacer = regexp.MustCompile(`\\([^\\])`)
|
||||
|
||||
// prepToken removes surrounding parenthesis and escape characters.
|
||||
func prepToken(text string) string {
|
||||
return escapeReplacer.ReplaceAllString(strings.Trim(text, "\""), "$1")
|
||||
}
|
||||
|
||||
// escapeString correctly escapes a snippet for printing.
|
||||
func escapeString(token string) string {
|
||||
// check if token contains characters that need to be escaped
|
||||
if strings.ContainsAny(token, "()\"\\\t\r\n ") {
|
||||
// put the token in parenthesis and only escape \ and "
|
||||
return fmt.Sprintf("\"%s\"", strings.ReplaceAll(token, "\"", "\\\""))
|
||||
}
|
||||
return token
|
||||
}
|
||||
177
base/database/query/parser_test.go
Normal file
177
base/database/query/parser_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
func TestExtractSnippets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
text1 := `query test: where ( "bananas" > 100 and monkeys.# <= "12")or(coconuts < 10 "and" area > 50) or name sameas Julian or name matches ^King\ `
|
||||
result1 := []*snippet{
|
||||
{text: "query", globalPosition: 1},
|
||||
{text: "test:", globalPosition: 7},
|
||||
{text: "where", globalPosition: 13},
|
||||
{text: "(", globalPosition: 19},
|
||||
{text: "bananas", globalPosition: 21},
|
||||
{text: ">", globalPosition: 31},
|
||||
{text: "100", globalPosition: 33},
|
||||
{text: "and", globalPosition: 37},
|
||||
{text: "monkeys.#", globalPosition: 41},
|
||||
{text: "<=", globalPosition: 51},
|
||||
{text: "12", globalPosition: 54},
|
||||
{text: ")", globalPosition: 58},
|
||||
{text: "or", globalPosition: 59},
|
||||
{text: "(", globalPosition: 61},
|
||||
{text: "coconuts", globalPosition: 62},
|
||||
{text: "<", globalPosition: 71},
|
||||
{text: "10", globalPosition: 73},
|
||||
{text: "and", globalPosition: 76},
|
||||
{text: "area", globalPosition: 82},
|
||||
{text: ">", globalPosition: 87},
|
||||
{text: "50", globalPosition: 89},
|
||||
{text: ")", globalPosition: 91},
|
||||
{text: "or", globalPosition: 93},
|
||||
{text: "name", globalPosition: 96},
|
||||
{text: "sameas", globalPosition: 101},
|
||||
{text: "Julian", globalPosition: 108},
|
||||
{text: "or", globalPosition: 115},
|
||||
{text: "name", globalPosition: 118},
|
||||
{text: "matches", globalPosition: 123},
|
||||
{text: "^King ", globalPosition: 131},
|
||||
}
|
||||
|
||||
snippets, err := extractSnippets(text1)
|
||||
if err != nil {
|
||||
t.Errorf("failed to extract snippets: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result1, snippets) {
|
||||
t.Errorf("unexpected results:")
|
||||
for _, el := range snippets {
|
||||
t.Errorf("%+v", el)
|
||||
}
|
||||
}
|
||||
|
||||
// t.Error(spew.Sprintf("%v", treeElement))
|
||||
}
|
||||
|
||||
func testParsing(t *testing.T, queryText string, expectedResult *Query) {
|
||||
t.Helper()
|
||||
|
||||
_, err := expectedResult.Check()
|
||||
if err != nil {
|
||||
t.Errorf("failed to create query: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
q, err := ParseQuery(queryText)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse query: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if queryText != q.Print() {
|
||||
t.Errorf("string match failed: %s", q.Print())
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(expectedResult, q) {
|
||||
t.Error("deepqual match failed.")
|
||||
t.Error("got:")
|
||||
t.Error(spew.Sdump(q))
|
||||
t.Error("expected:")
|
||||
t.Error(spew.Sdump(expectedResult))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
text1 := `query test: where (bananas > 100 and monkeys.# <= 12) or not (coconuts < 10 and area not > 50) or name sameas Julian or name matches "^King " orderby name limit 10 offset 20`
|
||||
result1 := New("test:").Where(Or(
|
||||
And(
|
||||
Where("bananas", GreaterThan, 100),
|
||||
Where("monkeys.#", LessThanOrEqual, 12),
|
||||
),
|
||||
Not(And(
|
||||
Where("coconuts", LessThan, 10),
|
||||
Not(Where("area", GreaterThan, 50)),
|
||||
)),
|
||||
Where("name", SameAs, "Julian"),
|
||||
Where("name", Matches, "^King "),
|
||||
)).OrderBy("name").Limit(10).Offset(20)
|
||||
testParsing(t, text1, result1)
|
||||
|
||||
testParsing(t, `query test: orderby name`, New("test:").OrderBy("name"))
|
||||
testParsing(t, `query test: limit 10`, New("test:").Limit(10))
|
||||
testParsing(t, `query test: offset 10`, New("test:").Offset(10))
|
||||
testParsing(t, `query test: where banana matches ^ban`, New("test:").Where(Where("banana", Matches, "^ban")))
|
||||
testParsing(t, `query test: where banana exists`, New("test:").Where(Where("banana", Exists, nil)))
|
||||
testParsing(t, `query test: where banana not exists`, New("test:").Where(Not(Where("banana", Exists, nil))))
|
||||
|
||||
// test all operators
|
||||
testParsing(t, `query test: where banana == 1`, New("test:").Where(Where("banana", Equals, 1)))
|
||||
testParsing(t, `query test: where banana > 1`, New("test:").Where(Where("banana", GreaterThan, 1)))
|
||||
testParsing(t, `query test: where banana >= 1`, New("test:").Where(Where("banana", GreaterThanOrEqual, 1)))
|
||||
testParsing(t, `query test: where banana < 1`, New("test:").Where(Where("banana", LessThan, 1)))
|
||||
testParsing(t, `query test: where banana <= 1`, New("test:").Where(Where("banana", LessThanOrEqual, 1)))
|
||||
testParsing(t, `query test: where banana f== 1.1`, New("test:").Where(Where("banana", FloatEquals, 1.1)))
|
||||
testParsing(t, `query test: where banana f> 1.1`, New("test:").Where(Where("banana", FloatGreaterThan, 1.1)))
|
||||
testParsing(t, `query test: where banana f>= 1.1`, New("test:").Where(Where("banana", FloatGreaterThanOrEqual, 1.1)))
|
||||
testParsing(t, `query test: where banana f< 1.1`, New("test:").Where(Where("banana", FloatLessThan, 1.1)))
|
||||
testParsing(t, `query test: where banana f<= 1.1`, New("test:").Where(Where("banana", FloatLessThanOrEqual, 1.1)))
|
||||
testParsing(t, `query test: where banana sameas banana`, New("test:").Where(Where("banana", SameAs, "banana")))
|
||||
testParsing(t, `query test: where banana contains banana`, New("test:").Where(Where("banana", Contains, "banana")))
|
||||
testParsing(t, `query test: where banana startswith banana`, New("test:").Where(Where("banana", StartsWith, "banana")))
|
||||
testParsing(t, `query test: where banana endswith banana`, New("test:").Where(Where("banana", EndsWith, "banana")))
|
||||
testParsing(t, `query test: where banana in banana,coconut`, New("test:").Where(Where("banana", In, []string{"banana", "coconut"})))
|
||||
testParsing(t, `query test: where banana matches banana`, New("test:").Where(Where("banana", Matches, "banana")))
|
||||
testParsing(t, `query test: where banana is true`, New("test:").Where(Where("banana", Is, true)))
|
||||
testParsing(t, `query test: where banana exists`, New("test:").Where(Where("banana", Exists, nil)))
|
||||
|
||||
// special
|
||||
testParsing(t, `query test: where banana not exists`, New("test:").Where(Not(Where("banana", Exists, nil))))
|
||||
}
|
||||
|
||||
func testParseError(t *testing.T, queryText string, expectedErrorString string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := ParseQuery(queryText)
|
||||
if err == nil {
|
||||
t.Errorf("should fail to parse: %s", queryText)
|
||||
return
|
||||
}
|
||||
if err.Error() != expectedErrorString {
|
||||
t.Errorf("unexpected error for query: %s\nwanted: %s\n got: %s", queryText, expectedErrorString, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// syntax
|
||||
testParseError(t, `query`, `unexpected end at position 5`)
|
||||
testParseError(t, `query test: where`, `unexpected end at position 17`)
|
||||
testParseError(t, `query test: where (`, `unexpected end at position 19`)
|
||||
testParseError(t, `query test: where )`, `unknown clause ")" at position 19`)
|
||||
testParseError(t, `query test: where not`, `unexpected end at position 21`)
|
||||
testParseError(t, `query test: where banana`, `unexpected end at position 24`)
|
||||
testParseError(t, `query test: where banana >`, `unexpected end at position 26`)
|
||||
testParseError(t, `query test: where banana nope`, `unknown operator at position 26`)
|
||||
testParseError(t, `query test: where banana exists or`, `unexpected end at position 34`)
|
||||
testParseError(t, `query test: where banana exists and`, `unexpected end at position 35`)
|
||||
testParseError(t, `query test: where banana exists and (`, `unexpected end at position 37`)
|
||||
testParseError(t, `query test: where banana exists and banana is true or`, `you may not mix "and" and "or" (position: 52)`)
|
||||
testParseError(t, `query test: where banana exists or banana is true and`, `you may not mix "and" and "or" (position: 51)`)
|
||||
// testParseError(t, `query test: where banana exists and (`, ``)
|
||||
|
||||
// value parsing error
|
||||
testParseError(t, `query test: where banana == banana`, `could not parse banana to int64: strconv.ParseInt: parsing "banana": invalid syntax (hint: use "sameas" to compare strings)`)
|
||||
testParseError(t, `query test: where banana f== banana`, `could not parse banana to float64: strconv.ParseFloat: parsing "banana": invalid syntax`)
|
||||
testParseError(t, `query test: where banana in banana`, `could not parse "banana" to []string`)
|
||||
testParseError(t, `query test: where banana matches [banana`, "could not compile regex \"[banana\": error parsing regexp: missing closing ]: `[banana`")
|
||||
testParseError(t, `query test: where banana is great`, `could not parse "great" to bool: strconv.ParseBool: parsing "great": invalid syntax`)
|
||||
}
|
||||
170
base/database/query/query.go
Normal file
170
base/database/query/query.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// Example:
|
||||
// q.New("core:/",
|
||||
// q.Where("a", q.GreaterThan, 0),
|
||||
// q.Where("b", q.Equals, 0),
|
||||
// q.Or(
|
||||
// q.Where("c", q.StartsWith, "x"),
|
||||
// q.Where("d", q.Contains, "y")
|
||||
// )
|
||||
// )
|
||||
|
||||
// Query contains a compiled query.
|
||||
type Query struct {
|
||||
checked bool
|
||||
dbName string
|
||||
dbKeyPrefix string
|
||||
where Condition
|
||||
orderBy string
|
||||
limit int
|
||||
offset int
|
||||
}
|
||||
|
||||
// New creates a new query with the supplied prefix.
|
||||
func New(prefix string) *Query {
|
||||
dbName, dbKeyPrefix := record.ParseKey(prefix)
|
||||
return &Query{
|
||||
dbName: dbName,
|
||||
dbKeyPrefix: dbKeyPrefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Where adds filtering.
|
||||
func (q *Query) Where(condition Condition) *Query {
|
||||
q.where = condition
|
||||
return q
|
||||
}
|
||||
|
||||
// Limit limits the number of returned results.
|
||||
func (q *Query) Limit(limit int) *Query {
|
||||
q.limit = limit
|
||||
return q
|
||||
}
|
||||
|
||||
// Offset sets the query offset.
|
||||
func (q *Query) Offset(offset int) *Query {
|
||||
q.offset = offset
|
||||
return q
|
||||
}
|
||||
|
||||
// OrderBy orders the results by the given key.
|
||||
func (q *Query) OrderBy(key string) *Query {
|
||||
q.orderBy = key
|
||||
return q
|
||||
}
|
||||
|
||||
// Check checks for errors in the query.
|
||||
func (q *Query) Check() (*Query, error) {
|
||||
if q.checked {
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// check condition
|
||||
if q.where != nil {
|
||||
err := q.where.check()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
q.checked = true
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// MustBeValid checks for errors in the query and panics if there is an error.
|
||||
func (q *Query) MustBeValid() *Query {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// IsChecked returns whether they query was checked.
|
||||
func (q *Query) IsChecked() bool {
|
||||
return q.checked
|
||||
}
|
||||
|
||||
// MatchesKey checks whether the query matches the supplied database key (key without database prefix).
|
||||
func (q *Query) MatchesKey(dbKey string) bool {
|
||||
return strings.HasPrefix(dbKey, q.dbKeyPrefix)
|
||||
}
|
||||
|
||||
// MatchesRecord checks whether the query matches the supplied database record (value only).
|
||||
func (q *Query) MatchesRecord(r record.Record) bool {
|
||||
if q.where == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
acc := r.GetAccessor(r)
|
||||
if acc == nil {
|
||||
return false
|
||||
}
|
||||
return q.where.complies(acc)
|
||||
}
|
||||
|
||||
// MatchesAccessor checks whether the query matches the supplied accessor (value only).
|
||||
func (q *Query) MatchesAccessor(acc accessor.Accessor) bool {
|
||||
if q.where == nil {
|
||||
return true
|
||||
}
|
||||
return q.where.complies(acc)
|
||||
}
|
||||
|
||||
// Matches checks whether the query matches the supplied database record.
|
||||
func (q *Query) Matches(r record.Record) bool {
|
||||
if !q.MatchesKey(r.DatabaseKey()) {
|
||||
return false
|
||||
}
|
||||
return q.MatchesRecord(r)
|
||||
}
|
||||
|
||||
// Print returns the string representation of the query.
|
||||
func (q *Query) Print() string {
|
||||
var where string
|
||||
if q.where != nil {
|
||||
where = q.where.string()
|
||||
if where != "" {
|
||||
if strings.HasPrefix(where, "(") {
|
||||
where = where[1 : len(where)-1]
|
||||
}
|
||||
where = fmt.Sprintf(" where %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
var orderBy string
|
||||
if q.orderBy != "" {
|
||||
orderBy = fmt.Sprintf(" orderby %s", q.orderBy)
|
||||
}
|
||||
|
||||
var limit string
|
||||
if q.limit > 0 {
|
||||
limit = fmt.Sprintf(" limit %d", q.limit)
|
||||
}
|
||||
|
||||
var offset string
|
||||
if q.offset > 0 {
|
||||
offset = fmt.Sprintf(" offset %d", q.offset)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("query %s:%s%s%s%s%s", q.dbName, q.dbKeyPrefix, where, orderBy, limit, offset)
|
||||
}
|
||||
|
||||
// DatabaseName returns the name of the database.
|
||||
func (q *Query) DatabaseName() string {
|
||||
return q.dbName
|
||||
}
|
||||
|
||||
// DatabaseKeyPrefix returns the key prefix for the database.
|
||||
func (q *Query) DatabaseKeyPrefix() string {
|
||||
return q.dbKeyPrefix
|
||||
}
|
||||
113
base/database/query/query_test.go
Normal file
113
base/database/query/query_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
//nolint:unparam
|
||||
package query
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// copied from https://github.com/tidwall/gjson/blob/master/gjson_test.go
|
||||
var testJSON = `{"age":100, "name":{"here":"B\\\"R"},
|
||||
"noop":{"what is a wren?":"a bird"},
|
||||
"happy":true,"immortal":false,
|
||||
"items":[1,2,3,{"tags":[1,2,3],"points":[[1,2],[3,4]]},4,5,6,7],
|
||||
"arr":["1",2,"3",{"hello":"world"},"4",5],
|
||||
"vals":[1,2,3,{"sadf":sdf"asdf"}],"name":{"first":"tom","last":null},
|
||||
"created":"2014-05-16T08:28:06.989Z",
|
||||
"loggy":{
|
||||
"programmers": [
|
||||
{
|
||||
"firstName": "Brett",
|
||||
"lastName": "McLaughlin",
|
||||
"email": "aaaa",
|
||||
"tag": "good"
|
||||
},
|
||||
{
|
||||
"firstName": "Jason",
|
||||
"lastName": "Hunter",
|
||||
"email": "bbbb",
|
||||
"tag": "bad"
|
||||
},
|
||||
{
|
||||
"firstName": "Elliotte",
|
||||
"lastName": "Harold",
|
||||
"email": "cccc",
|
||||
"tag":, "good"
|
||||
},
|
||||
{
|
||||
"firstName": 1002.3,
|
||||
"age": 101
|
||||
}
|
||||
]
|
||||
},
|
||||
"lastly":{"yay":"final"},
|
||||
"temperature": 120.413
|
||||
}`
|
||||
|
||||
func testQuery(t *testing.T, r record.Record, shouldMatch bool, condition Condition) {
|
||||
t.Helper()
|
||||
|
||||
q := New("test:").Where(condition).MustBeValid()
|
||||
// fmt.Printf("%s\n", q.Print())
|
||||
|
||||
matched := q.Matches(r)
|
||||
switch {
|
||||
case !matched && shouldMatch:
|
||||
t.Errorf("should match: %s", q.Print())
|
||||
case matched && !shouldMatch:
|
||||
t.Errorf("should not match: %s", q.Print())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// if !gjson.Valid(testJSON) {
|
||||
// t.Fatal("test json is invalid")
|
||||
// }
|
||||
r, err := record.NewWrapper("", nil, dsd.JSON, []byte(testJSON))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testQuery(t, r, true, Where("age", Equals, 100))
|
||||
testQuery(t, r, true, Where("age", GreaterThan, uint8(99)))
|
||||
testQuery(t, r, true, Where("age", GreaterThanOrEqual, 99))
|
||||
testQuery(t, r, true, Where("age", GreaterThanOrEqual, 100))
|
||||
testQuery(t, r, true, Where("age", LessThan, 101))
|
||||
testQuery(t, r, true, Where("age", LessThanOrEqual, "101"))
|
||||
testQuery(t, r, true, Where("age", LessThanOrEqual, 100))
|
||||
|
||||
testQuery(t, r, true, Where("temperature", FloatEquals, 120.413))
|
||||
testQuery(t, r, true, Where("temperature", FloatGreaterThan, 120))
|
||||
testQuery(t, r, true, Where("temperature", FloatGreaterThanOrEqual, 120))
|
||||
testQuery(t, r, true, Where("temperature", FloatGreaterThanOrEqual, 120.413))
|
||||
testQuery(t, r, true, Where("temperature", FloatLessThan, 121))
|
||||
testQuery(t, r, true, Where("temperature", FloatLessThanOrEqual, "121"))
|
||||
testQuery(t, r, true, Where("temperature", FloatLessThanOrEqual, "120.413"))
|
||||
|
||||
testQuery(t, r, true, Where("lastly.yay", SameAs, "final"))
|
||||
testQuery(t, r, true, Where("lastly.yay", Contains, "ina"))
|
||||
testQuery(t, r, true, Where("lastly.yay", StartsWith, "fin"))
|
||||
testQuery(t, r, true, Where("lastly.yay", EndsWith, "nal"))
|
||||
testQuery(t, r, true, Where("lastly.yay", In, "draft,final"))
|
||||
testQuery(t, r, true, Where("lastly.yay", In, "final,draft"))
|
||||
|
||||
testQuery(t, r, true, Where("happy", Is, true))
|
||||
testQuery(t, r, true, Where("happy", Is, "true"))
|
||||
testQuery(t, r, true, Where("happy", Is, "t"))
|
||||
testQuery(t, r, true, Not(Where("happy", Is, "0")))
|
||||
testQuery(t, r, true, And(
|
||||
Where("happy", Is, "1"),
|
||||
Not(Or(
|
||||
Where("happy", Is, false),
|
||||
Where("happy", Is, "f"),
|
||||
)),
|
||||
))
|
||||
|
||||
testQuery(t, r, true, Where("happy", Exists, nil))
|
||||
|
||||
testQuery(t, r, true, Where("created", Matches, "^2014-[0-9]{2}-[0-9]{2}T"))
|
||||
}
|
||||
156
base/database/record/base.go
Normal file
156
base/database/record/base.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/structures/container"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// TODO(ppacher):
|
||||
// we can reduce the record.Record interface a lot by moving
|
||||
// most of those functions that require the Record as it's first
|
||||
// parameter to static package functions
|
||||
// (i.e. Marshal, MarshalRecord, GetAccessor, ...).
|
||||
// We should also consider given Base a GetBase() *Base method
|
||||
// that returns itself. This way we can remove almost all Base
|
||||
// only methods from the record.Record interface. That is, we can
|
||||
// remove all those CreateMeta, UpdateMeta, ... stuff from the
|
||||
// interface definition (not the actual functions!). This would make
|
||||
// the record.Record interface slim and only provide methods that
|
||||
// most users actually need. All those database/storage related methods
|
||||
// can still be accessed by using GetBase().XXX() instead. We can also
|
||||
// expose the dbName and dbKey and meta properties directly which would
|
||||
// make a nice JSON blob when marshalled.
|
||||
|
||||
// Base provides a quick way to comply with the Model interface.
|
||||
type Base struct {
|
||||
dbName string
|
||||
dbKey string
|
||||
meta *Meta
|
||||
}
|
||||
|
||||
// SetKey sets the key on the database record. The key may only be set once and
|
||||
// future calls to SetKey will be ignored. If you want to copy/move the record
|
||||
// to another database key, you will need to create a copy and assign a new key.
|
||||
// A key must be set before the record is used in any database operation.
|
||||
func (b *Base) SetKey(key string) {
|
||||
if !b.KeyIsSet() {
|
||||
b.dbName, b.dbKey = ParseKey(key)
|
||||
} else {
|
||||
log.Errorf("database: key is already set: tried to replace %q with %q", b.Key(), key)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetKey resets the database name and key.
|
||||
// Use with caution!
|
||||
func (b *Base) ResetKey() {
|
||||
b.dbName = ""
|
||||
b.dbKey = ""
|
||||
}
|
||||
|
||||
// Key returns the key of the database record.
|
||||
// As the key must be set before any usage and can only be set once, this
|
||||
// function may be used without locking the record.
|
||||
func (b *Base) Key() string {
|
||||
return b.dbName + ":" + b.dbKey
|
||||
}
|
||||
|
||||
// KeyIsSet returns true if the database key is set.
|
||||
// As the key must be set before any usage and can only be set once, this
|
||||
// function may be used without locking the record.
|
||||
func (b *Base) KeyIsSet() bool {
|
||||
return b.dbName != ""
|
||||
}
|
||||
|
||||
// DatabaseName returns the name of the database.
|
||||
// As the key must be set before any usage and can only be set once, this
|
||||
// function may be used without locking the record.
|
||||
func (b *Base) DatabaseName() string {
|
||||
return b.dbName
|
||||
}
|
||||
|
||||
// DatabaseKey returns the database key of the database record.
|
||||
// As the key must be set before any usage and can only be set once, this
|
||||
// function may be used without locking the record.
|
||||
func (b *Base) DatabaseKey() string {
|
||||
return b.dbKey
|
||||
}
|
||||
|
||||
// Meta returns the metadata object for this record.
|
||||
func (b *Base) Meta() *Meta {
|
||||
return b.meta
|
||||
}
|
||||
|
||||
// CreateMeta sets a default metadata object for this record.
|
||||
func (b *Base) CreateMeta() {
|
||||
b.meta = &Meta{}
|
||||
}
|
||||
|
||||
// UpdateMeta creates the metadata if it does not exist and updates it.
|
||||
func (b *Base) UpdateMeta() {
|
||||
if b.meta == nil {
|
||||
b.CreateMeta()
|
||||
}
|
||||
b.meta.Update()
|
||||
}
|
||||
|
||||
// SetMeta sets the metadata on the database record, it should only be called after loading the record. Use MoveTo to save the record with another key.
|
||||
func (b *Base) SetMeta(meta *Meta) {
|
||||
b.meta = meta
|
||||
}
|
||||
|
||||
// Marshal marshals the object, without the database key or metadata. It returns nil if the record is deleted.
|
||||
func (b *Base) Marshal(self Record, format uint8) ([]byte, error) {
|
||||
if b.Meta() == nil {
|
||||
return nil, errors.New("missing meta")
|
||||
}
|
||||
|
||||
if b.Meta().Deleted > 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dumped, err := dsd.Dump(self, format)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dumped, nil
|
||||
}
|
||||
|
||||
// MarshalRecord packs the object, including metadata, into a byte array for saving in a database.
|
||||
func (b *Base) MarshalRecord(self Record) ([]byte, error) {
|
||||
if b.Meta() == nil {
|
||||
return nil, errors.New("missing meta")
|
||||
}
|
||||
|
||||
// version
|
||||
c := container.New([]byte{1})
|
||||
|
||||
// meta encoding
|
||||
metaSection, err := dsd.Dump(b.meta, dsd.GenCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.AppendAsBlock(metaSection)
|
||||
|
||||
// data
|
||||
dataSection, err := b.Marshal(self, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Append(dataSection)
|
||||
|
||||
return c.CompileData(), nil
|
||||
}
|
||||
|
||||
// IsWrapped returns whether the record is a Wrapper.
|
||||
func (b *Base) IsWrapped() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetAccessor returns an accessor for this record, if available.
|
||||
func (b *Base) GetAccessor(self Record) accessor.Accessor {
|
||||
return accessor.NewStructAccessor(self)
|
||||
}
|
||||
13
base/database/record/base_test.go
Normal file
13
base/database/record/base_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package record
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBaseRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// check model interface compliance
|
||||
var m Record
|
||||
b := &TestRecord{}
|
||||
m = b
|
||||
_ = m
|
||||
}
|
||||
14
base/database/record/key.go
Normal file
14
base/database/record/key.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseKey splits a key into it's database name and key parts.
|
||||
func ParseKey(key string) (dbName, dbKey string) {
|
||||
splitted := strings.SplitN(key, ":", 2)
|
||||
if len(splitted) < 2 {
|
||||
return splitted[0], ""
|
||||
}
|
||||
return splitted[0], strings.Join(splitted[1:], ":")
|
||||
}
|
||||
348
base/database/record/meta-bench_test.go
Normal file
348
base/database/record/meta-bench_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package record
|
||||
|
||||
// Benchmark:
|
||||
// BenchmarkAllocateBytes-8 2000000000 0.76 ns/op
|
||||
// BenchmarkAllocateStruct1-8 2000000000 0.76 ns/op
|
||||
// BenchmarkAllocateStruct2-8 2000000000 0.79 ns/op
|
||||
// BenchmarkMetaSerializeContainer-8 1000000 1703 ns/op
|
||||
// BenchmarkMetaUnserializeContainer-8 2000000 950 ns/op
|
||||
// BenchmarkMetaSerializeVarInt-8 3000000 457 ns/op
|
||||
// BenchmarkMetaUnserializeVarInt-8 20000000 62.9 ns/op
|
||||
// BenchmarkMetaSerializeWithXDR2-8 1000000 2360 ns/op
|
||||
// BenchmarkMetaUnserializeWithXDR2-8 500000 3189 ns/op
|
||||
// BenchmarkMetaSerializeWithColfer-8 10000000 237 ns/op
|
||||
// BenchmarkMetaUnserializeWithColfer-8 20000000 51.7 ns/op
|
||||
// BenchmarkMetaSerializeWithCodegen-8 50000000 23.7 ns/op
|
||||
// BenchmarkMetaUnserializeWithCodegen-8 100000000 18.9 ns/op
|
||||
// BenchmarkMetaSerializeWithDSDJSON-8 1000000 2398 ns/op
|
||||
// BenchmarkMetaUnserializeWithDSDJSON-8 300000 6264 ns/op
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/structures/container"
|
||||
"github.com/safing/structures/dsd"
|
||||
"github.com/safing/structures/varint"
|
||||
)
|
||||
|
||||
var testMeta = &Meta{
|
||||
Created: time.Now().Unix(),
|
||||
Modified: time.Now().Unix(),
|
||||
Expires: time.Now().Unix(),
|
||||
Deleted: time.Now().Unix(),
|
||||
secret: true,
|
||||
cronjewel: true,
|
||||
}
|
||||
|
||||
func BenchmarkAllocateBytes(b *testing.B) {
|
||||
for range b.N {
|
||||
_ = make([]byte, 33)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAllocateStruct1(b *testing.B) {
|
||||
for range b.N {
|
||||
var newMeta Meta
|
||||
_ = newMeta
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAllocateStruct2(b *testing.B) {
|
||||
for range b.N {
|
||||
_ = Meta{}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetaSerializeContainer(b *testing.B) {
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
c := container.New()
|
||||
c.AppendNumber(uint64(testMeta.Created))
|
||||
c.AppendNumber(uint64(testMeta.Modified))
|
||||
c.AppendNumber(uint64(testMeta.Expires))
|
||||
c.AppendNumber(uint64(testMeta.Deleted))
|
||||
switch {
|
||||
case testMeta.secret && testMeta.cronjewel:
|
||||
c.AppendNumber(3)
|
||||
case testMeta.secret:
|
||||
c.AppendNumber(1)
|
||||
case testMeta.cronjewel:
|
||||
c.AppendNumber(2)
|
||||
default:
|
||||
c.AppendNumber(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetaUnserializeContainer(b *testing.B) {
|
||||
// Setup
|
||||
c := container.New()
|
||||
c.AppendNumber(uint64(testMeta.Created))
|
||||
c.AppendNumber(uint64(testMeta.Modified))
|
||||
c.AppendNumber(uint64(testMeta.Expires))
|
||||
c.AppendNumber(uint64(testMeta.Deleted))
|
||||
switch {
|
||||
case testMeta.secret && testMeta.cronjewel:
|
||||
c.AppendNumber(3)
|
||||
case testMeta.secret:
|
||||
c.AppendNumber(1)
|
||||
case testMeta.cronjewel:
|
||||
c.AppendNumber(2)
|
||||
default:
|
||||
c.AppendNumber(0)
|
||||
}
|
||||
encodedData := c.CompileData()
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
var newMeta Meta
|
||||
var err error
|
||||
var num uint64
|
||||
c := container.New(encodedData)
|
||||
num, err = c.GetNextN64()
|
||||
newMeta.Created = int64(num)
|
||||
if err != nil {
|
||||
b.Errorf("could not decode: %s", err)
|
||||
return
|
||||
}
|
||||
num, err = c.GetNextN64()
|
||||
newMeta.Modified = int64(num)
|
||||
if err != nil {
|
||||
b.Errorf("could not decode: %s", err)
|
||||
return
|
||||
}
|
||||
num, err = c.GetNextN64()
|
||||
newMeta.Expires = int64(num)
|
||||
if err != nil {
|
||||
b.Errorf("could not decode: %s", err)
|
||||
return
|
||||
}
|
||||
num, err = c.GetNextN64()
|
||||
newMeta.Deleted = int64(num)
|
||||
if err != nil {
|
||||
b.Errorf("could not decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
flags, err := c.GetNextN8()
|
||||
if err != nil {
|
||||
b.Errorf("could not decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch flags {
|
||||
case 3:
|
||||
newMeta.secret = true
|
||||
newMeta.cronjewel = true
|
||||
case 2:
|
||||
newMeta.cronjewel = true
|
||||
case 1:
|
||||
newMeta.secret = true
|
||||
case 0:
|
||||
default:
|
||||
b.Errorf("invalid flag value: %d", flags)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetaSerializeVarInt(b *testing.B) {
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
encoded := make([]byte, 33)
|
||||
offset := 0
|
||||
data := varint.Pack64(uint64(testMeta.Created))
|
||||
for _, part := range data {
|
||||
encoded[offset] = part
|
||||
offset++
|
||||
}
|
||||
data = varint.Pack64(uint64(testMeta.Modified))
|
||||
for _, part := range data {
|
||||
encoded[offset] = part
|
||||
offset++
|
||||
}
|
||||
data = varint.Pack64(uint64(testMeta.Expires))
|
||||
for _, part := range data {
|
||||
encoded[offset] = part
|
||||
offset++
|
||||
}
|
||||
data = varint.Pack64(uint64(testMeta.Deleted))
|
||||
for _, part := range data {
|
||||
encoded[offset] = part
|
||||
offset++
|
||||
}
|
||||
|
||||
switch {
|
||||
case testMeta.secret && testMeta.cronjewel:
|
||||
encoded[offset] = 3
|
||||
case testMeta.secret:
|
||||
encoded[offset] = 1
|
||||
case testMeta.cronjewel:
|
||||
encoded[offset] = 2
|
||||
default:
|
||||
encoded[offset] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetaUnserializeVarInt(b *testing.B) {
|
||||
// Setup
|
||||
encoded := make([]byte, 33)
|
||||
offset := 0
|
||||
data := varint.Pack64(uint64(testMeta.Created))
|
||||
for _, part := range data {
|
||||
encoded[offset] = part
|
||||
offset++
|
||||
}
|
||||
data = varint.Pack64(uint64(testMeta.Modified))
|
||||
for _, part := range data {
|
||||
encoded[offset] = part
|
||||
offset++
|
||||
}
|
||||
data = varint.Pack64(uint64(testMeta.Expires))
|
||||
for _, part := range data {
|
||||
encoded[offset] = part
|
||||
offset++
|
||||
}
|
||||
data = varint.Pack64(uint64(testMeta.Deleted))
|
||||
for _, part := range data {
|
||||
encoded[offset] = part
|
||||
offset++
|
||||
}
|
||||
|
||||
switch {
|
||||
case testMeta.secret && testMeta.cronjewel:
|
||||
encoded[offset] = 3
|
||||
case testMeta.secret:
|
||||
encoded[offset] = 1
|
||||
case testMeta.cronjewel:
|
||||
encoded[offset] = 2
|
||||
default:
|
||||
encoded[offset] = 0
|
||||
}
|
||||
offset++
|
||||
encodedData := encoded[:offset]
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
var newMeta Meta
|
||||
offset = 0
|
||||
|
||||
num, n, err := varint.Unpack64(encodedData)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
return
|
||||
}
|
||||
testMeta.Created = int64(num)
|
||||
offset += n
|
||||
|
||||
num, n, err = varint.Unpack64(encodedData[offset:])
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
return
|
||||
}
|
||||
testMeta.Modified = int64(num)
|
||||
offset += n
|
||||
|
||||
num, n, err = varint.Unpack64(encodedData[offset:])
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
return
|
||||
}
|
||||
testMeta.Expires = int64(num)
|
||||
offset += n
|
||||
|
||||
num, n, err = varint.Unpack64(encodedData[offset:])
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
return
|
||||
}
|
||||
testMeta.Deleted = int64(num)
|
||||
offset += n
|
||||
|
||||
switch encodedData[offset] {
|
||||
case 3:
|
||||
newMeta.secret = true
|
||||
newMeta.cronjewel = true
|
||||
case 2:
|
||||
newMeta.cronjewel = true
|
||||
case 1:
|
||||
newMeta.secret = true
|
||||
case 0:
|
||||
default:
|
||||
b.Errorf("invalid flag value: %d", encodedData[offset])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetaSerializeWithCodegen(b *testing.B) {
|
||||
for range b.N {
|
||||
_, err := testMeta.GenCodeMarshal(nil)
|
||||
if err != nil {
|
||||
b.Errorf("failed to serialize with codegen: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetaUnserializeWithCodegen(b *testing.B) {
|
||||
// Setup
|
||||
encodedData, err := testMeta.GenCodeMarshal(nil)
|
||||
if err != nil {
|
||||
b.Errorf("failed to serialize with codegen: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
var newMeta Meta
|
||||
_, err := newMeta.GenCodeUnmarshal(encodedData)
|
||||
if err != nil {
|
||||
b.Errorf("failed to unserialize with codegen: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) {
|
||||
for range b.N {
|
||||
_, err := dsd.Dump(testMeta, dsd.JSON)
|
||||
if err != nil {
|
||||
b.Errorf("failed to serialize with DSD/JSON: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) {
|
||||
// Setup
|
||||
encodedData, err := dsd.Dump(testMeta, dsd.JSON)
|
||||
if err != nil {
|
||||
b.Errorf("failed to serialize with DSD/JSON: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
var newMeta Meta
|
||||
_, err := dsd.Load(encodedData, &newMeta)
|
||||
if err != nil {
|
||||
b.Errorf("failed to unserialize with DSD/JSON: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
145
base/database/record/meta-gencode.go
Normal file
145
base/database/record/meta-gencode.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GenCodeSize returns the size of the gencode marshalled byte slice.
|
||||
func (m *Meta) GenCodeSize() (s int) {
|
||||
s += 34
|
||||
return
|
||||
}
|
||||
|
||||
// GenCodeMarshal gencode marshalls Meta into the given byte array, or a new one if its too small.
|
||||
func (m *Meta) GenCodeMarshal(buf []byte) ([]byte, error) {
|
||||
size := m.GenCodeSize()
|
||||
{
|
||||
if cap(buf) >= size {
|
||||
buf = buf[:size]
|
||||
} else {
|
||||
buf = make([]byte, size)
|
||||
}
|
||||
}
|
||||
i := uint64(0)
|
||||
|
||||
{
|
||||
|
||||
buf[0+0] = byte(m.Created >> 0)
|
||||
|
||||
buf[1+0] = byte(m.Created >> 8)
|
||||
|
||||
buf[2+0] = byte(m.Created >> 16)
|
||||
|
||||
buf[3+0] = byte(m.Created >> 24)
|
||||
|
||||
buf[4+0] = byte(m.Created >> 32)
|
||||
|
||||
buf[5+0] = byte(m.Created >> 40)
|
||||
|
||||
buf[6+0] = byte(m.Created >> 48)
|
||||
|
||||
buf[7+0] = byte(m.Created >> 56)
|
||||
|
||||
}
|
||||
{
|
||||
|
||||
buf[0+8] = byte(m.Modified >> 0)
|
||||
|
||||
buf[1+8] = byte(m.Modified >> 8)
|
||||
|
||||
buf[2+8] = byte(m.Modified >> 16)
|
||||
|
||||
buf[3+8] = byte(m.Modified >> 24)
|
||||
|
||||
buf[4+8] = byte(m.Modified >> 32)
|
||||
|
||||
buf[5+8] = byte(m.Modified >> 40)
|
||||
|
||||
buf[6+8] = byte(m.Modified >> 48)
|
||||
|
||||
buf[7+8] = byte(m.Modified >> 56)
|
||||
|
||||
}
|
||||
{
|
||||
|
||||
buf[0+16] = byte(m.Expires >> 0)
|
||||
|
||||
buf[1+16] = byte(m.Expires >> 8)
|
||||
|
||||
buf[2+16] = byte(m.Expires >> 16)
|
||||
|
||||
buf[3+16] = byte(m.Expires >> 24)
|
||||
|
||||
buf[4+16] = byte(m.Expires >> 32)
|
||||
|
||||
buf[5+16] = byte(m.Expires >> 40)
|
||||
|
||||
buf[6+16] = byte(m.Expires >> 48)
|
||||
|
||||
buf[7+16] = byte(m.Expires >> 56)
|
||||
|
||||
}
|
||||
{
|
||||
|
||||
buf[0+24] = byte(m.Deleted >> 0)
|
||||
|
||||
buf[1+24] = byte(m.Deleted >> 8)
|
||||
|
||||
buf[2+24] = byte(m.Deleted >> 16)
|
||||
|
||||
buf[3+24] = byte(m.Deleted >> 24)
|
||||
|
||||
buf[4+24] = byte(m.Deleted >> 32)
|
||||
|
||||
buf[5+24] = byte(m.Deleted >> 40)
|
||||
|
||||
buf[6+24] = byte(m.Deleted >> 48)
|
||||
|
||||
buf[7+24] = byte(m.Deleted >> 56)
|
||||
|
||||
}
|
||||
{
|
||||
if m.secret {
|
||||
buf[32] = 1
|
||||
} else {
|
||||
buf[32] = 0
|
||||
}
|
||||
}
|
||||
{
|
||||
if m.cronjewel {
|
||||
buf[33] = 1
|
||||
} else {
|
||||
buf[33] = 0
|
||||
}
|
||||
}
|
||||
return buf[:i+34], nil
|
||||
}
|
||||
|
||||
// GenCodeUnmarshal gencode unmarshalls Meta and returns the bytes read.
|
||||
func (m *Meta) GenCodeUnmarshal(buf []byte) (uint64, error) {
|
||||
if len(buf) < m.GenCodeSize() {
|
||||
return 0, fmt.Errorf("insufficient data: got %d out of %d bytes", len(buf), m.GenCodeSize())
|
||||
}
|
||||
|
||||
i := uint64(0)
|
||||
|
||||
{
|
||||
m.Created = 0 | (int64(buf[0+0]) << 0) | (int64(buf[1+0]) << 8) | (int64(buf[2+0]) << 16) | (int64(buf[3+0]) << 24) | (int64(buf[4+0]) << 32) | (int64(buf[5+0]) << 40) | (int64(buf[6+0]) << 48) | (int64(buf[7+0]) << 56)
|
||||
}
|
||||
{
|
||||
m.Modified = 0 | (int64(buf[0+8]) << 0) | (int64(buf[1+8]) << 8) | (int64(buf[2+8]) << 16) | (int64(buf[3+8]) << 24) | (int64(buf[4+8]) << 32) | (int64(buf[5+8]) << 40) | (int64(buf[6+8]) << 48) | (int64(buf[7+8]) << 56)
|
||||
}
|
||||
{
|
||||
m.Expires = 0 | (int64(buf[0+16]) << 0) | (int64(buf[1+16]) << 8) | (int64(buf[2+16]) << 16) | (int64(buf[3+16]) << 24) | (int64(buf[4+16]) << 32) | (int64(buf[5+16]) << 40) | (int64(buf[6+16]) << 48) | (int64(buf[7+16]) << 56)
|
||||
}
|
||||
{
|
||||
m.Deleted = 0 | (int64(buf[0+24]) << 0) | (int64(buf[1+24]) << 8) | (int64(buf[2+24]) << 16) | (int64(buf[3+24]) << 24) | (int64(buf[4+24]) << 32) | (int64(buf[5+24]) << 40) | (int64(buf[6+24]) << 48) | (int64(buf[7+24]) << 56)
|
||||
}
|
||||
{
|
||||
m.secret = buf[32] == 1
|
||||
}
|
||||
{
|
||||
m.cronjewel = buf[33] == 1
|
||||
}
|
||||
return i + 34, nil
|
||||
}
|
||||
35
base/database/record/meta-gencode_test.go
Normal file
35
base/database/record/meta-gencode_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var genCodeTestMeta = &Meta{
|
||||
Created: time.Now().Unix(),
|
||||
Modified: time.Now().Unix(),
|
||||
Expires: time.Now().Unix(),
|
||||
Deleted: time.Now().Unix(),
|
||||
secret: true,
|
||||
cronjewel: true,
|
||||
}
|
||||
|
||||
func TestGenCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
encoded, err := genCodeTestMeta.GenCodeMarshal(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newMeta := &Meta{}
|
||||
_, err = newMeta.GenCodeUnmarshal(encoded)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(genCodeTestMeta, newMeta) {
|
||||
t.Errorf("objects are not equal, got: %v", newMeta)
|
||||
}
|
||||
}
|
||||
10
base/database/record/meta.colf
Normal file
10
base/database/record/meta.colf
Normal file
@@ -0,0 +1,10 @@
|
||||
package record
|
||||
|
||||
type course struct {
|
||||
Created int64
|
||||
Modified int64
|
||||
Expires int64
|
||||
Deleted int64
|
||||
Secret bool
|
||||
Cronjewel bool
|
||||
}
|
||||
8
base/database/record/meta.gencode
Normal file
8
base/database/record/meta.gencode
Normal file
@@ -0,0 +1,8 @@
|
||||
struct Meta {
|
||||
Created int64
|
||||
Modified int64
|
||||
Expires int64
|
||||
Deleted int64
|
||||
Secret bool
|
||||
Cronjewel bool
|
||||
}
|
||||
129
base/database/record/meta.go
Normal file
129
base/database/record/meta.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package record
|
||||
|
||||
import "time"
|
||||
|
||||
// Meta holds metadata about the record.
|
||||
type Meta struct {
|
||||
Created int64
|
||||
Modified int64
|
||||
Expires int64
|
||||
Deleted int64
|
||||
secret bool // secrets must not be sent to the UI, only synced between nodes
|
||||
cronjewel bool // crownjewels must never leave the instance, but may be read by the UI
|
||||
}
|
||||
|
||||
// SetAbsoluteExpiry sets an absolute expiry time (in seconds), that is not affected when the record is updated.
|
||||
func (m *Meta) SetAbsoluteExpiry(seconds int64) {
|
||||
m.Expires = seconds
|
||||
m.Deleted = 0
|
||||
}
|
||||
|
||||
// SetRelativateExpiry sets a relative expiry time (ie. TTL in seconds) that is automatically updated whenever the record is updated/saved.
|
||||
func (m *Meta) SetRelativateExpiry(seconds int64) {
|
||||
if seconds >= 0 {
|
||||
m.Deleted = -seconds
|
||||
}
|
||||
}
|
||||
|
||||
// GetAbsoluteExpiry returns the absolute expiry time.
|
||||
func (m *Meta) GetAbsoluteExpiry() int64 {
|
||||
return m.Expires
|
||||
}
|
||||
|
||||
// GetRelativeExpiry returns the current relative expiry time - ie. seconds until expiry.
|
||||
// A negative value signifies that the record does not expire.
|
||||
func (m *Meta) GetRelativeExpiry() int64 {
|
||||
if m.Expires == 0 {
|
||||
return -1
|
||||
}
|
||||
|
||||
abs := m.Expires - time.Now().Unix()
|
||||
if abs < 0 {
|
||||
return 0
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
// MakeCrownJewel marks the database records as a crownjewel, meaning that it will not be sent/synced to other devices.
|
||||
func (m *Meta) MakeCrownJewel() {
|
||||
m.cronjewel = true
|
||||
}
|
||||
|
||||
// MakeSecret sets the database record as secret, meaning that it may only be used internally, and not by interfacing processes, such as the UI.
|
||||
func (m *Meta) MakeSecret() {
|
||||
m.secret = true
|
||||
}
|
||||
|
||||
// Update updates the internal meta states and should be called before writing the record to the database.
|
||||
func (m *Meta) Update() {
|
||||
now := time.Now().Unix()
|
||||
m.Modified = now
|
||||
if m.Created == 0 {
|
||||
m.Created = now
|
||||
}
|
||||
if m.Deleted < 0 {
|
||||
m.Expires = now - m.Deleted
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets all metadata, except for the secret and crownjewel status.
|
||||
func (m *Meta) Reset() {
|
||||
m.Created = 0
|
||||
m.Modified = 0
|
||||
m.Expires = 0
|
||||
m.Deleted = 0
|
||||
}
|
||||
|
||||
// Delete marks the record as deleted.
|
||||
func (m *Meta) Delete() {
|
||||
m.Deleted = time.Now().Unix()
|
||||
}
|
||||
|
||||
// IsDeleted returns whether the record is deleted.
|
||||
func (m *Meta) IsDeleted() bool {
|
||||
return m.Deleted > 0
|
||||
}
|
||||
|
||||
// CheckValidity checks whether the database record is valid.
|
||||
func (m *Meta) CheckValidity() (valid bool) {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch {
|
||||
case m.Deleted > 0:
|
||||
return false
|
||||
case m.Expires > 0 && m.Expires < time.Now().Unix():
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// CheckPermission checks whether the database record may be accessed with the following scope.
|
||||
func (m *Meta) CheckPermission(local, internal bool) (permitted bool) {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch {
|
||||
case !local && m.cronjewel:
|
||||
return false
|
||||
case !internal && m.secret:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Duplicate returns a new copy of Meta.
|
||||
func (m *Meta) Duplicate() *Meta {
|
||||
return &Meta{
|
||||
Created: m.Created,
|
||||
Modified: m.Modified,
|
||||
Expires: m.Expires,
|
||||
Deleted: m.Deleted,
|
||||
secret: m.secret,
|
||||
cronjewel: m.cronjewel,
|
||||
}
|
||||
}
|
||||
32
base/database/record/record.go
Normal file
32
base/database/record/record.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// Record provides an interface for uniformally handling database records.
|
||||
type Record interface {
|
||||
SetKey(key string) // test:config
|
||||
Key() string // test:config
|
||||
KeyIsSet() bool
|
||||
DatabaseName() string // test
|
||||
DatabaseKey() string // config
|
||||
|
||||
// Metadata.
|
||||
Meta() *Meta
|
||||
SetMeta(meta *Meta)
|
||||
CreateMeta()
|
||||
UpdateMeta()
|
||||
|
||||
// Serialization.
|
||||
Marshal(self Record, format uint8) ([]byte, error)
|
||||
MarshalRecord(self Record) ([]byte, error)
|
||||
GetAccessor(self Record) accessor.Accessor
|
||||
|
||||
// Locking.
|
||||
Lock()
|
||||
Unlock()
|
||||
|
||||
// Wrapping.
|
||||
IsWrapped() bool
|
||||
}
|
||||
10
base/database/record/record_test.go
Normal file
10
base/database/record/record_test.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type TestRecord struct {
|
||||
Base
|
||||
sync.Mutex
|
||||
}
|
||||
160
base/database/record/wrapper.go
Normal file
160
base/database/record/wrapper.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/structures/container"
|
||||
"github.com/safing/structures/dsd"
|
||||
"github.com/safing/structures/varint"
|
||||
)
|
||||
|
||||
// Wrapper wraps raw data and implements the Record interface.
|
||||
type Wrapper struct {
|
||||
Base
|
||||
sync.Mutex
|
||||
|
||||
Format uint8
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// NewRawWrapper returns a record wrapper for the given data, including metadata. This is normally only used by storage backends when loading records.
|
||||
func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) {
|
||||
version, offset, err := varint.Unpack8(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != 1 {
|
||||
return nil, fmt.Errorf("incompatible record version: %d", version)
|
||||
}
|
||||
|
||||
metaSection, n, err := varint.GetNextBlock(data[offset:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get meta section: %w", err)
|
||||
}
|
||||
offset += n
|
||||
|
||||
newMeta := &Meta{}
|
||||
_, err = dsd.Load(metaSection, newMeta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal meta section: %w", err)
|
||||
}
|
||||
|
||||
var format uint8 = dsd.RAW
|
||||
if !newMeta.IsDeleted() {
|
||||
format, n, err = varint.Unpack8(data[offset:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get dsd format: %w", err)
|
||||
}
|
||||
offset += n
|
||||
}
|
||||
|
||||
return &Wrapper{
|
||||
Base{
|
||||
database,
|
||||
key,
|
||||
newMeta,
|
||||
},
|
||||
sync.Mutex{},
|
||||
format,
|
||||
data[offset:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewWrapper returns a new record wrapper for the given data.
|
||||
func NewWrapper(key string, meta *Meta, format uint8, data []byte) (*Wrapper, error) {
|
||||
dbName, dbKey := ParseKey(key)
|
||||
|
||||
return &Wrapper{
|
||||
Base{
|
||||
dbName: dbName,
|
||||
dbKey: dbKey,
|
||||
meta: meta,
|
||||
},
|
||||
sync.Mutex{},
|
||||
format,
|
||||
data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Marshal marshals the object, without the database key or metadata.
|
||||
func (w *Wrapper) Marshal(r Record, format uint8) ([]byte, error) {
|
||||
if w.Meta() == nil {
|
||||
return nil, errors.New("missing meta")
|
||||
}
|
||||
|
||||
if w.Meta().Deleted > 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if format != dsd.AUTO && format != w.Format {
|
||||
return nil, errors.New("could not dump model, wrapped object format mismatch")
|
||||
}
|
||||
|
||||
data := make([]byte, len(w.Data)+1)
|
||||
data[0] = w.Format
|
||||
copy(data[1:], w.Data)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// MarshalRecord packs the object, including metadata, into a byte array for saving in a database.
|
||||
func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) {
|
||||
// Duplication necessary, as the version from Base would call Base.Marshal instead of Wrapper.Marshal
|
||||
|
||||
if w.Meta() == nil {
|
||||
return nil, errors.New("missing meta")
|
||||
}
|
||||
|
||||
// version
|
||||
c := container.New([]byte{1})
|
||||
|
||||
// meta
|
||||
metaSection, err := dsd.Dump(w.meta, dsd.GenCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.AppendAsBlock(metaSection)
|
||||
|
||||
// data
|
||||
dataSection, err := w.Marshal(r, dsd.AUTO)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Append(dataSection)
|
||||
|
||||
return c.CompileData(), nil
|
||||
}
|
||||
|
||||
// IsWrapped returns whether the record is a Wrapper.
|
||||
func (w *Wrapper) IsWrapped() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Unwrap unwraps data into a record.
|
||||
func Unwrap(wrapped, r Record) error {
|
||||
wrapper, ok := wrapped.(*Wrapper)
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot unwrap %T", wrapped)
|
||||
}
|
||||
|
||||
err := dsd.LoadAsFormat(wrapper.Data, wrapper.Format, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unwrap %T: %w", r, err)
|
||||
}
|
||||
|
||||
r.SetKey(wrapped.Key())
|
||||
r.SetMeta(wrapped.Meta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccessor returns an accessor for this record, if available.
|
||||
func (w *Wrapper) GetAccessor(self Record) accessor.Accessor {
|
||||
if w.Format == dsd.JSON && len(w.Data) > 0 {
|
||||
return accessor.NewJSONBytesAccessor(&w.Data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
57
base/database/record/wrapper_test.go
Normal file
57
base/database/record/wrapper_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
func TestWrapper(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// check model interface compliance
|
||||
var m Record
|
||||
w := &Wrapper{}
|
||||
m = w
|
||||
_ = m
|
||||
|
||||
// create test data
|
||||
testData := []byte(`{"a": "b"}`)
|
||||
encodedTestData := []byte(`J{"a": "b"}`)
|
||||
|
||||
// test wrapper
|
||||
wrapper, err := NewWrapper("test:a", &Meta{}, dsd.JSON, testData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if wrapper.Format != dsd.JSON {
|
||||
t.Error("format mismatch")
|
||||
}
|
||||
if !bytes.Equal(testData, wrapper.Data) {
|
||||
t.Error("data mismatch")
|
||||
}
|
||||
|
||||
encoded, err := wrapper.Marshal(wrapper, dsd.JSON)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(encodedTestData, encoded) {
|
||||
t.Error("marshal mismatch")
|
||||
}
|
||||
|
||||
wrapper.SetMeta(&Meta{})
|
||||
wrapper.meta.Update()
|
||||
raw, err := wrapper.MarshalRecord(wrapper)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrapper2, err := NewRawWrapper("test", "a", raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(testData, wrapper2.Data) {
|
||||
t.Error("marshal mismatch")
|
||||
}
|
||||
}
|
||||
67
base/database/registry.go
Normal file
67
base/database/registry.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
registry = make(map[string]*Database)
|
||||
registryLock sync.Mutex
|
||||
|
||||
nameConstraint = regexp.MustCompile("^[A-Za-z0-9_-]{3,}$")
|
||||
)
|
||||
|
||||
// Register registers a new database.
|
||||
// If the database is already registered, only
|
||||
// the description and the primary API will be
|
||||
// updated and the effective object will be returned.
|
||||
func Register(db *Database) (*Database, error) {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
registeredDB, ok := registry[db.Name]
|
||||
|
||||
if ok {
|
||||
// update database
|
||||
if registeredDB.Description != db.Description {
|
||||
registeredDB.Description = db.Description
|
||||
}
|
||||
if registeredDB.ShadowDelete != db.ShadowDelete {
|
||||
registeredDB.ShadowDelete = db.ShadowDelete
|
||||
}
|
||||
} else {
|
||||
// register new database
|
||||
if !nameConstraint.MatchString(db.Name) {
|
||||
return nil, errors.New("database name must only contain alphanumeric and `_-` characters and must be at least 3 characters long")
|
||||
}
|
||||
|
||||
now := time.Now().Round(time.Second)
|
||||
db.Registered = now
|
||||
db.LastUpdated = now
|
||||
db.LastLoaded = time.Time{}
|
||||
|
||||
registry[db.Name] = db
|
||||
}
|
||||
|
||||
if ok {
|
||||
return registeredDB, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func getDatabase(name string) (*Database, error) {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
registeredDB, ok := registry[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`database "%s" not registered`, name)
|
||||
}
|
||||
registeredDB.Loaded()
|
||||
|
||||
return registeredDB, nil
|
||||
}
|
||||
231
base/database/storage/badger/badger.go
Normal file
231
base/database/storage/badger/badger.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package badger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/dgraph-io/badger"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
// Badger database made pluggable for portbase.
|
||||
type Badger struct {
|
||||
name string
|
||||
db *badger.DB
|
||||
}
|
||||
|
||||
func init() {
|
||||
_ = storage.Register("badger", NewBadger)
|
||||
}
|
||||
|
||||
// NewBadger opens/creates a badger database.
|
||||
func NewBadger(name, location string) (storage.Interface, error) {
|
||||
opts := badger.DefaultOptions(location)
|
||||
|
||||
db, err := badger.Open(opts)
|
||||
if errors.Is(err, badger.ErrTruncateNeeded) {
|
||||
// clean up after crash
|
||||
log.Warningf("database/storage: truncating corrupted value log of badger database %s: this may cause data loss", name)
|
||||
opts.Truncate = true
|
||||
db, err = badger.Open(opts)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Badger{
|
||||
name: name,
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (b *Badger) Get(key string) (record.Record, error) {
|
||||
var item *badger.Item
|
||||
|
||||
err := b.db.View(func(txn *badger.Txn) error {
|
||||
var err error
|
||||
item, err = txn.Get([]byte(key))
|
||||
if err != nil {
|
||||
if errors.Is(err, badger.ErrKeyNotFound) {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// return err if deleted or expired
|
||||
if item.IsDeletedOrExpired() {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
data, err := item.ValueCopy(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := record.NewRawWrapper(b.name, string(item.Key()), data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetMeta returns the metadata of a database record.
|
||||
func (b *Badger) GetMeta(key string) (*record.Meta, error) {
|
||||
// TODO: Replace with more performant variant.
|
||||
|
||||
r, err := b.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.Meta(), nil
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (b *Badger) Put(r record.Record) (record.Record, error) {
|
||||
data, err := r.MarshalRecord(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = b.db.Update(func(txn *badger.Txn) error {
|
||||
return txn.Set([]byte(r.DatabaseKey()), data)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (b *Badger) Delete(key string) error {
|
||||
return b.db.Update(func(txn *badger.Txn) error {
|
||||
err := txn.Delete([]byte(key))
|
||||
if err != nil && !errors.Is(err, badger.ErrKeyNotFound) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Query returns a an iterator for the supplied query.
|
||||
func (b *Badger) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
queryIter := iterator.New()
|
||||
|
||||
go b.queryExecutor(queryIter, q, local, internal)
|
||||
return queryIter, nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) {
|
||||
err := b.db.View(func(txn *badger.Txn) error {
|
||||
it := txn.NewIterator(badger.DefaultIteratorOptions)
|
||||
defer it.Close()
|
||||
prefix := []byte(q.DatabaseKeyPrefix())
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
|
||||
item := it.Item()
|
||||
|
||||
var data []byte
|
||||
err := item.Value(func(val []byte) error {
|
||||
data = val
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r, err := record.NewRawWrapper(b.name, string(item.Key()), data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !r.Meta().CheckValidity() {
|
||||
continue
|
||||
}
|
||||
if !r.Meta().CheckPermission(local, internal) {
|
||||
continue
|
||||
}
|
||||
|
||||
if q.MatchesRecord(r) {
|
||||
copiedData, err := item.ValueCopy(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newWrapper, err := record.NewRawWrapper(b.name, r.DatabaseKey(), copiedData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-queryIter.Done:
|
||||
return nil
|
||||
case queryIter.Next <- newWrapper:
|
||||
default:
|
||||
select {
|
||||
case queryIter.Next <- newWrapper:
|
||||
case <-queryIter.Done:
|
||||
return nil
|
||||
case <-time.After(1 * time.Minute):
|
||||
return errors.New("query timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
queryIter.Finish(err)
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (b *Badger) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Injected returns whether the database is injected.
|
||||
func (b *Badger) Injected() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Maintain runs a light maintenance operation on the database.
|
||||
func (b *Badger) Maintain(_ context.Context) error {
|
||||
_ = b.db.RunValueLogGC(0.7)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainThorough runs a thorough maintenance operation on the database.
|
||||
func (b *Badger) MaintainThorough(_ context.Context) (err error) {
|
||||
for err == nil {
|
||||
err = b.db.RunValueLogGC(0.7)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainRecordStates maintains records states in the database.
|
||||
func (b *Badger) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error {
|
||||
// TODO: implement MaintainRecordStates
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the database.
|
||||
func (b *Badger) Shutdown() error {
|
||||
return b.db.Close()
|
||||
}
|
||||
148
base/database/storage/badger/badger_test.go
Normal file
148
base/database/storage/badger/badger_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package badger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
// Compile time interface checks.
|
||||
_ storage.Interface = &Badger{}
|
||||
_ storage.Maintainer = &Badger{}
|
||||
)
|
||||
|
||||
type TestRecord struct { //nolint:maligned
|
||||
record.Base
|
||||
sync.Mutex
|
||||
S string
|
||||
I int
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
UI uint
|
||||
UI8 uint8
|
||||
UI16 uint16
|
||||
UI32 uint32
|
||||
UI64 uint64
|
||||
F32 float32
|
||||
F64 float64
|
||||
B bool
|
||||
}
|
||||
|
||||
func TestBadger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir, err := os.MkdirTemp("", "testing-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = os.RemoveAll(testDir) // clean up
|
||||
}()
|
||||
|
||||
// start
|
||||
db, err := NewBadger("test", testDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a := &TestRecord{
|
||||
S: "banana",
|
||||
I: 42,
|
||||
I8: 42,
|
||||
I16: 42,
|
||||
I32: 42,
|
||||
I64: 42,
|
||||
UI: 42,
|
||||
UI8: 42,
|
||||
UI16: 42,
|
||||
UI32: 42,
|
||||
UI64: 42,
|
||||
F32: 42.42,
|
||||
F64: 42.42,
|
||||
B: true,
|
||||
}
|
||||
a.SetMeta(&record.Meta{})
|
||||
a.Meta().Update()
|
||||
a.SetKey("test:A")
|
||||
|
||||
// put record
|
||||
_, err = db.Put(a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// get and compare
|
||||
r1, err := db.Get("A")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a1 := &TestRecord{}
|
||||
err = record.Unwrap(r1, a1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a, a1) {
|
||||
t.Fatalf("mismatch, got %v", a1)
|
||||
}
|
||||
|
||||
// test query
|
||||
q := query.New("").MustBeValid()
|
||||
it, err := db.Query(q, true, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cnt := 0
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cnt != 1 {
|
||||
t.Fatalf("unexpected query result count: %d", cnt)
|
||||
}
|
||||
|
||||
// delete
|
||||
err = db.Delete("A")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check if its gone
|
||||
_, err = db.Get("A")
|
||||
if err == nil {
|
||||
t.Fatal("should fail")
|
||||
}
|
||||
|
||||
// maintenance
|
||||
maintainer, ok := db.(storage.Maintainer)
|
||||
if ok {
|
||||
err = maintainer.Maintain(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = maintainer.MaintainThorough(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal("should implement Maintainer")
|
||||
}
|
||||
|
||||
// shutdown
|
||||
err = db.Shutdown()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
427
base/database/storage/bbolt/bbolt.go
Normal file
427
base/database/storage/bbolt/bbolt.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package bbolt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"go.etcd.io/bbolt"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
var bucketName = []byte{0}
|
||||
|
||||
// BBolt database made pluggable for portbase.
|
||||
type BBolt struct {
|
||||
name string
|
||||
db *bbolt.DB
|
||||
}
|
||||
|
||||
func init() {
|
||||
_ = storage.Register("bbolt", NewBBolt)
|
||||
}
|
||||
|
||||
// NewBBolt opens/creates a bbolt database.
|
||||
func NewBBolt(name, location string) (storage.Interface, error) {
|
||||
// Create options for bbolt database.
|
||||
dbFile := filepath.Join(location, "db.bbolt")
|
||||
dbOptions := &bbolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
// Open/Create database, retry if there is a timeout.
|
||||
db, err := bbolt.Open(dbFile, 0o0600, dbOptions)
|
||||
for i := 0; i < 5 && err != nil; i++ {
|
||||
// Try again if there is an error.
|
||||
db, err = bbolt.Open(dbFile, 0o0600, dbOptions)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create bucket
|
||||
err = db.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists(bucketName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &BBolt{
|
||||
name: name,
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (b *BBolt) Get(key string) (record.Record, error) {
|
||||
var r record.Record
|
||||
|
||||
err := b.db.View(func(tx *bbolt.Tx) error {
|
||||
// get value from db
|
||||
value := tx.Bucket(bucketName).Get([]byte(key))
|
||||
if value == nil {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
|
||||
// copy data
|
||||
duplicate := make([]byte, len(value))
|
||||
copy(duplicate, value)
|
||||
|
||||
// create record
|
||||
var txErr error
|
||||
r, txErr = record.NewRawWrapper(b.name, key, duplicate)
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// GetMeta returns the metadata of a database record.
|
||||
func (b *BBolt) GetMeta(key string) (*record.Meta, error) {
|
||||
// TODO: Replace with more performant variant.
|
||||
|
||||
r, err := b.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.Meta(), nil
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (b *BBolt) Put(r record.Record) (record.Record, error) {
|
||||
data, err := r.MarshalRecord(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = b.db.Update(func(tx *bbolt.Tx) error {
|
||||
txErr := tx.Bucket(bucketName).Put([]byte(r.DatabaseKey()), data)
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
func (b *BBolt) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error) {
|
||||
batch := make(chan record.Record, 100)
|
||||
errs := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := b.db.Batch(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(bucketName)
|
||||
for r := range batch {
|
||||
txErr := b.batchPutOrDelete(bucket, shadowDelete, r)
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
errs <- err
|
||||
}()
|
||||
|
||||
return batch, errs
|
||||
}
|
||||
|
||||
func (b *BBolt) batchPutOrDelete(bucket *bbolt.Bucket, shadowDelete bool, r record.Record) (err error) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
if !shadowDelete && r.Meta().IsDeleted() {
|
||||
// Immediate delete.
|
||||
err = bucket.Delete([]byte(r.DatabaseKey()))
|
||||
} else {
|
||||
// Put or shadow delete.
|
||||
var data []byte
|
||||
data, err = r.MarshalRecord(r)
|
||||
if err == nil {
|
||||
err = bucket.Put([]byte(r.DatabaseKey()), data)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (b *BBolt) Delete(key string) error {
|
||||
err := b.db.Update(func(tx *bbolt.Tx) error {
|
||||
txErr := tx.Bucket(bucketName).Delete([]byte(key))
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query returns a an iterator for the supplied query.
|
||||
func (b *BBolt) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
queryIter := iterator.New()
|
||||
|
||||
go b.queryExecutor(queryIter, q, local, internal)
|
||||
return queryIter, nil
|
||||
}
|
||||
|
||||
func (b *BBolt) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) {
|
||||
prefix := []byte(q.DatabaseKeyPrefix())
|
||||
err := b.db.View(func(tx *bbolt.Tx) error {
|
||||
// Create a cursor for iteration.
|
||||
c := tx.Bucket(bucketName).Cursor()
|
||||
|
||||
// Iterate over items in sorted key order. This starts from the
|
||||
// first key/value pair and updates the k/v variables to the
|
||||
// next key/value on each iteration.
|
||||
//
|
||||
// The loop finishes at the end of the cursor when a nil key is returned.
|
||||
for key, value := c.Seek(prefix); key != nil; key, value = c.Next() {
|
||||
|
||||
// if we don't match the prefix anymore, exit
|
||||
if !bytes.HasPrefix(key, prefix) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrap value
|
||||
iterWrapper, err := record.NewRawWrapper(b.name, string(key), value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check validity / access
|
||||
if !iterWrapper.Meta().CheckValidity() {
|
||||
continue
|
||||
}
|
||||
if !iterWrapper.Meta().CheckPermission(local, internal) {
|
||||
continue
|
||||
}
|
||||
|
||||
// check if matches & send
|
||||
if q.MatchesRecord(iterWrapper) {
|
||||
// copy data
|
||||
duplicate := make([]byte, len(value))
|
||||
copy(duplicate, value)
|
||||
|
||||
newWrapper, err := record.NewRawWrapper(b.name, iterWrapper.DatabaseKey(), duplicate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-queryIter.Done:
|
||||
return nil
|
||||
case queryIter.Next <- newWrapper:
|
||||
default:
|
||||
select {
|
||||
case <-queryIter.Done:
|
||||
return nil
|
||||
case queryIter.Next <- newWrapper:
|
||||
case <-time.After(1 * time.Second):
|
||||
return errors.New("query timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
queryIter.Finish(err)
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (b *BBolt) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Injected returns whether the database is injected.
|
||||
func (b *BBolt) Injected() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// MaintainRecordStates maintains records states in the database.
|
||||
func (b *BBolt) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error { //nolint:gocognit
|
||||
now := time.Now().Unix()
|
||||
purgeThreshold := purgeDeletedBefore.Unix()
|
||||
|
||||
return b.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(bucketName)
|
||||
// Create a cursor for iteration.
|
||||
c := bucket.Cursor()
|
||||
for key, value := c.First(); key != nil; key, value = c.Next() {
|
||||
// check if context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// wrap value
|
||||
wrapper, err := record.NewRawWrapper(b.name, string(key), value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check if we need to do maintenance
|
||||
meta := wrapper.Meta()
|
||||
switch {
|
||||
case meta.Deleted == 0 && meta.Expires > 0 && meta.Expires < now:
|
||||
if shadowDelete {
|
||||
// mark as deleted
|
||||
meta.Deleted = meta.Expires
|
||||
deleted, err := wrapper.MarshalRecord(wrapper)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = bucket.Put(key, deleted)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cursor repositioning is required after modifying data.
|
||||
// While the documentation states that this is also required after a
|
||||
// delete, this actually makes the cursor skip a record with the
|
||||
// following c.Next() call of the loop.
|
||||
// Docs/Issue: https://github.com/boltdb/bolt/issues/426#issuecomment-141982984
|
||||
c.Seek(key)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Immediately delete expired entries if shadowDelete is disabled.
|
||||
fallthrough
|
||||
case meta.Deleted > 0 && (!shadowDelete || meta.Deleted < purgeThreshold):
|
||||
// delete from storage
|
||||
err = c.Delete()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Purge deletes all records that match the given query. It returns the number of successful deletes and an error.
|
||||
func (b *BBolt) Purge(ctx context.Context, q *query.Query, local, internal, shadowDelete bool) (int, error) { //nolint:gocognit
|
||||
prefix := []byte(q.DatabaseKeyPrefix())
|
||||
|
||||
var cnt int
|
||||
var done bool
|
||||
for !done {
|
||||
err := b.db.Update(func(tx *bbolt.Tx) error {
|
||||
// Create a cursor for iteration.
|
||||
bucket := tx.Bucket(bucketName)
|
||||
c := bucket.Cursor()
|
||||
for key, value := c.Seek(prefix); key != nil; key, value = c.Next() {
|
||||
// Check if context has been cancelled.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
done = true
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Check if we still match the key prefix, if not, exit.
|
||||
if !bytes.HasPrefix(key, prefix) {
|
||||
done = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wrap the value in a new wrapper to access the metadata.
|
||||
wrapper, err := record.NewRawWrapper(b.name, string(key), value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we have permission for this record.
|
||||
if !wrapper.Meta().CheckPermission(local, internal) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if record is already deleted.
|
||||
if wrapper.Meta().IsDeleted() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the query matches this record.
|
||||
if !q.MatchesRecord(wrapper) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Delete record.
|
||||
if shadowDelete {
|
||||
// Shadow delete.
|
||||
wrapper.Meta().Delete()
|
||||
deleted, err := wrapper.MarshalRecord(wrapper)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = bucket.Put(key, deleted)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cursor repositioning is required after modifying data.
|
||||
// While the documentation states that this is also required after a
|
||||
// delete, this actually makes the cursor skip a record with the
|
||||
// following c.Next() call of the loop.
|
||||
// Docs/Issue: https://github.com/boltdb/bolt/issues/426#issuecomment-141982984
|
||||
c.Seek(key)
|
||||
|
||||
} else {
|
||||
// Immediate delete.
|
||||
err = c.Delete()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Work in batches of 1000 changes in order to enable other operations in between.
|
||||
cnt++
|
||||
if cnt%1000 == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
done = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the database.
|
||||
func (b *BBolt) Shutdown() error {
|
||||
return b.db.Close()
|
||||
}
|
||||
206
base/database/storage/bbolt/bbolt_test.go
Normal file
206
base/database/storage/bbolt/bbolt_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package bbolt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
// Compile time interface checks.
|
||||
_ storage.Interface = &BBolt{}
|
||||
_ storage.Batcher = &BBolt{}
|
||||
_ storage.Purger = &BBolt{}
|
||||
)
|
||||
|
||||
type TestRecord struct { //nolint:maligned
|
||||
record.Base
|
||||
sync.Mutex
|
||||
S string
|
||||
I int
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
UI uint
|
||||
UI8 uint8
|
||||
UI16 uint16
|
||||
UI32 uint32
|
||||
UI64 uint64
|
||||
F32 float32
|
||||
F64 float64
|
||||
B bool
|
||||
}
|
||||
|
||||
func TestBBolt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir, err := os.MkdirTemp("", "testing-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = os.RemoveAll(testDir) // clean up
|
||||
}()
|
||||
|
||||
// start
|
||||
db, err := NewBBolt("test", testDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a := &TestRecord{
|
||||
S: "banana",
|
||||
I: 42,
|
||||
I8: 42,
|
||||
I16: 42,
|
||||
I32: 42,
|
||||
I64: 42,
|
||||
UI: 42,
|
||||
UI8: 42,
|
||||
UI16: 42,
|
||||
UI32: 42,
|
||||
UI64: 42,
|
||||
F32: 42.42,
|
||||
F64: 42.42,
|
||||
B: true,
|
||||
}
|
||||
a.SetMeta(&record.Meta{})
|
||||
a.Meta().Update()
|
||||
a.SetKey("test:A")
|
||||
|
||||
// put record
|
||||
_, err = db.Put(a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// get and compare
|
||||
r1, err := db.Get("A")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a1 := &TestRecord{}
|
||||
err = record.Unwrap(r1, a1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a, a1) {
|
||||
t.Fatalf("mismatch, got %v", a1)
|
||||
}
|
||||
|
||||
// setup query test records
|
||||
qA := &TestRecord{}
|
||||
qA.SetKey("test:path/to/A")
|
||||
qA.CreateMeta()
|
||||
qB := &TestRecord{}
|
||||
qB.SetKey("test:path/to/B")
|
||||
qB.CreateMeta()
|
||||
qC := &TestRecord{}
|
||||
qC.SetKey("test:path/to/C")
|
||||
qC.CreateMeta()
|
||||
qZ := &TestRecord{}
|
||||
qZ.SetKey("test:z")
|
||||
qZ.CreateMeta()
|
||||
// put
|
||||
_, err = db.Put(qA)
|
||||
if err == nil {
|
||||
_, err = db.Put(qB)
|
||||
}
|
||||
if err == nil {
|
||||
_, err = db.Put(qC)
|
||||
}
|
||||
if err == nil {
|
||||
_, err = db.Put(qZ)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test query
|
||||
q := query.New("test:path/to/").MustBeValid()
|
||||
it, err := db.Query(q, true, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cnt := 0
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
t.Fatal(it.Err())
|
||||
}
|
||||
if cnt != 3 {
|
||||
t.Fatalf("unexpected query result count: %d", cnt)
|
||||
}
|
||||
|
||||
// delete
|
||||
err = db.Delete("A")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check if its gone
|
||||
_, err = db.Get("A")
|
||||
if err == nil {
|
||||
t.Fatal("should fail")
|
||||
}
|
||||
|
||||
// maintenance
|
||||
err = db.MaintainRecordStates(context.TODO(), time.Now(), true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// maintenance
|
||||
err = db.MaintainRecordStates(context.TODO(), time.Now(), false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// purging
|
||||
purger, ok := db.(storage.Purger)
|
||||
if ok {
|
||||
n, err := purger.Purge(context.TODO(), query.New("test:path/to/").MustBeValid(), true, true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 3 {
|
||||
t.Fatalf("unexpected purge delete count: %d", n)
|
||||
}
|
||||
} else {
|
||||
t.Fatal("should implement Purger")
|
||||
}
|
||||
|
||||
// test query
|
||||
q = query.New("test").MustBeValid()
|
||||
it, err = db.Query(q, true, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cnt = 0
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
t.Fatal(it.Err())
|
||||
}
|
||||
if cnt != 1 {
|
||||
t.Fatalf("unexpected query result count: %d", cnt)
|
||||
}
|
||||
|
||||
// shutdown
|
||||
err = db.Shutdown()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
8
base/database/storage/errors.go
Normal file
8
base/database/storage/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package storage
|
||||
|
||||
import "errors"
|
||||
|
||||
// Errors for storages.
|
||||
var (
|
||||
ErrNotFound = errors.New("storage entry not found")
|
||||
)
|
||||
302
base/database/storage/fstree/fstree.go
Normal file
302
base/database/storage/fstree/fstree.go
Normal file
@@ -0,0 +1,302 @@
|
||||
/*
|
||||
Package fstree provides a dead simple file-based database storage backend.
|
||||
It is primarily meant for easy testing or storing big files that can easily be accesses directly, without datastore.
|
||||
*/
|
||||
package fstree
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
"github.com/safing/portmaster/base/utils/renameio"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFileMode = os.FileMode(0o0644)
|
||||
defaultDirMode = os.FileMode(0o0755)
|
||||
onWindows = runtime.GOOS == "windows"
|
||||
)
|
||||
|
||||
// FSTree database storage.
|
||||
type FSTree struct {
|
||||
name string
|
||||
basePath string
|
||||
}
|
||||
|
||||
func init() {
|
||||
_ = storage.Register("fstree", NewFSTree)
|
||||
}
|
||||
|
||||
// NewFSTree returns a (new) FSTree database.
|
||||
func NewFSTree(name, location string) (storage.Interface, error) {
|
||||
basePath, err := filepath.Abs(location)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fstree: failed to validate path %s: %w", location, err)
|
||||
}
|
||||
|
||||
file, err := os.Stat(basePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
err = os.MkdirAll(basePath, defaultDirMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fstree: failed to create directory %s: %w", basePath, err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("fstree: failed to stat path %s: %w", basePath, err)
|
||||
}
|
||||
} else {
|
||||
if !file.IsDir() {
|
||||
return nil, fmt.Errorf("fstree: provided database path (%s) is a file", basePath)
|
||||
}
|
||||
}
|
||||
|
||||
return &FSTree{
|
||||
name: name,
|
||||
basePath: basePath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (fst *FSTree) buildFilePath(key string, checkKeyLength bool) (string, error) {
|
||||
// check key length
|
||||
if checkKeyLength && len(key) < 1 {
|
||||
return "", fmt.Errorf("fstree: key too short: %s", key)
|
||||
}
|
||||
// build filepath
|
||||
dstPath := filepath.Join(fst.basePath, key) // Join also calls Clean()
|
||||
if !strings.HasPrefix(dstPath, fst.basePath) {
|
||||
return "", fmt.Errorf("fstree: key integrity check failed, compiled path is %s", dstPath)
|
||||
}
|
||||
// return
|
||||
return dstPath, nil
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (fst *FSTree) Get(key string) (record.Record, error) {
|
||||
dstPath, err := fst.buildFilePath(key, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dstPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("fstree: failed to read file %s: %w", dstPath, err)
|
||||
}
|
||||
|
||||
r, err := record.NewRawWrapper(fst.name, key, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// GetMeta returns the metadata of a database record.
|
||||
func (fst *FSTree) GetMeta(key string) (*record.Meta, error) {
|
||||
// TODO: Replace with more performant variant.
|
||||
|
||||
r, err := fst.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.Meta(), nil
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (fst *FSTree) Put(r record.Record) (record.Record, error) {
|
||||
dstPath, err := fst.buildFilePath(r.DatabaseKey(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := r.MarshalRecord(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = writeFile(dstPath, data, defaultFileMode)
|
||||
if err != nil {
|
||||
// create dir and try again
|
||||
err = os.MkdirAll(filepath.Dir(dstPath), defaultDirMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fstree: failed to create directory %s: %w", filepath.Dir(dstPath), err)
|
||||
}
|
||||
err = writeFile(dstPath, data, defaultFileMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fstree: could not write file %s: %w", dstPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (fst *FSTree) Delete(key string) error {
|
||||
dstPath, err := fst.buildFilePath(key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove entry
|
||||
err = os.Remove(dstPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fstree: could not delete %s: %w", dstPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query returns a an iterator for the supplied query.
|
||||
func (fst *FSTree) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
walkPrefix, err := fst.buildFilePath(q.DatabaseKeyPrefix(), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileInfo, err := os.Stat(walkPrefix)
|
||||
var walkRoot string
|
||||
switch {
|
||||
case err == nil && fileInfo.IsDir():
|
||||
walkRoot = walkPrefix
|
||||
case err == nil:
|
||||
walkRoot = filepath.Dir(walkPrefix)
|
||||
case errors.Is(err, fs.ErrNotExist):
|
||||
walkRoot = filepath.Dir(walkPrefix)
|
||||
default: // err != nil
|
||||
return nil, fmt.Errorf("fstree: could not stat query root %s: %w", walkPrefix, err)
|
||||
}
|
||||
|
||||
queryIter := iterator.New()
|
||||
|
||||
go fst.queryExecutor(walkRoot, queryIter, q, local, internal)
|
||||
return queryIter, nil
|
||||
}
|
||||
|
||||
func (fst *FSTree) queryExecutor(walkRoot string, queryIter *iterator.Iterator, q *query.Query, local, internal bool) {
|
||||
err := filepath.Walk(walkRoot, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("fstree: error in walking fs: %w", err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// skip dir if not in scope
|
||||
if !strings.HasPrefix(path, fst.basePath) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// continue
|
||||
return nil
|
||||
}
|
||||
|
||||
// still in scope?
|
||||
if !strings.HasPrefix(path, fst.basePath) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// read file
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("fstree: failed to read file %s: %w", path, err)
|
||||
}
|
||||
|
||||
// parse
|
||||
key, err := filepath.Rel(fst.basePath, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fstree: failed to extract key from filepath %s: %w", path, err)
|
||||
}
|
||||
r, err := record.NewRawWrapper(fst.name, key, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fstree: failed to load file %s: %w", path, err)
|
||||
}
|
||||
|
||||
if !r.Meta().CheckValidity() {
|
||||
// record is not valid
|
||||
return nil
|
||||
}
|
||||
|
||||
if !r.Meta().CheckPermission(local, internal) {
|
||||
// no permission to access
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if matches, then send
|
||||
if q.MatchesRecord(r) {
|
||||
select {
|
||||
case queryIter.Next <- r:
|
||||
case <-queryIter.Done:
|
||||
case <-time.After(1 * time.Second):
|
||||
return errors.New("fstree: query buffer full, timeout")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
queryIter.Finish(err)
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (fst *FSTree) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Injected returns whether the database is injected.
|
||||
func (fst *FSTree) Injected() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// MaintainRecordStates maintains records states in the database.
|
||||
func (fst *FSTree) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error {
|
||||
// TODO: implement MaintainRecordStates
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the database.
|
||||
func (fst *FSTree) Shutdown() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeFile mirrors os.WriteFile, replacing an existing file with the same
|
||||
// name atomically. This is not atomic on Windows, but still an improvement.
|
||||
// TODO: Replace with github.com/google/renamio.WriteFile as soon as it is fixed on Windows.
|
||||
// TODO: This has become a wont-fix. Explore other options.
|
||||
// This function is forked from https://github.com/google/renameio/blob/a368f9987532a68a3d676566141654a81aa8100b/writefile.go.
|
||||
func writeFile(filename string, data []byte, perm os.FileMode) error {
|
||||
t, err := renameio.TempFile("", filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer t.Cleanup() //nolint:errcheck
|
||||
|
||||
// Set permissions before writing data, in case the data is sensitive.
|
||||
if !onWindows {
|
||||
if err := t.Chmod(perm); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := t.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return t.CloseAtomicallyReplace()
|
||||
}
|
||||
6
base/database/storage/fstree/fstree_test.go
Normal file
6
base/database/storage/fstree/fstree_test.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package fstree
|
||||
|
||||
import "github.com/safing/portmaster/base/database/storage"
|
||||
|
||||
// Compile time interface checks.
|
||||
var _ storage.Interface = &FSTree{}
|
||||
216
base/database/storage/hashmap/map.go
Normal file
216
base/database/storage/hashmap/map.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package hashmap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
// HashMap storage.
|
||||
type HashMap struct {
|
||||
name string
|
||||
db map[string]record.Record
|
||||
dbLock sync.RWMutex
|
||||
}
|
||||
|
||||
func init() {
|
||||
_ = storage.Register("hashmap", NewHashMap)
|
||||
}
|
||||
|
||||
// NewHashMap creates a hashmap database.
|
||||
func NewHashMap(name, location string) (storage.Interface, error) {
|
||||
return &HashMap{
|
||||
name: name,
|
||||
db: make(map[string]record.Record),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (hm *HashMap) Get(key string) (record.Record, error) {
|
||||
hm.dbLock.RLock()
|
||||
defer hm.dbLock.RUnlock()
|
||||
|
||||
r, ok := hm.db[key]
|
||||
if !ok {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// GetMeta returns the metadata of a database record.
|
||||
func (hm *HashMap) GetMeta(key string) (*record.Meta, error) {
|
||||
// TODO: Replace with more performant variant.
|
||||
|
||||
r, err := hm.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.Meta(), nil
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (hm *HashMap) Put(r record.Record) (record.Record, error) {
|
||||
hm.dbLock.Lock()
|
||||
defer hm.dbLock.Unlock()
|
||||
|
||||
hm.db[r.DatabaseKey()] = r
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
func (hm *HashMap) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error) {
|
||||
hm.dbLock.Lock()
|
||||
defer hm.dbLock.Unlock()
|
||||
// we could lock for every record, but we want to have the same behaviour
|
||||
// as the other storage backends, especially for testing.
|
||||
|
||||
batch := make(chan record.Record, 100)
|
||||
errs := make(chan error, 1)
|
||||
|
||||
// start handler
|
||||
go func() {
|
||||
for r := range batch {
|
||||
hm.batchPutOrDelete(shadowDelete, r)
|
||||
}
|
||||
errs <- nil
|
||||
}()
|
||||
|
||||
return batch, errs
|
||||
}
|
||||
|
||||
func (hm *HashMap) batchPutOrDelete(shadowDelete bool, r record.Record) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
hm.dbLock.Lock()
|
||||
defer hm.dbLock.Unlock()
|
||||
|
||||
if !shadowDelete && r.Meta().IsDeleted() {
|
||||
delete(hm.db, r.DatabaseKey())
|
||||
} else {
|
||||
hm.db[r.DatabaseKey()] = r
|
||||
}
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (hm *HashMap) Delete(key string) error {
|
||||
hm.dbLock.Lock()
|
||||
defer hm.dbLock.Unlock()
|
||||
|
||||
delete(hm.db, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query returns a an iterator for the supplied query.
|
||||
func (hm *HashMap) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
queryIter := iterator.New()
|
||||
|
||||
go hm.queryExecutor(queryIter, q, local, internal)
|
||||
return queryIter, nil
|
||||
}
|
||||
|
||||
func (hm *HashMap) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) {
|
||||
hm.dbLock.RLock()
|
||||
defer hm.dbLock.RUnlock()
|
||||
|
||||
var err error
|
||||
|
||||
mapLoop:
|
||||
for key, record := range hm.db {
|
||||
record.Lock()
|
||||
if !q.MatchesKey(key) ||
|
||||
!q.MatchesRecord(record) ||
|
||||
!record.Meta().CheckValidity() ||
|
||||
!record.Meta().CheckPermission(local, internal) {
|
||||
|
||||
record.Unlock()
|
||||
continue
|
||||
}
|
||||
record.Unlock()
|
||||
|
||||
select {
|
||||
case <-queryIter.Done:
|
||||
break mapLoop
|
||||
case queryIter.Next <- record:
|
||||
default:
|
||||
select {
|
||||
case <-queryIter.Done:
|
||||
break mapLoop
|
||||
case queryIter.Next <- record:
|
||||
case <-time.After(1 * time.Second):
|
||||
err = errors.New("query timeout")
|
||||
break mapLoop
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
queryIter.Finish(err)
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (hm *HashMap) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Injected returns whether the database is injected.
|
||||
func (hm *HashMap) Injected() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// MaintainRecordStates maintains records states in the database.
|
||||
func (hm *HashMap) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error {
|
||||
hm.dbLock.Lock()
|
||||
defer hm.dbLock.Unlock()
|
||||
|
||||
now := time.Now().Unix()
|
||||
purgeThreshold := purgeDeletedBefore.Unix()
|
||||
|
||||
for key, record := range hm.db {
|
||||
// check if context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
meta := record.Meta()
|
||||
switch {
|
||||
case meta.Deleted == 0 && meta.Expires > 0 && meta.Expires < now:
|
||||
if shadowDelete {
|
||||
// mark as deleted
|
||||
record.Lock()
|
||||
meta.Deleted = meta.Expires
|
||||
record.Unlock()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Immediately delete expired entries if shadowDelete is disabled.
|
||||
fallthrough
|
||||
case meta.Deleted > 0 && (!shadowDelete || meta.Deleted < purgeThreshold):
|
||||
// delete from storage
|
||||
delete(hm.db, key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the database.
|
||||
func (hm *HashMap) Shutdown() error {
|
||||
return nil
|
||||
}
|
||||
145
base/database/storage/hashmap/map_test.go
Normal file
145
base/database/storage/hashmap/map_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package hashmap
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
// Compile time interface checks.
|
||||
_ storage.Interface = &HashMap{}
|
||||
_ storage.Batcher = &HashMap{}
|
||||
)
|
||||
|
||||
type TestRecord struct { //nolint:maligned
|
||||
record.Base
|
||||
sync.Mutex
|
||||
S string
|
||||
I int
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
UI uint
|
||||
UI8 uint8
|
||||
UI16 uint16
|
||||
UI32 uint32
|
||||
UI64 uint64
|
||||
F32 float32
|
||||
F64 float64
|
||||
B bool
|
||||
}
|
||||
|
||||
func TestHashMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// start
|
||||
db, err := NewHashMap("test", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a := &TestRecord{
|
||||
S: "banana",
|
||||
I: 42,
|
||||
I8: 42,
|
||||
I16: 42,
|
||||
I32: 42,
|
||||
I64: 42,
|
||||
UI: 42,
|
||||
UI8: 42,
|
||||
UI16: 42,
|
||||
UI32: 42,
|
||||
UI64: 42,
|
||||
F32: 42.42,
|
||||
F64: 42.42,
|
||||
B: true,
|
||||
}
|
||||
a.SetMeta(&record.Meta{})
|
||||
a.Meta().Update()
|
||||
a.SetKey("test:A")
|
||||
|
||||
// put record
|
||||
_, err = db.Put(a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// get and compare
|
||||
a1, err := db.Get("A")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a, a1) {
|
||||
t.Fatalf("mismatch, got %v", a1)
|
||||
}
|
||||
|
||||
// setup query test records
|
||||
qA := &TestRecord{}
|
||||
qA.SetKey("test:path/to/A")
|
||||
qA.CreateMeta()
|
||||
qB := &TestRecord{}
|
||||
qB.SetKey("test:path/to/B")
|
||||
qB.CreateMeta()
|
||||
qC := &TestRecord{}
|
||||
qC.SetKey("test:path/to/C")
|
||||
qC.CreateMeta()
|
||||
qZ := &TestRecord{}
|
||||
qZ.SetKey("test:z")
|
||||
qZ.CreateMeta()
|
||||
// put
|
||||
_, err = db.Put(qA)
|
||||
if err == nil {
|
||||
_, err = db.Put(qB)
|
||||
}
|
||||
if err == nil {
|
||||
_, err = db.Put(qC)
|
||||
}
|
||||
if err == nil {
|
||||
_, err = db.Put(qZ)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test query
|
||||
q := query.New("test:path/to/").MustBeValid()
|
||||
it, err := db.Query(q, true, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cnt := 0
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
t.Fatal(it.Err())
|
||||
}
|
||||
if cnt != 3 {
|
||||
t.Fatalf("unexpected query result count: %d", cnt)
|
||||
}
|
||||
|
||||
// delete
|
||||
err = db.Delete("A")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check if its gone
|
||||
_, err = db.Get("A")
|
||||
if err == nil {
|
||||
t.Fatal("should fail")
|
||||
}
|
||||
|
||||
// shutdown
|
||||
err = db.Shutdown()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
60
base/database/storage/injectbase.go
Normal file
60
base/database/storage/injectbase.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// ErrNotImplemented is returned when a function is not implemented by a storage.
|
||||
var ErrNotImplemented = errors.New("not implemented")
|
||||
|
||||
// InjectBase is a dummy base structure to reduce boilerplate code for injected storage interfaces.
|
||||
type InjectBase struct{}
|
||||
|
||||
// Compile time interface check.
|
||||
var _ Interface = &InjectBase{}
|
||||
|
||||
// Get returns a database record.
|
||||
func (i *InjectBase) Get(key string) (record.Record, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (i *InjectBase) Put(m record.Record) (record.Record, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (i *InjectBase) Delete(key string) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// Query returns a an iterator for the supplied query.
|
||||
func (i *InjectBase) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (i *InjectBase) ReadOnly() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Injected returns whether the database is injected.
|
||||
func (i *InjectBase) Injected() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// MaintainRecordStates maintains records states in the database.
|
||||
func (i *InjectBase) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the database.
|
||||
func (i *InjectBase) Shutdown() error {
|
||||
return nil
|
||||
}
|
||||
48
base/database/storage/interface.go
Normal file
48
base/database/storage/interface.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// Interface defines the database storage API.
|
||||
type Interface interface {
|
||||
// Primary Interface
|
||||
Get(key string) (record.Record, error)
|
||||
Put(m record.Record) (record.Record, error)
|
||||
Delete(key string) error
|
||||
Query(q *query.Query, local, internal bool) (*iterator.Iterator, error)
|
||||
|
||||
// Information and Control
|
||||
ReadOnly() bool
|
||||
Injected() bool
|
||||
Shutdown() error
|
||||
|
||||
// Mandatory Record Maintenance
|
||||
MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error
|
||||
}
|
||||
|
||||
// MetaHandler defines the database storage API for backends that support optimized fetching of only the metadata.
|
||||
type MetaHandler interface {
|
||||
GetMeta(key string) (*record.Meta, error)
|
||||
}
|
||||
|
||||
// Maintainer defines the database storage API for backends that require regular maintenance.
|
||||
type Maintainer interface {
|
||||
Maintain(ctx context.Context) error
|
||||
MaintainThorough(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Batcher defines the database storage API for backends that support batch operations.
|
||||
type Batcher interface {
|
||||
PutMany(shadowDelete bool) (batch chan<- record.Record, errs <-chan error)
|
||||
}
|
||||
|
||||
// Purger defines the database storage API for backends that support the purge operation.
|
||||
type Purger interface {
|
||||
Purge(ctx context.Context, q *query.Query, local, internal, shadowDelete bool) (int, error)
|
||||
}
|
||||
111
base/database/storage/sinkhole/sinkhole.go
Normal file
111
base/database/storage/sinkhole/sinkhole.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package sinkhole
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
// Sinkhole is a dummy storage.
|
||||
type Sinkhole struct {
|
||||
name string
|
||||
}
|
||||
|
||||
var (
|
||||
// Compile time interface checks.
|
||||
_ storage.Interface = &Sinkhole{}
|
||||
_ storage.Maintainer = &Sinkhole{}
|
||||
_ storage.Batcher = &Sinkhole{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
_ = storage.Register("sinkhole", NewSinkhole)
|
||||
}
|
||||
|
||||
// NewSinkhole creates a dummy database.
|
||||
func NewSinkhole(name, location string) (storage.Interface, error) {
|
||||
return &Sinkhole{
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Exists returns whether an entry with the given key exists.
|
||||
func (s *Sinkhole) Exists(key string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (s *Sinkhole) Get(key string) (record.Record, error) {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
// GetMeta returns the metadata of a database record.
|
||||
func (s *Sinkhole) GetMeta(key string) (*record.Meta, error) {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (s *Sinkhole) Put(r record.Record) (record.Record, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
func (s *Sinkhole) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error) {
|
||||
batch := make(chan record.Record, 100)
|
||||
errs := make(chan error, 1)
|
||||
|
||||
// start handler
|
||||
go func() {
|
||||
for range batch {
|
||||
// discard everything
|
||||
}
|
||||
errs <- nil
|
||||
}()
|
||||
|
||||
return batch, errs
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (s *Sinkhole) Delete(key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query returns a an iterator for the supplied query.
|
||||
func (s *Sinkhole) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
return nil, errors.New("query not implemented by sinkhole")
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (s *Sinkhole) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Injected returns whether the database is injected.
|
||||
func (s *Sinkhole) Injected() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Maintain runs a light maintenance operation on the database.
|
||||
func (s *Sinkhole) Maintain(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainThorough runs a thorough maintenance operation on the database.
|
||||
func (s *Sinkhole) MaintainThorough(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainRecordStates maintains records states in the database.
|
||||
func (s *Sinkhole) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the database.
|
||||
func (s *Sinkhole) Shutdown() error {
|
||||
return nil
|
||||
}
|
||||
47
base/database/storage/storages.go
Normal file
47
base/database/storage/storages.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A Factory creates a new database of it's type.
|
||||
type Factory func(name, location string) (Interface, error)
|
||||
|
||||
var (
|
||||
storages = make(map[string]Factory)
|
||||
storagesLock sync.Mutex
|
||||
)
|
||||
|
||||
// Register registers a new storage type.
|
||||
func Register(name string, factory Factory) error {
|
||||
storagesLock.Lock()
|
||||
defer storagesLock.Unlock()
|
||||
|
||||
_, ok := storages[name]
|
||||
if ok {
|
||||
return errors.New("factory for this type already exists")
|
||||
}
|
||||
|
||||
storages[name] = factory
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateDatabase starts a new database with the given name and storageType at location.
|
||||
func CreateDatabase(name, storageType, location string) (Interface, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// StartDatabase starts a new database with the given name and storageType at location.
|
||||
func StartDatabase(name, storageType, location string) (Interface, error) {
|
||||
storagesLock.Lock()
|
||||
defer storagesLock.Unlock()
|
||||
|
||||
factory, ok := storages[storageType]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("storage type %s not registered", storageType)
|
||||
}
|
||||
|
||||
return factory(name, location)
|
||||
}
|
||||
35
base/database/subscription.go
Normal file
35
base/database/subscription.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// Subscription is a database subscription for updates.
|
||||
type Subscription struct {
|
||||
q *query.Query
|
||||
local bool
|
||||
internal bool
|
||||
|
||||
Feed chan record.Record
|
||||
}
|
||||
|
||||
// Cancel cancels the subscription.
|
||||
func (s *Subscription) Cancel() error {
|
||||
c, err := getController(s.q.DatabaseName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.subscriptionLock.Lock()
|
||||
defer c.subscriptionLock.Unlock()
|
||||
|
||||
for key, sub := range c.subscriptions {
|
||||
if sub.q == s.q {
|
||||
c.subscriptions = append(c.subscriptions[:key], c.subscriptions[key+1:]...)
|
||||
close(s.Feed) // this close is guarded by the controllers subscriptionLock.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user