Add basic, tailored SQL ORM mapper
This commit is contained in:
404
netquery/orm/decoder.go
Normal file
404
netquery/orm/decoder.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// Commonly used error messages when working with orm.
|
||||
var (
|
||||
errStructExpected = errors.New("encode: can only encode structs to maps")
|
||||
errStructPointerExpected = errors.New("decode: result must be pointer to a struct type or map[string]interface{}")
|
||||
errUnexpectedColumnType = errors.New("decode: unexpected column type")
|
||||
)
|
||||
|
||||
// constants used when transforming data to and from sqlite.
|
||||
var (
|
||||
// sqliteTimeFromat defines the string representation that is
|
||||
// expected by SQLite DATETIME functions.
|
||||
// Note that SQLite itself does not include support for a DATETIME
|
||||
// column type. Instead, dates and times are stored either as INTEGER,
|
||||
// TEXT or REAL.
|
||||
// This package provides support for time.Time being stored as TEXT (using a
|
||||
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
|
||||
// unixepoch and unixnano-epoch where the nano variant is not offically supported by
|
||||
// SQLITE).
|
||||
sqliteTimeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
type (
|
||||
|
||||
// Stmt describes the interface that must be implemented in order to
|
||||
// be decodable to a struct type using DecodeStmt. This interface is implemented
|
||||
// by *sqlite.Stmt.
|
||||
Stmt interface {
|
||||
ColumnCount() int
|
||||
ColumnName(int) string
|
||||
ColumnType(int) sqlite.ColumnType
|
||||
ColumnText(int) string
|
||||
ColumnBool(int) bool
|
||||
ColumnFloat(int) float64
|
||||
ColumnInt(int) int
|
||||
ColumnReader(int) *bytes.Reader
|
||||
}
|
||||
|
||||
// DecodeFunc is called for each non-basic type during decoding.
|
||||
DecodeFunc func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error)
|
||||
|
||||
DecodeConfig struct {
|
||||
DecodeHooks []DecodeFunc
|
||||
}
|
||||
)
|
||||
|
||||
// DecodeStmt decodes the current result row loaded in Stmt into the struct or map type result.
|
||||
// Decoding hooks configured in cfg are executed before trying to decode basic types and may
|
||||
// be specified to provide support for special types.
|
||||
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
|
||||
//
|
||||
func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeConfig) error {
|
||||
// make sure we got something to decode into ...
|
||||
if result == nil {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// fast path for decoding into a map
|
||||
if mp, ok := result.(*map[string]interface{}); ok {
|
||||
return decodeIntoMap(ctx, stmt, mp)
|
||||
}
|
||||
|
||||
// make sure we got a pointer in result
|
||||
if reflect.TypeOf(result).Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// make sure it's a poiter to a struct type
|
||||
t := reflect.ValueOf(result).Elem().Type()
|
||||
if t.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
}
|
||||
|
||||
// if result is a nil pointer make sure to allocate some space
|
||||
// for the resulting struct
|
||||
resultValue := reflect.ValueOf(result)
|
||||
if resultValue.IsNil() {
|
||||
resultValue.Set(
|
||||
reflect.New(t),
|
||||
)
|
||||
}
|
||||
|
||||
// we need access to the struct directly and not to the
|
||||
// pointer.
|
||||
target := reflect.Indirect(resultValue)
|
||||
|
||||
// create a lookup map from field name (or sqlite:"" tag)
|
||||
// to the field name
|
||||
lm := make(map[string]string)
|
||||
for i := 0; i < target.NumField(); i++ {
|
||||
fieldType := t.Field(i)
|
||||
|
||||
// skip unexported fields
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
lm[sqlColumnName(fieldType)] = fieldType.Name
|
||||
}
|
||||
|
||||
// iterate over all columns and assign them to the correct
|
||||
// fields
|
||||
for i := 0; i < stmt.ColumnCount(); i++ {
|
||||
colName := stmt.ColumnName(i)
|
||||
fieldName, ok := lm[colName]
|
||||
if !ok {
|
||||
// there's no target field for this column
|
||||
// so we can skip it
|
||||
continue
|
||||
}
|
||||
fieldType, _ := t.FieldByName(fieldName)
|
||||
|
||||
value := target.FieldByName(fieldName)
|
||||
|
||||
colType := stmt.ColumnType(i)
|
||||
|
||||
// if the column is reported as NULL we keep
|
||||
// the field as it is.
|
||||
// TODO(ppacher): should we set it to nil here?
|
||||
if colType == sqlite.TypeNull {
|
||||
continue
|
||||
}
|
||||
|
||||
// if value is a nil pointer we need to allocate some memory
|
||||
// first
|
||||
if getKind(value) == reflect.Ptr && value.IsNil() {
|
||||
storage := reflect.New(fieldType.Type.Elem())
|
||||
|
||||
value.Set(storage)
|
||||
|
||||
// make sure value actually points the
|
||||
// dereferenced target storage
|
||||
value = storage.Elem()
|
||||
}
|
||||
|
||||
// execute all decode hooks but make sure we use decodeBasic() as the
|
||||
// last one.
|
||||
columnValue, err := runDecodeHooks(
|
||||
i,
|
||||
stmt,
|
||||
fieldType,
|
||||
value,
|
||||
append(cfg.DecodeHooks, decodeBasic()),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we don't have a converted value now we try to
|
||||
// decode basic types
|
||||
if columnValue == nil {
|
||||
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
|
||||
}
|
||||
|
||||
//log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
|
||||
|
||||
// convert it to the target type if conversion is possible
|
||||
newValue := reflect.ValueOf(columnValue)
|
||||
if newValue.Type().ConvertibleTo(value.Type()) {
|
||||
newValue = newValue.Convert(value.Type())
|
||||
}
|
||||
|
||||
// assign the new value to the struct field.
|
||||
value.Set(newValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DatetimeDecoder is capable of decoding sqlite INTEGER or TEXT storage classes into
|
||||
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
|
||||
// decide between Unix or UnixNano epoch timestamps.
|
||||
//
|
||||
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing
|
||||
//
|
||||
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) {
|
||||
// we only care about "time.Time" here
|
||||
if outval.Type().String() != "time.Time" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch stmt.ColumnType(colIdx) {
|
||||
case sqlite.TypeInteger:
|
||||
// stored as unix-epoch, if unixnano is set in the struct field tag
|
||||
// we parse it with nano-second resolution
|
||||
// TODO(ppacher): actually split the tag value at "," and search
|
||||
// the slice for "unixnano"
|
||||
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
|
||||
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), nil
|
||||
}
|
||||
|
||||
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), nil
|
||||
|
||||
case sqlite.TypeText:
|
||||
// stored ISO8601 but does not have any timezone information
|
||||
// assigned so we always treat it as loc here.
|
||||
t, err := time.ParseInLocation(sqliteTimeFormat, stmt.ColumnText(colIdx), loc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
|
||||
case sqlite.TypeFloat:
|
||||
// stored as Julian day numbers
|
||||
return nil, fmt.Errorf("REAL storage type not support for time.Time")
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported storage type for time.Time: %s", outval.Type())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) error {
|
||||
if *mp == nil {
|
||||
*mp = make(map[string]interface{})
|
||||
}
|
||||
|
||||
for i := 0; i < stmt.ColumnCount(); i++ {
|
||||
var x interface{}
|
||||
val, err := decodeBasic()(i, stmt, reflect.StructField{}, reflect.ValueOf(&x).Elem())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
|
||||
}
|
||||
|
||||
(*mp)[stmt.ColumnName(i)] = val
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeBasic() DecodeFunc {
|
||||
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) {
|
||||
valueKind := getKind(outval)
|
||||
colType := stmt.ColumnType(colIdx)
|
||||
colName := stmt.ColumnName(colIdx)
|
||||
|
||||
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
|
||||
|
||||
switch valueKind {
|
||||
case reflect.String:
|
||||
if colType != sqlite.TypeText {
|
||||
return nil, errInvalidType
|
||||
}
|
||||
return stmt.ColumnText(colIdx), nil
|
||||
|
||||
case reflect.Bool:
|
||||
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
|
||||
// with INTEGER affinity.
|
||||
if colType != sqlite.TypeInteger {
|
||||
return nil, errInvalidType
|
||||
}
|
||||
return stmt.ColumnBool(colIdx), nil
|
||||
|
||||
case reflect.Float64:
|
||||
if colType != sqlite.TypeFloat {
|
||||
return nil, errInvalidType
|
||||
}
|
||||
return stmt.ColumnFloat(colIdx), nil
|
||||
|
||||
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
|
||||
if colType != sqlite.TypeInteger {
|
||||
return nil, errInvalidType
|
||||
}
|
||||
|
||||
return stmt.ColumnInt(colIdx), nil
|
||||
|
||||
case reflect.Slice:
|
||||
if outval.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, fmt.Errorf("slices other than []byte for BLOB are not supported")
|
||||
}
|
||||
|
||||
if colType != sqlite.TypeBlob {
|
||||
return nil, errInvalidType
|
||||
}
|
||||
|
||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
}
|
||||
|
||||
return columnValue, nil
|
||||
|
||||
case reflect.Interface:
|
||||
var (
|
||||
t reflect.Type
|
||||
x interface{}
|
||||
)
|
||||
switch colType {
|
||||
case sqlite.TypeBlob:
|
||||
t = reflect.TypeOf([]byte{})
|
||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
}
|
||||
x = columnValue
|
||||
|
||||
case sqlite.TypeFloat:
|
||||
t = reflect.TypeOf(float64(0))
|
||||
x = stmt.ColumnFloat(colIdx)
|
||||
|
||||
case sqlite.TypeInteger:
|
||||
t = reflect.TypeOf(int(0))
|
||||
x = stmt.ColumnInt(colIdx)
|
||||
|
||||
case sqlite.TypeText:
|
||||
t = reflect.TypeOf(string(""))
|
||||
x = stmt.ColumnText(colIdx)
|
||||
|
||||
case sqlite.TypeNull:
|
||||
t = nil
|
||||
x = nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported column type %s", colType)
|
||||
}
|
||||
|
||||
if t == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
target := reflect.New(t).Elem()
|
||||
target.Set(reflect.ValueOf(x))
|
||||
|
||||
return target.Interface(), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("cannot decode into %s", valueKind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sqlColumnName(fieldType reflect.StructField) string {
|
||||
tagValue, hasTag := fieldType.Tag.Lookup("sqlite")
|
||||
if !hasTag {
|
||||
return fieldType.Name
|
||||
}
|
||||
|
||||
parts := strings.Split(tagValue, ",")
|
||||
if parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return fieldType.Name
|
||||
}
|
||||
|
||||
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
|
||||
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
|
||||
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
|
||||
func runDecodeHooks(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
|
||||
for _, fn := range hooks {
|
||||
res, err := fn(colIdx, stmt, fieldDef, outval)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res != nil {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getKind returns the kind of value but normalized Int, Uint and Float varaints
|
||||
// to their base type.
|
||||
func getKind(val reflect.Value) reflect.Kind {
|
||||
kind := val.Kind()
|
||||
return normalizeKind(kind)
|
||||
}
|
||||
|
||||
func normalizeKind(kind reflect.Kind) reflect.Kind {
|
||||
switch {
|
||||
case kind >= reflect.Int && kind <= reflect.Int64:
|
||||
return reflect.Int
|
||||
case kind >= reflect.Uint && kind <= reflect.Uint64:
|
||||
return reflect.Uint
|
||||
case kind >= reflect.Float32 && kind <= reflect.Float64:
|
||||
return reflect.Float64
|
||||
default:
|
||||
return kind
|
||||
}
|
||||
}
|
||||
|
||||
var DefaultDecodeConfig = DecodeConfig{
|
||||
DecodeHooks: []DecodeFunc{
|
||||
DatetimeDecoder(time.UTC),
|
||||
},
|
||||
}
|
||||
475
netquery/orm/decoder_test.go
Normal file
475
netquery/orm/decoder_test.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
type testStmt struct {
|
||||
columns []string
|
||||
values []interface{}
|
||||
types []sqlite.ColumnType
|
||||
}
|
||||
|
||||
func (ts testStmt) ColumnCount() int { return len(ts.columns) }
|
||||
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
|
||||
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) }
|
||||
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) }
|
||||
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) }
|
||||
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) }
|
||||
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) }
|
||||
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
|
||||
|
||||
// compile time check
|
||||
var _ Stmt = new(testStmt)
|
||||
|
||||
type exampleFieldTypes struct {
|
||||
S string
|
||||
I int
|
||||
F float64
|
||||
B bool
|
||||
}
|
||||
|
||||
type examplePointerTypes struct {
|
||||
S *string
|
||||
I *int
|
||||
F *float64
|
||||
B *bool
|
||||
}
|
||||
|
||||
type exampleStructTags struct {
|
||||
S string `sqlite:"col_string"`
|
||||
I int `sqlite:"col_int"`
|
||||
}
|
||||
|
||||
type exampleIntConv struct {
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
I int
|
||||
}
|
||||
|
||||
type exampleBlobTypes struct {
|
||||
B []byte
|
||||
}
|
||||
|
||||
type exampleJSONRawTypes struct {
|
||||
B json.RawMessage
|
||||
}
|
||||
|
||||
type exampleTimeTypes struct {
|
||||
T time.Time
|
||||
TP *time.Time
|
||||
}
|
||||
|
||||
type exampleInterface struct {
|
||||
I interface{}
|
||||
IP *interface{}
|
||||
}
|
||||
|
||||
func (ett *exampleTimeTypes) Equal(other interface{}) bool {
|
||||
oett, ok := other.(*exampleTimeTypes)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
|
||||
}
|
||||
|
||||
type exampleTimeNano struct {
|
||||
T time.Time `sqlite:",unixnano"`
|
||||
}
|
||||
|
||||
func (etn *exampleTimeNano) Equal(other interface{}) bool {
|
||||
oetn, ok := other.(*exampleTimeNano)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return etn.T.Equal(oetn.T)
|
||||
}
|
||||
|
||||
func Test_Decoder(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Stmt testStmt
|
||||
Result interface{}
|
||||
Expected interface{}
|
||||
}{
|
||||
{
|
||||
"Decoding into nil is not allowed",
|
||||
testStmt{
|
||||
columns: nil,
|
||||
values: nil,
|
||||
types: nil,
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Decoding into basic types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into basic types with different order",
|
||||
testStmt{
|
||||
columns: []string{"I", "S", "B", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
"string value",
|
||||
true,
|
||||
1.2,
|
||||
},
|
||||
},
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into basic types with missing values",
|
||||
testStmt{
|
||||
columns: []string{"F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
F: 1.2,
|
||||
B: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into pointer types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
1.2,
|
||||
true,
|
||||
},
|
||||
},
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
s := "string value"
|
||||
i := 1
|
||||
f := 1.2
|
||||
b := true
|
||||
|
||||
return &examplePointerTypes{
|
||||
S: &s,
|
||||
I: &i,
|
||||
F: &f,
|
||||
B: &b,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into pointer types with missing values",
|
||||
testStmt{
|
||||
columns: []string{"S", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
true,
|
||||
},
|
||||
},
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
s := "string value"
|
||||
b := true
|
||||
|
||||
return &examplePointerTypes{
|
||||
S: &s,
|
||||
B: &b,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into fields with struct tags",
|
||||
testStmt{
|
||||
columns: []string{"col_string", "col_int"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
"string value",
|
||||
1,
|
||||
},
|
||||
},
|
||||
&exampleStructTags{},
|
||||
&exampleStructTags{
|
||||
S: "string value",
|
||||
I: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into correct int type",
|
||||
testStmt{
|
||||
columns: []string{"I8", "I16", "I32", "I64", "I"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
},
|
||||
},
|
||||
&exampleIntConv{},
|
||||
&exampleIntConv{
|
||||
1, 1, 1, 1, 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling NULL values for basic types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
// we use nil here but actually that does not matter
|
||||
nil,
|
||||
nil,
|
||||
1.0,
|
||||
},
|
||||
},
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
F: 1.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling NULL values for pointer types",
|
||||
testStmt{
|
||||
columns: []string{"S", "I", "F"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeNull,
|
||||
sqlite.TypeFloat,
|
||||
},
|
||||
values: []interface{}{
|
||||
// we use nil here but actually that does not matter
|
||||
nil,
|
||||
nil,
|
||||
1.0,
|
||||
},
|
||||
},
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
f := 1.0
|
||||
|
||||
return &examplePointerTypes{F: &f}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling blob types",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
&exampleBlobTypes{},
|
||||
&exampleBlobTypes{
|
||||
B: ([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling blob types as json.RawMessage",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
&exampleJSONRawTypes{},
|
||||
&exampleJSONRawTypes{
|
||||
B: (json.RawMessage)("hello world"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling time.Time and pointers to it",
|
||||
testStmt{
|
||||
columns: []string{"T", "TP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
int(refTime.Unix()),
|
||||
int(refTime.Unix()),
|
||||
},
|
||||
},
|
||||
&exampleTimeTypes{},
|
||||
&exampleTimeTypes{
|
||||
T: refTime,
|
||||
TP: &refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handling time.Time in nano-second resolution (struct tags)",
|
||||
testStmt{
|
||||
columns: []string{"T", "TP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeInteger,
|
||||
},
|
||||
values: []interface{}{
|
||||
int(refTime.UnixNano()),
|
||||
int(refTime.UnixNano()),
|
||||
},
|
||||
},
|
||||
&exampleTimeNano{},
|
||||
&exampleTimeNano{
|
||||
T: refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into interface",
|
||||
testStmt{
|
||||
columns: []string{"I", "IP"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeText,
|
||||
},
|
||||
values: []interface{}{
|
||||
"value1",
|
||||
"value2",
|
||||
},
|
||||
},
|
||||
&exampleInterface{},
|
||||
func() interface{} {
|
||||
var x interface{}
|
||||
x = "value2"
|
||||
|
||||
return &exampleInterface{
|
||||
I: "value1",
|
||||
IP: &x,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into map[string]interface{}",
|
||||
testStmt{
|
||||
columns: []string{"I", "F", "S", "B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeFloat,
|
||||
sqlite.TypeText,
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
1,
|
||||
1.1,
|
||||
"string value",
|
||||
[]byte("blob value"),
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"I": 1,
|
||||
"F": 1.1,
|
||||
"S": "string value",
|
||||
"B": []byte("blob value"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := DecodeStmt(ctx, c.Stmt, c.Result, DefaultDecodeConfig)
|
||||
if fn, ok := c.Expected.(func() interface{}); ok {
|
||||
c.Expected = fn()
|
||||
}
|
||||
|
||||
if c.Expected == nil {
|
||||
assert.Error(t, err, c.Desc)
|
||||
} else {
|
||||
assert.NoError(t, err, c.Desc)
|
||||
|
||||
if equaler, ok := c.Expected.(interface{ Equal(x interface{}) bool }); ok {
|
||||
assert.True(t, equaler.Equal(c.Result))
|
||||
} else {
|
||||
assert.Equal(t, c.Expected, c.Result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
124
netquery/orm/encoder.go
Normal file
124
netquery/orm/encoder.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
type (
|
||||
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
|
||||
|
||||
EncodeConfig struct {
|
||||
EncodeHooks []EncodeFunc
|
||||
}
|
||||
)
|
||||
|
||||
// EncodeAsMap returns a map that contains the value of each struct field of
|
||||
// r using the sqlite column name as a map key. It either uses the name of the
|
||||
// exported struct field or the value of the "sqlite" tag.
|
||||
func EncodeAsMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) {
|
||||
// make sure we work on a struct type
|
||||
val := reflect.Indirect(reflect.ValueOf(r))
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("%w, got %T", errStructExpected, r)
|
||||
}
|
||||
|
||||
res := make(map[string]interface{}, val.NumField())
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
fieldType := val.Type().Field(i)
|
||||
field := val.Field(i)
|
||||
|
||||
// skip unexported fields
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
colDev, err := getColumnDef(fieldType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
x, found, err := runEncodeHooks(colDev, fieldType.Type, field, cfg.EncodeHooks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
if reflect.Indirect(field).IsValid() {
|
||||
x = reflect.Indirect(field).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
res[keyPrefix+sqlColumnName(fieldType)] = x
|
||||
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||
return func(colDev *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
||||
// if fieldType holds a pointer we need to dereference the value
|
||||
ft := valType.String()
|
||||
if valType.Kind() == reflect.Ptr {
|
||||
ft = valType.Elem().String()
|
||||
val = reflect.Indirect(val)
|
||||
}
|
||||
|
||||
// we only care about "time.Time" here
|
||||
if ft != "time.Time" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// handle the zero time as a NULL.
|
||||
if !val.IsValid() || val.IsZero() {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
valInterface := val.Interface()
|
||||
t, ok := valInterface.(time.Time)
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
|
||||
}
|
||||
|
||||
switch colDev.Type {
|
||||
case sqlite.TypeInteger:
|
||||
if colDev.UnixNano {
|
||||
return t.UnixNano(), true, nil
|
||||
}
|
||||
return t.Unix(), true, nil
|
||||
case sqlite.TypeText:
|
||||
str := t.In(loc).Format(sqliteTimeFormat)
|
||||
|
||||
return str, true, nil
|
||||
}
|
||||
|
||||
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDev.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func runEncodeHooks(colDev *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
||||
for _, fn := range hooks {
|
||||
res, end, err := fn(colDev, valType, val)
|
||||
if err != nil {
|
||||
return res, false, err
|
||||
}
|
||||
|
||||
if end {
|
||||
return res, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
var DefaultEncodeConfig = EncodeConfig{
|
||||
EncodeHooks: []EncodeFunc{
|
||||
DatetimeEncoder(time.UTC),
|
||||
},
|
||||
}
|
||||
126
netquery/orm/encoder_test.go
Normal file
126
netquery/orm/encoder_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_EncodeAsMap(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Input interface{}
|
||||
Expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
"Encode basic types",
|
||||
struct {
|
||||
I int
|
||||
F float64
|
||||
S string
|
||||
B []byte
|
||||
}{
|
||||
I: 1,
|
||||
F: 1.2,
|
||||
S: "string",
|
||||
B: ([]byte)("bytes"),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 1,
|
||||
"F": 1.2,
|
||||
"S": "string",
|
||||
"B": ([]byte)("bytes"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Encode using struct tags",
|
||||
struct {
|
||||
I int `sqlite:"col_int"`
|
||||
S string `sqlite:"col_string"`
|
||||
}{
|
||||
I: 1,
|
||||
S: "string value",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"col_int": 1,
|
||||
"col_string": "string value",
|
||||
},
|
||||
},
|
||||
{
|
||||
"Ignore Private fields",
|
||||
struct {
|
||||
I int
|
||||
s string
|
||||
}{
|
||||
I: 1,
|
||||
s: "string value",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle Pointers",
|
||||
struct {
|
||||
I *int
|
||||
S *string
|
||||
}{
|
||||
I: new(int),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"I": 0,
|
||||
"S": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle time.Time types",
|
||||
struct {
|
||||
TinInt time.Time `sqlite:",integer,unixnano"`
|
||||
TinString time.Time `sqlite:",text"`
|
||||
}{
|
||||
TinInt: refTime,
|
||||
TinString: refTime,
|
||||
},
|
||||
map[string]interface{}{
|
||||
"TinInt": refTime.UnixNano(),
|
||||
"TinString": refTime.Format(sqliteTimeFormat),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Handle time.Time pointer types",
|
||||
struct {
|
||||
TinInt *time.Time `sqlite:",integer,unixnano"`
|
||||
TinString *time.Time `sqlite:",text"`
|
||||
Tnil1 *time.Time `sqlite:",text"`
|
||||
Tnil2 *time.Time `sqlite:",text"`
|
||||
}{
|
||||
TinInt: &refTime,
|
||||
TinString: &refTime,
|
||||
Tnil1: nil,
|
||||
Tnil2: (*time.Time)(nil),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"TinInt": refTime.UnixNano(),
|
||||
"TinString": refTime.Format(sqliteTimeFormat),
|
||||
"Tnil1": nil,
|
||||
"Tnil2": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
// t.Parallel()
|
||||
|
||||
res, err := EncodeAsMap(ctx, c.Input, "", DefaultEncodeConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Expected, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
123
netquery/orm/query_runner.go
Normal file
123
netquery/orm/query_runner.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
type (
|
||||
QueryOption func(opts *queryOpts)
|
||||
|
||||
queryOpts struct {
|
||||
Transient bool
|
||||
Args []interface{}
|
||||
NamedArgs map[string]interface{}
|
||||
Result interface{}
|
||||
DecodeConfig DecodeConfig
|
||||
}
|
||||
)
|
||||
|
||||
func WithTransient() QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Transient = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithArgs(args ...interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Args = args
|
||||
}
|
||||
}
|
||||
|
||||
func WithNamedArgs(args map[string]interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.NamedArgs = args
|
||||
}
|
||||
}
|
||||
|
||||
func WithResult(result interface{}) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Result = result
|
||||
}
|
||||
}
|
||||
|
||||
func WithDecodeConfig(cfg DecodeConfig) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.DecodeConfig = cfg
|
||||
}
|
||||
}
|
||||
|
||||
func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...QueryOption) error {
|
||||
args := queryOpts{
|
||||
DecodeConfig: DefaultDecodeConfig,
|
||||
}
|
||||
|
||||
for _, fn := range modifiers {
|
||||
fn(&args)
|
||||
}
|
||||
|
||||
opts := &sqlitex.ExecOptions{
|
||||
Args: args.Args,
|
||||
Named: args.NamedArgs,
|
||||
}
|
||||
|
||||
var (
|
||||
sliceVal reflect.Value
|
||||
valElemType reflect.Type
|
||||
)
|
||||
|
||||
if args.Result != nil {
|
||||
target := args.Result
|
||||
outVal := reflect.ValueOf(target)
|
||||
if outVal.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("target must be a pointer, got %T", target)
|
||||
}
|
||||
|
||||
sliceVal = reflect.Indirect(outVal)
|
||||
if !sliceVal.IsValid() || sliceVal.IsNil() {
|
||||
newVal := reflect.Zero(outVal.Type().Elem())
|
||||
sliceVal.Set(newVal)
|
||||
}
|
||||
|
||||
kind := sliceVal.Kind()
|
||||
if kind != reflect.Slice {
|
||||
return fmt.Errorf("target but be pointer to slice, got %T", target)
|
||||
}
|
||||
valType := sliceVal.Type()
|
||||
valElemType = valType.Elem()
|
||||
|
||||
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
var currentField reflect.Value
|
||||
|
||||
currentField = reflect.New(valElemType)
|
||||
|
||||
if err := DecodeStmt(ctx, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if args.Transient {
|
||||
err = sqlitex.ExecuteTransient(conn, sql, opts)
|
||||
} else {
|
||||
err = sqlitex.Execute(conn, sql, opts)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if args.Result != nil {
|
||||
reflect.Indirect(reflect.ValueOf(args.Result)).Set(sliceVal)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
220
netquery/orm/schema_builder.go
Normal file
220
netquery/orm/schema_builder.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
var (
|
||||
errSkipStructField = errors.New("struct field should be skipped")
|
||||
)
|
||||
|
||||
var (
|
||||
TagUnixNano = "unixnano"
|
||||
TagPrimaryKey = "primary"
|
||||
TagAutoIncrement = "autoincrement"
|
||||
TagNotNull = "not-null"
|
||||
TagNullable = "nullable"
|
||||
TagTypeInt = "integer"
|
||||
TagTypeText = "text"
|
||||
TagTypePrefixVarchar = "varchar"
|
||||
TagTypeBlob = "blob"
|
||||
TagTypeFloat = "float"
|
||||
)
|
||||
|
||||
var sqlTypeMap = map[sqlite.ColumnType]string{
|
||||
sqlite.TypeBlob: "BLOB",
|
||||
sqlite.TypeFloat: "REAL",
|
||||
sqlite.TypeInteger: "INTEGER",
|
||||
sqlite.TypeText: "TEXT",
|
||||
}
|
||||
|
||||
type (
|
||||
TableSchema struct {
|
||||
Name string
|
||||
Columns []ColumnDef
|
||||
}
|
||||
|
||||
ColumnDef struct {
|
||||
Name string
|
||||
Nullable bool
|
||||
Type sqlite.ColumnType
|
||||
Length int
|
||||
PrimaryKey bool
|
||||
AutoIncrement bool
|
||||
UnixNano bool
|
||||
}
|
||||
)
|
||||
|
||||
func (ts TableSchema) CreateStatement(ifNotExists bool) string {
|
||||
sql := "CREATE TABLE"
|
||||
if ifNotExists {
|
||||
sql += " IF NOT EXISTS"
|
||||
}
|
||||
sql += " " + ts.Name + " ( "
|
||||
|
||||
for idx, col := range ts.Columns {
|
||||
sql += col.AsSQL()
|
||||
if idx < len(ts.Columns)-1 {
|
||||
sql += ", "
|
||||
}
|
||||
}
|
||||
|
||||
sql += " );"
|
||||
return sql
|
||||
}
|
||||
|
||||
func (def ColumnDef) AsSQL() string {
|
||||
sql := def.Name + " "
|
||||
|
||||
if def.Type == sqlite.TypeText && def.Length > 0 {
|
||||
sql += fmt.Sprintf("VARCHAR(%d)", def.Length)
|
||||
} else {
|
||||
sql += sqlTypeMap[def.Type]
|
||||
}
|
||||
|
||||
if def.PrimaryKey {
|
||||
sql += " PRIMARY KEY"
|
||||
}
|
||||
if def.AutoIncrement {
|
||||
sql += " AUTOINCREMENT"
|
||||
}
|
||||
if !def.Nullable {
|
||||
sql += " NOT NULL"
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
|
||||
ts := &TableSchema{
|
||||
Name: name,
|
||||
}
|
||||
|
||||
val := reflect.Indirect(reflect.ValueOf(d))
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("%w, got %T", errStructExpected, d)
|
||||
}
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
fieldType := val.Type().Field(i)
|
||||
if !fieldType.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
def, err := getColumnDef(fieldType)
|
||||
if err != nil {
|
||||
if errors.Is(err, errSkipStructField) {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("struct field %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
ts.Columns = append(ts.Columns, *def)
|
||||
}
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
|
||||
def := &ColumnDef{
|
||||
Name: fieldType.Name,
|
||||
Nullable: fieldType.Type.Kind() == reflect.Ptr,
|
||||
}
|
||||
|
||||
ft := fieldType.Type
|
||||
|
||||
if fieldType.Type.Kind() == reflect.Ptr {
|
||||
ft = fieldType.Type.Elem()
|
||||
}
|
||||
|
||||
kind := normalizeKind(ft.Kind())
|
||||
|
||||
switch kind {
|
||||
case reflect.Int:
|
||||
def.Type = sqlite.TypeInteger
|
||||
|
||||
case reflect.Float64:
|
||||
def.Type = sqlite.TypeFloat
|
||||
|
||||
case reflect.String:
|
||||
def.Type = sqlite.TypeText
|
||||
|
||||
case reflect.Slice:
|
||||
// only []byte/[]uint8 is supported
|
||||
if ft.Elem().Kind() != reflect.Uint8 {
|
||||
return nil, fmt.Errorf("slices of type %s is not supported", ft.Elem())
|
||||
}
|
||||
|
||||
def.Type = sqlite.TypeBlob
|
||||
}
|
||||
|
||||
if err := applyStructFieldTag(fieldType, def); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return def, nil
|
||||
}
|
||||
|
||||
// applyStructFieldTag parses the sqlite:"" struct field tag and update the column
|
||||
// definition def accordingly.
|
||||
func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
|
||||
parts := strings.Split(fieldType.Tag.Get("sqlite"), ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
if parts[0] == "-" {
|
||||
return errSkipStructField
|
||||
}
|
||||
|
||||
def.Name = parts[0]
|
||||
}
|
||||
|
||||
if len(parts) > 1 {
|
||||
for _, k := range parts[1:] {
|
||||
switch k {
|
||||
// column modifieres
|
||||
case TagPrimaryKey:
|
||||
def.PrimaryKey = true
|
||||
case TagAutoIncrement:
|
||||
def.AutoIncrement = true
|
||||
case TagNotNull:
|
||||
def.Nullable = false
|
||||
case TagNullable:
|
||||
def.Nullable = true
|
||||
case TagUnixNano:
|
||||
def.UnixNano = true
|
||||
|
||||
// basic column types
|
||||
case TagTypeInt:
|
||||
def.Type = sqlite.TypeInteger
|
||||
case TagTypeText:
|
||||
def.Type = sqlite.TypeText
|
||||
case TagTypeFloat:
|
||||
def.Type = sqlite.TypeFloat
|
||||
case TagTypeBlob:
|
||||
def.Type = sqlite.TypeBlob
|
||||
|
||||
// advanced column types
|
||||
default:
|
||||
if strings.HasPrefix(k, TagTypePrefixVarchar) {
|
||||
lenStr := strings.TrimSuffix(strings.TrimPrefix(k, TagTypePrefixVarchar+"("), ")")
|
||||
length, err := strconv.ParseInt(lenStr, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse varchar length %q: %w", lenStr, err)
|
||||
}
|
||||
|
||||
def.Type = sqlite.TypeText
|
||||
def.Length = int(length)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
41
netquery/orm/schema_builder_test.go
Normal file
41
netquery/orm/schema_builder_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_SchemaBuilder(t *testing.T) {
|
||||
cases := []struct {
|
||||
Name string
|
||||
Model interface{}
|
||||
ExpectedSQL string
|
||||
}{
|
||||
{
|
||||
"Simple",
|
||||
struct {
|
||||
ID int `sqlite:"id,primary,autoincrement"`
|
||||
Text string `sqlite:"text,nullable"`
|
||||
Int *int `sqlite:",not-null"`
|
||||
Float interface{} `sqlite:",float,nullable"`
|
||||
}{},
|
||||
`CREATE TABLE Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
|
||||
},
|
||||
{
|
||||
"Varchar",
|
||||
struct {
|
||||
S string `sqlite:",varchar(10)"`
|
||||
}{},
|
||||
`CREATE TABLE Varchar ( S VARCHAR(10) NOT NULL );`,
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
c := cases[idx]
|
||||
|
||||
res, err := GenerateTableSchema(c.Name, c.Model)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.ExpectedSQL, res.CreateStatement(false))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user