wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
483
service/netquery/orm/decoder.go
Normal file
483
service/netquery/orm/decoder.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// Commonly used error messages when working with orm.
|
||||
var (
|
||||
errStructExpected = errors.New("encode: can only encode structs to maps")
|
||||
errStructPointerExpected = errors.New("decode: result must be pointer to a struct type or map[string]interface{}")
|
||||
errUnexpectedColumnType = errors.New("decode: unexpected column type")
|
||||
)
|
||||
|
||||
// constants used when transforming data to and from sqlite.
|
||||
var (
|
||||
// sqliteTimeFromat defines the string representation that is
|
||||
// expected by SQLite DATETIME functions.
|
||||
// Note that SQLite itself does not include support for a DATETIME
|
||||
// column type. Instead, dates and times are stored either as INTEGER,
|
||||
// TEXT or REAL.
|
||||
// This package provides support for time.Time being stored as TEXT (using a
|
||||
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
|
||||
// unixepoch and unixnano-epoch where the nano variant is not officially supported by
|
||||
// SQLITE).
|
||||
SqliteTimeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
type (
|
||||
|
||||
// Stmt describes the interface that must be implemented in order to
|
||||
// be decodable to a struct type using DecodeStmt. This interface is implemented
|
||||
// by *sqlite.Stmt.
|
||||
Stmt interface {
|
||||
ColumnCount() int
|
||||
ColumnName(int) string
|
||||
ColumnType(int) sqlite.ColumnType
|
||||
ColumnText(int) string
|
||||
ColumnBool(int) bool
|
||||
ColumnFloat(int) float64
|
||||
ColumnInt(int) int
|
||||
ColumnReader(int) *bytes.Reader
|
||||
}
|
||||
|
||||
// DecodeFunc is called for each non-basic type during decoding.
|
||||
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
|
||||
|
||||
// DecodeConfig holds decoding functions.
|
||||
DecodeConfig struct {
|
||||
DecodeHooks []DecodeFunc
|
||||
}
|
||||
)
|
||||
|
||||
// DecodeStmt decodes the current result row loaded in Stmt into the struct or map type result.
|
||||
// Decoding hooks configured in cfg are executed before trying to decode basic types and may
|
||||
// be specified to provide support for special types.
|
||||
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
|
||||
func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result interface{}, cfg DecodeConfig) error {
|
||||
// make sure we got something to decode into ...
|
||||
if result == nil {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// fast path for decoding into a map
|
||||
if mp, ok := result.(*map[string]interface{}); ok {
|
||||
return decodeIntoMap(ctx, schema, stmt, mp, cfg)
|
||||
}
|
||||
|
||||
// make sure we got a pointer in result
|
||||
if reflect.TypeOf(result).Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// make sure it's a poiter to a struct type
|
||||
t := reflect.ValueOf(result).Elem().Type()
|
||||
if t.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// if result is a nil pointer make sure to allocate some space
|
||||
// for the resulting struct
|
||||
resultValue := reflect.ValueOf(result)
|
||||
if resultValue.IsNil() {
|
||||
resultValue.Set(
|
||||
reflect.New(t),
|
||||
)
|
||||
}
|
||||
|
||||
// we need access to the struct directly and not to the
|
||||
// pointer.
|
||||
target := reflect.Indirect(resultValue)
|
||||
|
||||
// create a lookup map from field name (or sqlite:"" tag)
|
||||
// to the field name
|
||||
lm := make(map[string]string)
|
||||
for i := 0; i < target.NumField(); i++ {
|
||||
fieldType := t.Field(i)
|
||||
|
||||
// skip unexported fields
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
lm[sqlColumnName(fieldType)] = fieldType.Name
|
||||
}
|
||||
|
||||
// iterate over all columns and assign them to the correct
|
||||
// fields
|
||||
for i := 0; i < stmt.ColumnCount(); i++ {
|
||||
colName := stmt.ColumnName(i)
|
||||
fieldName, ok := lm[colName]
|
||||
if !ok {
|
||||
// there's no target field for this column
|
||||
// so we can skip it
|
||||
continue
|
||||
}
|
||||
fieldType, _ := t.FieldByName(fieldName)
|
||||
|
||||
value := target.FieldByName(fieldName)
|
||||
|
||||
colType := stmt.ColumnType(i)
|
||||
|
||||
// if the column is reported as NULL we keep
|
||||
// the field as it is.
|
||||
// TODO(ppacher): should we set it to nil here?
|
||||
if colType == sqlite.TypeNull {
|
||||
continue
|
||||
}
|
||||
|
||||
// if value is a nil pointer we need to allocate some memory
|
||||
// first
|
||||
if getKind(value) == reflect.Ptr && value.IsNil() {
|
||||
storage := reflect.New(fieldType.Type.Elem())
|
||||
|
||||
value.Set(storage)
|
||||
|
||||
// make sure value actually points the
|
||||
// dereferenced target storage
|
||||
value = storage.Elem()
|
||||
}
|
||||
|
||||
colDef := schema.GetColumnDef(colName)
|
||||
|
||||
// execute all decode hooks but make sure we use decodeBasic() as the
|
||||
// last one.
|
||||
columnValue, err := runDecodeHooks(
|
||||
i,
|
||||
colDef,
|
||||
stmt,
|
||||
fieldType,
|
||||
value,
|
||||
append(cfg.DecodeHooks, decodeBasic()),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we don't have a converted value now we try to
|
||||
// decode basic types
|
||||
if columnValue == nil {
|
||||
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
|
||||
}
|
||||
|
||||
// Debugging:
|
||||
// log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
|
||||
|
||||
// convert it to the target type if conversion is possible
|
||||
newValue := reflect.ValueOf(columnValue)
|
||||
if newValue.Type().ConvertibleTo(value.Type()) {
|
||||
newValue = newValue.Convert(value.Type())
|
||||
}
|
||||
|
||||
// assign the new value to the struct field.
|
||||
value.Set(newValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DatetimeDecoder is capable of decoding sqlite INTEGER or TEXT storage classes into
|
||||
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
|
||||
// decide between Unix or UnixNano epoch timestamps.
|
||||
//
|
||||
// TODO(ppacher): update comment about loc parameter and TEXT storage class parsing.
|
||||
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
|
||||
// if we have the column definition available we
|
||||
// use the target go type from there.
|
||||
outType := outval.Type()
|
||||
|
||||
if colDef != nil {
|
||||
outType = colDef.GoType
|
||||
}
|
||||
|
||||
// we only care about "time.Time" here
|
||||
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
|
||||
// log.Printf("not decoding %s %v", outType, colDef)
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // Only selecting specific types.
|
||||
case sqlite.TypeInteger:
|
||||
// stored as unix-epoch, if unixnano is set in the struct field tag
|
||||
// we parse it with nano-second resolution
|
||||
// TODO(ppacher): actually split the tag value at "," and search
|
||||
// the slice for "unixnano"
|
||||
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
|
||||
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), true, nil
|
||||
}
|
||||
|
||||
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), true, nil
|
||||
|
||||
case sqlite.TypeText:
|
||||
// stored ISO8601 but does not have any timezone information
|
||||
// assigned so we always treat it as loc here.
|
||||
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
||||
}
|
||||
|
||||
return t, true, nil
|
||||
|
||||
case sqlite.TypeFloat:
|
||||
// stored as Julian day numbers
|
||||
return nil, false, fmt.Errorf("REAL storage type not support for time.Time")
|
||||
|
||||
case sqlite.TypeNull:
|
||||
return nil, true, nil
|
||||
|
||||
default:
|
||||
return nil, false, fmt.Errorf("unsupported storage type for time.Time: %s", stmt.ColumnType(colIdx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
|
||||
if *mp == nil {
|
||||
*mp = make(map[string]interface{})
|
||||
}
|
||||
|
||||
for i := 0; i < stmt.ColumnCount(); i++ {
|
||||
var x interface{}
|
||||
|
||||
colDef := schema.GetColumnDef(stmt.ColumnName(i))
|
||||
|
||||
outVal := reflect.ValueOf(&x).Elem()
|
||||
fieldType := reflect.StructField{}
|
||||
if colDef != nil {
|
||||
outVal = reflect.New(colDef.GoType).Elem()
|
||||
fieldType = reflect.StructField{
|
||||
Type: colDef.GoType,
|
||||
}
|
||||
}
|
||||
|
||||
val, err := runDecodeHooks(
|
||||
i,
|
||||
colDef,
|
||||
stmt,
|
||||
fieldType,
|
||||
outVal,
|
||||
append(cfg.DecodeHooks, decodeBasic()),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
|
||||
}
|
||||
|
||||
(*mp)[stmt.ColumnName(i)] = val
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeBasic() DecodeFunc {
|
||||
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (result interface{}, handled bool, err error) {
|
||||
valueKind := getKind(outval)
|
||||
colType := stmt.ColumnType(colIdx)
|
||||
colName := stmt.ColumnName(colIdx)
|
||||
|
||||
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
|
||||
|
||||
// if we have the column definition available we
|
||||
// use the target go type from there.
|
||||
if colDef != nil {
|
||||
valueKind = NormalizeKind(colDef.GoType.Kind())
|
||||
|
||||
// if we have a column definition we try to convert the value to
|
||||
// the actual Go-type that was used in the model.
|
||||
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
|
||||
// or that type aliases like (type myInt int) are decoded into myInt instead of int
|
||||
defer func() {
|
||||
if handled {
|
||||
t := reflect.New(colDef.GoType).Elem()
|
||||
|
||||
if result == nil || reflect.ValueOf(result).IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.ValueOf(result).Type().ConvertibleTo(colDef.GoType) {
|
||||
result = reflect.ValueOf(result).Convert(colDef.GoType).Interface()
|
||||
}
|
||||
t.Set(reflect.ValueOf(result))
|
||||
|
||||
result = t.Interface()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// log.Printf("decoding %s into kind %s", colName, valueKind)
|
||||
|
||||
if colType == sqlite.TypeNull {
|
||||
if colDef != nil && colDef.Nullable {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
if colDef != nil && !colDef.Nullable {
|
||||
return reflect.New(colDef.GoType).Elem().Interface(), true, nil
|
||||
}
|
||||
|
||||
if outval.Kind() == reflect.Ptr {
|
||||
return nil, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
switch valueKind { //nolint:exhaustive
|
||||
case reflect.String:
|
||||
if colType != sqlite.TypeText {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnText(colIdx), true, nil
|
||||
|
||||
case reflect.Bool:
|
||||
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
|
||||
// with INTEGER affinity.
|
||||
if colType != sqlite.TypeInteger {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnBool(colIdx), true, nil
|
||||
|
||||
case reflect.Float64:
|
||||
if colType != sqlite.TypeFloat {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnFloat(colIdx), true, nil
|
||||
|
||||
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
|
||||
if colType != sqlite.TypeInteger {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
|
||||
return stmt.ColumnInt(colIdx), true, nil
|
||||
|
||||
case reflect.Slice:
|
||||
if outval.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, false, fmt.Errorf("slices other than []byte for BLOB are not supported")
|
||||
}
|
||||
|
||||
if colType != sqlite.TypeBlob {
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
|
||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
}
|
||||
|
||||
return columnValue, true, nil
|
||||
|
||||
case reflect.Interface:
|
||||
var (
|
||||
t reflect.Type
|
||||
x interface{}
|
||||
)
|
||||
switch colType {
|
||||
case sqlite.TypeBlob:
|
||||
t = reflect.TypeOf([]byte{})
|
||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
}
|
||||
x = columnValue
|
||||
|
||||
case sqlite.TypeFloat:
|
||||
t = reflect.TypeOf(float64(0))
|
||||
x = stmt.ColumnFloat(colIdx)
|
||||
|
||||
case sqlite.TypeInteger:
|
||||
t = reflect.TypeOf(int(0))
|
||||
x = stmt.ColumnInt(colIdx)
|
||||
|
||||
case sqlite.TypeText:
|
||||
t = reflect.TypeOf(string(""))
|
||||
x = stmt.ColumnText(colIdx)
|
||||
|
||||
case sqlite.TypeNull:
|
||||
t = nil
|
||||
x = nil
|
||||
|
||||
default:
|
||||
return nil, false, fmt.Errorf("unsupported column type %s", colType)
|
||||
}
|
||||
|
||||
if t == nil {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
target := reflect.New(t).Elem()
|
||||
target.Set(reflect.ValueOf(x))
|
||||
|
||||
return target.Interface(), true, nil
|
||||
|
||||
default:
|
||||
return nil, false, fmt.Errorf("cannot decode into %s", valueKind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sqlColumnName(fieldType reflect.StructField) string {
|
||||
tagValue, hasTag := fieldType.Tag.Lookup("sqlite")
|
||||
if !hasTag {
|
||||
return fieldType.Name
|
||||
}
|
||||
|
||||
parts := strings.Split(tagValue, ",")
|
||||
if parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return fieldType.Name
|
||||
}
|
||||
|
||||
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
|
||||
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
|
||||
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
|
||||
func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
|
||||
for _, fn := range hooks {
|
||||
res, end, err := fn(colIdx, colDef, stmt, fieldDef, outval)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if end {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getKind returns the kind of value but normalized Int, Uint and Float variants.
|
||||
// to their base type.
|
||||
func getKind(val reflect.Value) reflect.Kind {
|
||||
kind := val.Kind()
|
||||
return NormalizeKind(kind)
|
||||
}
|
||||
|
||||
// NormalizeKind returns a normalized kind of the given kind.
|
||||
func NormalizeKind(kind reflect.Kind) reflect.Kind {
|
||||
switch {
|
||||
case kind >= reflect.Int && kind <= reflect.Int64:
|
||||
return reflect.Int
|
||||
case kind >= reflect.Uint && kind <= reflect.Uint64:
|
||||
return reflect.Uint
|
||||
case kind >= reflect.Float32 && kind <= reflect.Float64:
|
||||
return reflect.Float64
|
||||
default:
|
||||
return kind
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultDecodeConfig holds the default decoding configuration.
|
||||
var DefaultDecodeConfig = DecodeConfig{
|
||||
DecodeHooks: []DecodeFunc{
|
||||
DatetimeDecoder(time.UTC),
|
||||
},
|
||||
}
|
||||
570
service/netquery/orm/decoder_test.go
Normal file
570
service/netquery/orm/decoder_test.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
type testStmt struct {
|
||||
columns []string
|
||||
values []interface{}
|
||||
types []sqlite.ColumnType
|
||||
}
|
||||
|
||||
func (ts testStmt) ColumnCount() int { return len(ts.columns) }
|
||||
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
|
||||
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
|
||||
|
||||
// Compile time check.
|
||||
var _ Stmt = new(testStmt)
|
||||
|
||||
type exampleFieldTypes struct {
|
||||
S string
|
||||
I int
|
||||
F float64
|
||||
B bool
|
||||
}
|
||||
|
||||
type examplePointerTypes struct {
|
||||
S *string
|
||||
I *int
|
||||
F *float64
|
||||
B *bool
|
||||
}
|
||||
|
||||
type exampleStructTags struct {
|
||||
S string `sqlite:"col_string"`
|
||||
I int `sqlite:"col_int"`
|
||||
}
|
||||
|
||||
type exampleIntConv struct {
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
I int
|
||||
}
|
||||
|
||||
type exampleBlobTypes struct {
|
||||
B []byte
|
||||
}
|
||||
|
||||
type exampleJSONRawTypes struct {
|
||||
B json.RawMessage
|
||||
}
|
||||
|
||||
type exampleTimeTypes struct {
|
||||
T time.Time
|
||||
TP *time.Time
|
||||
}
|
||||
|
||||
type exampleInterface struct {
|
||||
I interface{}
|
||||
IP *interface{}
|
||||
}
|
||||
|
||||
func (ett *exampleTimeTypes) Equal(other interface{}) bool {
|
||||
oett, ok := other.(*exampleTimeTypes)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
|
||||
}
|
||||
|
||||
type myInt int
|
||||
|
||||
type exampleTimeNano struct {
|
||||
T time.Time `sqlite:",unixnano"`
|
||||
}
|
||||
|
||||
func (etn *exampleTimeNano) Equal(other interface{}) bool {
|
||||
oetn, ok := other.(*exampleTimeNano)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return etn.T.Equal(oetn.T)
|
||||
}
|
||||
|
||||
func TestDecoder(t *testing.T) { //nolint:maintidx,tparallel
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Stmt testStmt
|
||||
ColumnDef []ColumnDef
|
||||
Result interface{}
|
||||
Expected interface{}
|
||||
}{
|
||||
{
|
||||
"Decoding into nil is not allowed",
|
||||
testStmt{
|
||||
columns: nil,
|
||||
values: nil,
|
||||
types: nil,
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Decoding into basic types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into basic types with different order",
|
||||
testStmt{
|
||||
columns: []string{"I", "S", "B", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
"string value",
|
||||
true,
|
||||
1.2,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into basic types with missing values",
|
||||
testStmt{
|
||||
columns: []string{"F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into pointer types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
s := "string value"
|
||||
i := 1
|
||||
f := 1.2
|
||||
b := true
|
||||
|
||||
return &examplePointerTypes{
|
||||
S: &s,
|
||||
I: &i,
|
||||
F: &f,
|
||||
B: &b,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into pointer types with missing values",
|
||||
testStmt{
|
||||
columns: []string{"S", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
s := "string value"
|
||||
b := true
|
||||
|
||||
return &examplePointerTypes{
|
||||
S: &s,
|
||||
B: &b,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into fields with struct tags",
|
||||
testStmt{
|
||||
columns: []string{"col_string", "col_int"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleStructTags{},
|
||||
&exampleStructTags{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into correct int type",
|
||||
testStmt{
|
||||
columns: []string{"I8", "I16", "I32", "I64", "I"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleIntConv{},
|
||||
&exampleIntConv{
|
||||
1, 1, 1, 1, 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling NULL values for basic types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
// we use nil here but actually that does not matter
|
||||
nil,
|
||||
nil,
|
||||
1.0,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
F: 1.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling NULL values for pointer types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
// we use nil here but actually that does not matter
|
||||
nil,
|
||||
nil,
|
||||
1.0,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
f := 1.0
|
||||
|
||||
return &examplePointerTypes{F: &f}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling blob types",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleBlobTypes{},
|
||||
&exampleBlobTypes{
|
||||
B: ([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling blob types as json.RawMessage",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleJSONRawTypes{},
|
||||
&exampleJSONRawTypes{
|
||||
B: (json.RawMessage)("hello world"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling time.Time and pointers to it",
|
||||
testStmt{
|
||||
columns: []string{"T", "TP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
int(refTime.Unix()),
|
||||
int(refTime.Unix()),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleTimeTypes{},
|
||||
&exampleTimeTypes{
|
||||
T: refTime,
|
||||
TP: &refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling time.Time in nano-second resolution (struct tags)",
|
||||
testStmt{
|
||||
columns: []string{"T", "TP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
int(refTime.UnixNano()),
|
||||
int(refTime.UnixNano()),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleTimeNano{},
|
||||
&exampleTimeNano{
|
||||
T: refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into interface",
|
||||
testStmt{
|
||||
columns: []string{"I", "IP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeText,
|
||||
},
|
||||
values: []interface{}{
|
||||
"value1",
|
||||
"value2",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleInterface{},
|
||||
func() interface{} {
|
||||
var x interface{} = "value2"
|
||||
|
||||
return &exampleInterface{
|
||||
I: "value1",
|
||||
IP: &x,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into map[string]interface{}",
|
||||
testStmt{
|
||||
columns: []string{"I", "F", "S", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
1.1,
|
||||
"string value",
|
||||
[]byte("blob value"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"I": 1,
|
||||
"F": 1.1,
|
||||
"S": "string value",
|
||||
"B": []byte("blob value"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding using type-hints",
|
||||
testStmt{
|
||||
columns: []string{"B", "T"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeText,
|
||||
},
|
||||
values: []interface{}{
|
||||
true,
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "B",
|
||||
Type: sqlite.TypeInteger,
|
||||
GoType: reflect.TypeOf(true),
|
||||
},
|
||||
{
|
||||
Name: "T",
|
||||
Type: sqlite.TypeText,
|
||||
GoType: reflect.TypeOf(time.Time{}),
|
||||
IsTime: true,
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"B": true,
|
||||
"T": refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into type aliases",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
[]byte(`{"foo": "bar}`),
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "B",
|
||||
Type: sqlite.TypeBlob,
|
||||
GoType: reflect.TypeOf(json.RawMessage(`{"foo": "bar}`)),
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"B": json.RawMessage(`{"foo": "bar}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into type aliases #2",
|
||||
testStmt{
|
||||
columns: []string{"I"},
|
||||
types: []sqlite.ColumnType{sqlite.TypeInteger},
|
||||
values: []interface{}{
|
||||
10,
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "I",
|
||||
Type: sqlite.TypeInteger,
|
||||
GoType: reflect.TypeOf(myInt(0)),
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"I": myInt(10),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
// log.Println(c.Desc)
|
||||
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
|
||||
if fn, ok := c.Expected.(func() interface{}); ok {
|
||||
c.Expected = fn()
|
||||
}
|
||||
|
||||
if c.Expected == nil {
|
||||
assert.Error(t, err, c.Desc)
|
||||
} else {
|
||||
assert.NoError(t, err, c.Desc)
|
||||
|
||||
if equaler, ok := c.Expected.(interface{ Equal(x interface{}) bool }); ok {
|
||||
assert.True(t, equaler.Equal(c.Result))
|
||||
} else {
|
||||
assert.Equal(t, c.Expected, c.Result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
258
service/netquery/orm/encoder.go
Normal file
258
service/netquery/orm/encoder.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
type (
|
||||
// EncodeFunc is called for each non-basic type during encoding.
|
||||
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
|
||||
|
||||
// EncodeConfig holds encoding functions.
|
||||
EncodeConfig struct {
|
||||
EncodeHooks []EncodeFunc
|
||||
}
|
||||
)
|
||||
|
||||
// ToParamMap returns a map that contains the sqlite compatible value of each struct field of
|
||||
// r using the sqlite column name as a map key. It either uses the name of the
|
||||
// exported struct field or the value of the "sqlite" tag.
|
||||
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) {
|
||||
// make sure we work on a struct type
|
||||
val := reflect.Indirect(reflect.ValueOf(r))
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("%w, got %T", errStructExpected, r)
|
||||
}
|
||||
|
||||
res := make(map[string]interface{}, val.NumField())
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
fieldType := val.Type().Field(i)
|
||||
field := val.Field(i)
|
||||
|
||||
// skip unexported fields
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
colDef, err := getColumnDef(fieldType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
if slices.Contains(skipFields, colDef.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
x, found, err := runEncodeHooks(
|
||||
colDef,
|
||||
fieldType.Type,
|
||||
field,
|
||||
append(
|
||||
cfg.EncodeHooks,
|
||||
encodeBasic(),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
if reflect.Indirect(field).IsValid() {
|
||||
x = reflect.Indirect(field).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
res[keyPrefix+sqlColumnName(fieldType)] = x
|
||||
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// EncodeValue encodes the given value.
|
||||
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
|
||||
fieldValue := reflect.ValueOf(val)
|
||||
fieldType := reflect.TypeOf(val)
|
||||
|
||||
x, found, err := runEncodeHooks(
|
||||
colDef,
|
||||
fieldType,
|
||||
fieldValue,
|
||||
append(
|
||||
cfg.EncodeHooks,
|
||||
encodeBasic(),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
if reflect.Indirect(fieldValue).IsValid() {
|
||||
x = reflect.Indirect(fieldValue).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return x, nil
|
||||
}
|
||||
|
||||
func encodeBasic() EncodeFunc {
|
||||
return func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
||||
kind := valType.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
valType = valType.Elem()
|
||||
kind = valType.Kind()
|
||||
|
||||
if val.IsNil() {
|
||||
if !col.Nullable {
|
||||
// we need to set the zero value here since the column
|
||||
// is not marked as nullable
|
||||
return reflect.New(valType).Elem().Interface(), true, nil
|
||||
}
|
||||
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
switch NormalizeKind(kind) { //nolint:exhaustive
|
||||
case reflect.String,
|
||||
reflect.Float64,
|
||||
reflect.Bool,
|
||||
reflect.Int,
|
||||
reflect.Uint:
|
||||
// sqlite package handles conversion of those types
|
||||
// already
|
||||
return val.Interface(), true, nil
|
||||
|
||||
case reflect.Slice:
|
||||
if valType.Elem().Kind() == reflect.Uint8 {
|
||||
// this is []byte
|
||||
return val.Interface(), true, nil
|
||||
}
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
return nil, false, fmt.Errorf("cannot convert value of kind %s for use in SQLite", kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DatetimeEncoder returns a new datetime encoder for the given time zone.
|
||||
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
||||
// if fieldType holds a pointer we need to dereference the value
|
||||
ft := valType.String()
|
||||
if valType.Kind() == reflect.Ptr {
|
||||
ft = valType.Elem().String()
|
||||
val = reflect.Indirect(val)
|
||||
}
|
||||
|
||||
normalizedKind := NormalizeKind(valType.Kind())
|
||||
|
||||
// we only care about "time.Time" here
|
||||
var t time.Time
|
||||
switch {
|
||||
case ft == "time.Time":
|
||||
// handle the zero time as a NULL.
|
||||
if !val.IsValid() || val.IsZero() {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var ok bool
|
||||
valInterface := val.Interface()
|
||||
t, ok = valInterface.(time.Time)
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
|
||||
}
|
||||
|
||||
case valType.Kind() == reflect.String && colDef.IsTime:
|
||||
var err error
|
||||
t, err = time.Parse(time.RFC3339, val.String())
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
|
||||
}
|
||||
|
||||
case (normalizedKind == reflect.Int || normalizedKind == reflect.Uint || normalizedKind == reflect.Float64) && colDef.IsTime:
|
||||
seconds := int64(0)
|
||||
switch normalizedKind { //nolint:exhaustive // Previous switch case assures these types.
|
||||
case reflect.Int:
|
||||
seconds = val.Int()
|
||||
case reflect.Uint:
|
||||
seconds = int64(val.Uint())
|
||||
case reflect.Float64:
|
||||
seconds = int64(val.Float())
|
||||
}
|
||||
|
||||
t = time.Unix(seconds, 0)
|
||||
|
||||
default:
|
||||
// we don't care ...
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch colDef.Type { //nolint:exhaustive
|
||||
case sqlite.TypeInteger:
|
||||
if colDef.UnixNano {
|
||||
return t.UnixNano(), true, nil
|
||||
}
|
||||
return t.Unix(), true, nil
|
||||
|
||||
case sqlite.TypeText:
|
||||
str := t.In(loc).Format(SqliteTimeFormat)
|
||||
|
||||
return str, true, nil
|
||||
}
|
||||
|
||||
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDef.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
||||
if valType == nil {
|
||||
if !colDef.Nullable {
|
||||
switch colDef.Type { //nolint:exhaustive
|
||||
case sqlite.TypeBlob:
|
||||
return []byte{}, true, nil
|
||||
case sqlite.TypeFloat:
|
||||
return 0.0, true, nil
|
||||
case sqlite.TypeText:
|
||||
return "", true, nil
|
||||
case sqlite.TypeInteger:
|
||||
return 0, true, nil
|
||||
default:
|
||||
return nil, false, fmt.Errorf("unsupported sqlite data type: %s", colDef.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
for _, fn := range hooks {
|
||||
res, end, err := fn(colDef, valType, val)
|
||||
if err != nil {
|
||||
return res, false, err
|
||||
}
|
||||
|
||||
if end {
|
||||
return res, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// DefaultEncodeConfig holds the default encoding configuration.
|
||||
var DefaultEncodeConfig = EncodeConfig{
|
||||
EncodeHooks: []EncodeFunc{
|
||||
DatetimeEncoder(time.UTC),
|
||||
},
|
||||
}
|
||||
260
service/netquery/orm/encoder_test.go
Normal file
260
service/netquery/orm/encoder_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Input interface{}
|
||||
Expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
"Encode basic types",
|
||||
struct {
|
||||
I int
|
||||
F float64
|
||||
S string
|
||||
B []byte
|
||||
}{
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
S: "string",
|
||||
B: ([]byte)("bytes"),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 1,
|
||||
"F": 1.2,
|
||||
"S": "string",
|
||||
"B": ([]byte)("bytes"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Encode using struct tags",
|
||||
struct {
|
||||
I int `sqlite:"col_int"`
|
||||
S string `sqlite:"col_string"`
|
||||
}{
|
||||
I: 1,
|
||||
S: "string value",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"col_int": 1,
|
||||
"col_string": "string value",
|
||||
},
|
||||
},
|
||||
{
|
||||
"Ignore Private fields",
|
||||
struct {
|
||||
I int
|
||||
s string
|
||||
}{
|
||||
I: 1,
|
||||
s: "string value",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle Pointers",
|
||||
struct {
|
||||
I *int
|
||||
S *string
|
||||
}{
|
||||
I: new(int),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 0,
|
||||
"S": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle time.Time types",
|
||||
struct {
|
||||
TinInt time.Time `sqlite:",integer,unixnano"`
|
||||
TinString time.Time `sqlite:",text"`
|
||||
}{
|
||||
TinInt: refTime,
|
||||
TinString: refTime,
|
||||
},
|
||||
map[string]interface{}{
|
||||
"TinInt": refTime.UnixNano(),
|
||||
"TinString": refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle time.Time pointer types",
|
||||
struct {
|
||||
TinInt *time.Time `sqlite:",integer,unixnano"`
|
||||
TinString *time.Time `sqlite:",text"`
|
||||
Tnil1 *time.Time `sqlite:",text"`
|
||||
Tnil2 *time.Time `sqlite:",text"`
|
||||
}{
|
||||
TinInt: &refTime,
|
||||
TinString: &refTime,
|
||||
Tnil1: nil,
|
||||
Tnil2: (*time.Time)(nil),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"TinInt": refTime.UnixNano(),
|
||||
"TinString": refTime.Format(SqliteTimeFormat),
|
||||
"Tnil1": nil,
|
||||
"Tnil2": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Expected, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeValue(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Column ColumnDef
|
||||
Input interface{}
|
||||
Output interface{}
|
||||
}{
|
||||
{
|
||||
"Special value time.Time as text",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
refTime,
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
{
|
||||
"Special value time.Time as unix-epoch",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeInteger,
|
||||
},
|
||||
refTime,
|
||||
refTime.Unix(),
|
||||
},
|
||||
{
|
||||
"Special value time.Time as unixnano-epoch",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeInteger,
|
||||
UnixNano: true,
|
||||
},
|
||||
refTime,
|
||||
refTime.UnixNano(),
|
||||
},
|
||||
{
|
||||
"Special value zero time",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
time.Time{},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Special value zero time pointer",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
new(time.Time),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Special value *time.Time as text",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
&refTime,
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
{
|
||||
"Special value untyped nil",
|
||||
ColumnDef{
|
||||
Nullable: true,
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Special value typed nil",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
(*time.Time)(nil),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Time formated as string",
|
||||
ColumnDef{
|
||||
IsTime: true,
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
refTime.In(time.Local).Format(time.RFC3339),
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
{
|
||||
"Nullable integer",
|
||||
ColumnDef{
|
||||
Type: sqlite.TypeInteger,
|
||||
Nullable: true,
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Not-Null integer",
|
||||
ColumnDef{
|
||||
Name: "test",
|
||||
Type: sqlite.TypeInteger,
|
||||
},
|
||||
nil,
|
||||
0,
|
||||
},
|
||||
{
|
||||
"Not-Null string",
|
||||
ColumnDef{
|
||||
Type: sqlite.TypeText,
|
||||
},
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Output, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
186
service/netquery/orm/query_runner.go
Normal file
186
service/netquery/orm/query_runner.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
type (
|
||||
// QueryOption can be specified at RunQuery to alter the behavior
|
||||
// of the executed query.
|
||||
QueryOption func(opts *queryOpts)
|
||||
|
||||
queryOpts struct {
|
||||
Transient bool
|
||||
Args []interface{}
|
||||
NamedArgs map[string]interface{}
|
||||
Result interface{}
|
||||
DecodeConfig DecodeConfig
|
||||
Schema TableSchema
|
||||
}
|
||||
)
|
||||
|
||||
// WithTransient marks the query as transient.
|
||||
//
|
||||
// Transient queries will not be cached for later
|
||||
// re-use after they have been prepared.
|
||||
func WithTransient() QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Transient = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithArgs adds a list of arguments for the query. Arguments
|
||||
// are applied in order.
|
||||
//
|
||||
// See SQL Language Expression documentation of SQLite for
|
||||
// details: https://sqlite.org/lang_expr.html
|
||||
func WithArgs(args ...interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Args = args
|
||||
}
|
||||
}
|
||||
|
||||
// WithNamedArgs adds args to the query. The query must used
|
||||
// named argument placeholders. According to the SQLite spec,
|
||||
// arguments must either start with ':', '@' or '$'.
|
||||
//
|
||||
// See SQL Language Expression documentation of SQLite for
|
||||
// details: https://sqlite.org/lang_expr.html
|
||||
func WithNamedArgs(args map[string]interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.NamedArgs = args
|
||||
}
|
||||
}
|
||||
|
||||
// WithSchema returns a query option that adds the given table
|
||||
// schema to the query.
|
||||
func WithSchema(tbl TableSchema) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Schema = tbl
|
||||
}
|
||||
}
|
||||
|
||||
// WithResult sets the result receiver. result is expected to
|
||||
// be a pointer to a slice of struct or map types.
|
||||
//
|
||||
// For decoding DecodeStmt is used to decode each
|
||||
// row into a new slice element. It thus supports special values
|
||||
// like time.Time. See DecodeStmt() and WithDecodeConfig() for
|
||||
// more information.
|
||||
func WithResult(result interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Result = result
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecodeConfig configures the DecodeConfig to use when
|
||||
// calling DecodeStmt to decode each row into the result slice.
|
||||
//
|
||||
// If not specified, DefaultDecodeConfig will be used.
|
||||
func WithDecodeConfig(cfg DecodeConfig) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.DecodeConfig = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// RunQuery executes the query stored in sql against the databased opened in
|
||||
// conn. Please refer to the documentation of QueryOption, especially WithResult()
|
||||
// for more information on how to retrieve the resulting rows.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var result []struct{
|
||||
// Count int `sqlite:"rowCount"`
|
||||
// }
|
||||
//
|
||||
// err := RunQuery(ctx, conn, "SELECT COUNT(*) AS rowCount FROM table", WithResult(&result))
|
||||
// fmt.Println(result[0].Count)
|
||||
func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...QueryOption) error {
|
||||
args := queryOpts{
|
||||
DecodeConfig: DefaultDecodeConfig,
|
||||
}
|
||||
|
||||
for _, fn := range modifiers {
|
||||
fn(&args)
|
||||
}
|
||||
|
||||
opts := &sqlitex.ExecOptions{
|
||||
Args: args.Args,
|
||||
Named: args.NamedArgs,
|
||||
}
|
||||
|
||||
var (
|
||||
sliceVal reflect.Value
|
||||
valElemType reflect.Type
|
||||
)
|
||||
|
||||
if args.Result != nil {
|
||||
target := args.Result
|
||||
outVal := reflect.ValueOf(target)
|
||||
if outVal.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("target must be a pointer, got %T", target)
|
||||
}
|
||||
|
||||
sliceVal = reflect.Indirect(outVal)
|
||||
if !sliceVal.IsValid() || sliceVal.IsNil() {
|
||||
newVal := reflect.Zero(outVal.Type().Elem())
|
||||
sliceVal.Set(newVal)
|
||||
}
|
||||
|
||||
kind := sliceVal.Kind()
|
||||
if kind != reflect.Slice {
|
||||
return fmt.Errorf("target but be pointer to slice, got %T", target)
|
||||
}
|
||||
valType := sliceVal.Type()
|
||||
valElemType = valType.Elem()
|
||||
|
||||
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
currentField := reflect.New(valElemType)
|
||||
|
||||
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||
resultDump := make(map[string]any)
|
||||
|
||||
for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ {
|
||||
name := stmt.ColumnName(colIdx)
|
||||
|
||||
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // TODO: handle type BLOB?
|
||||
case sqlite.TypeText:
|
||||
resultDump[name] = stmt.ColumnText(colIdx)
|
||||
case sqlite.TypeFloat:
|
||||
resultDump[name] = stmt.ColumnFloat(colIdx)
|
||||
case sqlite.TypeInteger:
|
||||
resultDump[name] = stmt.ColumnInt(colIdx)
|
||||
case sqlite.TypeNull:
|
||||
resultDump[name] = "<null>"
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%w: %+v", err, resultDump)
|
||||
}
|
||||
|
||||
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if args.Transient {
|
||||
err = sqlitex.ExecuteTransient(conn, sql, opts)
|
||||
} else {
|
||||
err = sqlitex.Execute(conn, sql, opts)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if args.Result != nil {
|
||||
reflect.Indirect(reflect.ValueOf(args.Result)).Set(sliceVal)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
288
service/netquery/orm/schema_builder.go
Normal file
288
service/netquery/orm/schema_builder.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
var errSkipStructField = errors.New("struct field should be skipped")
|
||||
|
||||
// Struct Tags.
|
||||
var (
|
||||
TagUnixNano = "unixnano"
|
||||
TagPrimaryKey = "primary"
|
||||
TagAutoIncrement = "autoincrement"
|
||||
TagTime = "time"
|
||||
TagNotNull = "not-null"
|
||||
TagNullable = "nullable"
|
||||
TagTypeInt = "integer"
|
||||
TagTypeText = "text"
|
||||
TagTypePrefixVarchar = "varchar"
|
||||
TagTypeBlob = "blob"
|
||||
TagTypeFloat = "float"
|
||||
TagTypePrefixDefault = "default="
|
||||
)
|
||||
|
||||
var sqlTypeMap = map[sqlite.ColumnType]string{
|
||||
sqlite.TypeBlob: "BLOB",
|
||||
sqlite.TypeFloat: "REAL",
|
||||
sqlite.TypeInteger: "INTEGER",
|
||||
sqlite.TypeText: "TEXT",
|
||||
}
|
||||
|
||||
type (
|
||||
// TableSchema defines a SQL table schema.
|
||||
TableSchema struct {
|
||||
Name string
|
||||
Columns []ColumnDef
|
||||
}
|
||||
|
||||
// ColumnDef defines a SQL column.
|
||||
ColumnDef struct { //nolint:maligned
|
||||
Name string
|
||||
Nullable bool
|
||||
Type sqlite.ColumnType
|
||||
GoType reflect.Type
|
||||
Length int
|
||||
PrimaryKey bool
|
||||
AutoIncrement bool
|
||||
UnixNano bool
|
||||
IsTime bool
|
||||
Default any
|
||||
}
|
||||
)
|
||||
|
||||
// GetColumnDef returns the column definition with the given name.
|
||||
func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
|
||||
for _, def := range ts.Columns {
|
||||
if def.Name == name {
|
||||
return &def
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateStatement build the CREATE SQL statement for the table.
|
||||
func (ts TableSchema) CreateStatement(databaseName string, ifNotExists bool) string {
|
||||
sql := "CREATE TABLE"
|
||||
if ifNotExists {
|
||||
sql += " IF NOT EXISTS"
|
||||
}
|
||||
name := ts.Name
|
||||
if databaseName != "" {
|
||||
name = databaseName + "." + ts.Name
|
||||
}
|
||||
|
||||
sql += " " + name + " ( "
|
||||
|
||||
for idx, col := range ts.Columns {
|
||||
sql += col.AsSQL()
|
||||
if idx < len(ts.Columns)-1 {
|
||||
sql += ", "
|
||||
}
|
||||
}
|
||||
|
||||
sql += " );"
|
||||
return sql
|
||||
}
|
||||
|
||||
// AsSQL builds the SQL column definition.
|
||||
func (def ColumnDef) AsSQL() string {
|
||||
sql := def.Name + " "
|
||||
|
||||
if def.Type == sqlite.TypeText && def.Length > 0 {
|
||||
sql += fmt.Sprintf("VARCHAR(%d)", def.Length)
|
||||
} else {
|
||||
sql += sqlTypeMap[def.Type]
|
||||
}
|
||||
|
||||
if def.PrimaryKey {
|
||||
sql += " PRIMARY KEY"
|
||||
}
|
||||
if def.AutoIncrement {
|
||||
sql += " AUTOINCREMENT"
|
||||
}
|
||||
if def.Default != nil {
|
||||
sql += " DEFAULT "
|
||||
switch def.Type { //nolint:exhaustive // TODO: handle types BLOB, NULL?
|
||||
case sqlite.TypeFloat:
|
||||
sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64) //nolint:forcetypeassert
|
||||
case sqlite.TypeInteger:
|
||||
sql += strconv.FormatInt(def.Default.(int64), 10) //nolint:forcetypeassert
|
||||
case sqlite.TypeText:
|
||||
sql += fmt.Sprintf("%q", def.Default.(string)) //nolint:forcetypeassert
|
||||
default:
|
||||
log.Errorf("unsupported default value: %q %q", def.Type, def.Default)
|
||||
sql = strings.TrimSuffix(sql, " DEFAULT ")
|
||||
}
|
||||
sql += " "
|
||||
}
|
||||
if !def.Nullable {
|
||||
sql += " NOT NULL"
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
// GenerateTableSchema generates a table schema from the given struct.
|
||||
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
|
||||
ts := &TableSchema{
|
||||
Name: name,
|
||||
}
|
||||
|
||||
val := reflect.Indirect(reflect.ValueOf(d))
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("%w, got %T", errStructExpected, d)
|
||||
}
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
fieldType := val.Type().Field(i)
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
def, err := getColumnDef(fieldType)
|
||||
if err != nil {
|
||||
if errors.Is(err, errSkipStructField) {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("struct field %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
ts.Columns = append(ts.Columns, *def)
|
||||
}
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
|
||||
def := &ColumnDef{
|
||||
Name: fieldType.Name,
|
||||
Nullable: fieldType.Type.Kind() == reflect.Ptr,
|
||||
}
|
||||
|
||||
ft := fieldType.Type
|
||||
|
||||
if fieldType.Type.Kind() == reflect.Ptr {
|
||||
ft = fieldType.Type.Elem()
|
||||
}
|
||||
|
||||
def.GoType = ft
|
||||
kind := NormalizeKind(ft.Kind())
|
||||
|
||||
switch kind { //nolint:exhaustive
|
||||
case reflect.Int, reflect.Uint:
|
||||
def.Type = sqlite.TypeInteger
|
||||
|
||||
case reflect.Float64:
|
||||
def.Type = sqlite.TypeFloat
|
||||
|
||||
case reflect.String:
|
||||
def.Type = sqlite.TypeText
|
||||
|
||||
case reflect.Slice:
|
||||
// only []byte/[]uint8 is supported
|
||||
if ft.Elem().Kind() != reflect.Uint8 {
|
||||
return nil, fmt.Errorf("slices of type %s is not supported", ft.Elem())
|
||||
}
|
||||
|
||||
def.Type = sqlite.TypeBlob
|
||||
}
|
||||
|
||||
if err := applyStructFieldTag(fieldType, def); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return def, nil
|
||||
}
|
||||
|
||||
// applyStructFieldTag parses the sqlite:"" struct field tag and update the column
|
||||
// definition def accordingly.
|
||||
func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
|
||||
parts := strings.Split(fieldType.Tag.Get("sqlite"), ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
if parts[0] == "-" {
|
||||
return errSkipStructField
|
||||
}
|
||||
|
||||
def.Name = parts[0]
|
||||
}
|
||||
|
||||
if len(parts) > 1 {
|
||||
for _, k := range parts[1:] {
|
||||
switch k {
|
||||
// column modifiers
|
||||
case TagPrimaryKey:
|
||||
def.PrimaryKey = true
|
||||
case TagAutoIncrement:
|
||||
def.AutoIncrement = true
|
||||
case TagNotNull:
|
||||
def.Nullable = false
|
||||
case TagNullable:
|
||||
def.Nullable = true
|
||||
case TagUnixNano:
|
||||
def.UnixNano = true
|
||||
case TagTime:
|
||||
def.IsTime = true
|
||||
|
||||
// basic column types
|
||||
case TagTypeInt:
|
||||
def.Type = sqlite.TypeInteger
|
||||
case TagTypeText:
|
||||
def.Type = sqlite.TypeText
|
||||
case TagTypeFloat:
|
||||
def.Type = sqlite.TypeFloat
|
||||
case TagTypeBlob:
|
||||
def.Type = sqlite.TypeBlob
|
||||
|
||||
// advanced column types
|
||||
default:
|
||||
if strings.HasPrefix(k, TagTypePrefixVarchar) {
|
||||
lenStr := strings.TrimSuffix(strings.TrimPrefix(k, TagTypePrefixVarchar+"("), ")")
|
||||
length, err := strconv.ParseInt(lenStr, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse varchar length %q: %w", lenStr, err)
|
||||
}
|
||||
|
||||
def.Type = sqlite.TypeText
|
||||
def.Length = int(length)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(k, TagTypePrefixDefault) {
|
||||
defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault)
|
||||
switch def.Type { //nolint:exhaustive
|
||||
case sqlite.TypeFloat:
|
||||
fv, err := strconv.ParseFloat(defaultValue, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err)
|
||||
}
|
||||
def.Default = fv
|
||||
case sqlite.TypeInteger:
|
||||
fv, err := strconv.ParseInt(defaultValue, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err)
|
||||
}
|
||||
def.Default = fv
|
||||
case sqlite.TypeText:
|
||||
def.Default = defaultValue
|
||||
case sqlite.TypeBlob:
|
||||
return fmt.Errorf("default values for TypeBlob not yet supported")
|
||||
default:
|
||||
return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
43
service/netquery/orm/schema_builder_test.go
Normal file
43
service/netquery/orm/schema_builder_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSchemaBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
Model interface{}
|
||||
ExpectedSQL string
|
||||
}{
|
||||
{
|
||||
"Simple",
|
||||
struct {
|
||||
ID int `sqlite:"id,primary,autoincrement"`
|
||||
Text string `sqlite:"text,nullable"`
|
||||
Int *int `sqlite:",not-null"`
|
||||
Float interface{} `sqlite:",float,nullable"`
|
||||
}{},
|
||||
`CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
|
||||
},
|
||||
{
|
||||
"Varchar",
|
||||
struct {
|
||||
S string `sqlite:",varchar(10)"`
|
||||
}{},
|
||||
`CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`,
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
c := cases[idx]
|
||||
|
||||
res, err := GenerateTableSchema(c.Name, c.Model)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user