Add basic, tailored SQL ORM mapper

This commit is contained in:
Patrick Pacher
2022-03-16 20:37:57 +01:00
parent f135ec3242
commit 62ec170b90
9 changed files with 1669 additions and 63 deletions

404
netquery/orm/decoder.go Normal file
View File

@@ -0,0 +1,404 @@
package orm
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"reflect"
"strings"
"time"
"zombiezen.com/go/sqlite"
)
// Commonly used error messages when working with orm.
var (
errStructExpected = errors.New("encode: can only encode structs to maps")
errStructPointerExpected = errors.New("decode: result must be pointer to a struct type or map[string]interface{}")
errUnexpectedColumnType = errors.New("decode: unexpected column type")
)
// constants used when transforming data to and from sqlite.
var (
// sqliteTimeFromat defines the string representation that is
// expected by SQLite DATETIME functions.
// Note that SQLite itself does not include support for a DATETIME
// column type. Instead, dates and times are stored either as INTEGER,
// TEXT or REAL.
// This package provides support for time.Time being stored as TEXT (using a
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
// unixepoch and unixnano-epoch where the nano variant is not offically supported by
// SQLITE).
sqliteTimeFormat = "2006-01-02 15:04:05"
)
type (
// Stmt describes the interface that must be implemented in order to
// be decodable to a struct type using DecodeStmt. This interface is implemented
// by *sqlite.Stmt.
Stmt interface {
ColumnCount() int
ColumnName(int) string
ColumnType(int) sqlite.ColumnType
ColumnText(int) string
ColumnBool(int) bool
ColumnFloat(int) float64
ColumnInt(int) int
ColumnReader(int) *bytes.Reader
}
// DecodeFunc is called for each non-basic type during decoding.
DecodeFunc func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error)
DecodeConfig struct {
DecodeHooks []DecodeFunc
}
)
// DecodeStmt decodes the current result row loaded in Stmt into the struct or map type result.
// Decoding hooks configured in cfg are executed before trying to decode basic types and may
// be specified to provide support for special types.
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
//
func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeConfig) error {
// make sure we got something to decode into ...
if result == nil {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// fast path for decoding into a map
if mp, ok := result.(*map[string]interface{}); ok {
return decodeIntoMap(ctx, stmt, mp)
}
// make sure we got a pointer in result
if reflect.TypeOf(result).Kind() != reflect.Ptr {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// make sure it's a poiter to a struct type
t := reflect.ValueOf(result).Elem().Type()
if t.Kind() != reflect.Struct {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// if result is a nil pointer make sure to allocate some space
// for the resulting struct
resultValue := reflect.ValueOf(result)
if resultValue.IsNil() {
resultValue.Set(
reflect.New(t),
)
}
// we need access to the struct directly and not to the
// pointer.
target := reflect.Indirect(resultValue)
// create a lookup map from field name (or sqlite:"" tag)
// to the field name
lm := make(map[string]string)
for i := 0; i < target.NumField(); i++ {
fieldType := t.Field(i)
// skip unexported fields
if !fieldType.IsExported() {
continue
}
lm[sqlColumnName(fieldType)] = fieldType.Name
}
// iterate over all columns and assign them to the correct
// fields
for i := 0; i < stmt.ColumnCount(); i++ {
colName := stmt.ColumnName(i)
fieldName, ok := lm[colName]
if !ok {
// there's no target field for this column
// so we can skip it
continue
}
fieldType, _ := t.FieldByName(fieldName)
value := target.FieldByName(fieldName)
colType := stmt.ColumnType(i)
// if the column is reported as NULL we keep
// the field as it is.
// TODO(ppacher): should we set it to nil here?
if colType == sqlite.TypeNull {
continue
}
// if value is a nil pointer we need to allocate some memory
// first
if getKind(value) == reflect.Ptr && value.IsNil() {
storage := reflect.New(fieldType.Type.Elem())
value.Set(storage)
// make sure value actually points the
// dereferenced target storage
value = storage.Elem()
}
// execute all decode hooks but make sure we use decodeBasic() as the
// last one.
columnValue, err := runDecodeHooks(
i,
stmt,
fieldType,
value,
append(cfg.DecodeHooks, decodeBasic()),
)
if err != nil {
return err
}
// if we don't have a converted value now we try to
// decode basic types
if columnValue == nil {
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
}
//log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
// convert it to the target type if conversion is possible
newValue := reflect.ValueOf(columnValue)
if newValue.Type().ConvertibleTo(value.Type()) {
newValue = newValue.Convert(value.Type())
}
// assign the new value to the struct field.
value.Set(newValue)
}
return nil
}
// DatetimeDecoder is capable of decoding sqlite INTEGER or TEXT storage classes into
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
// decide between Unix or UnixNano epoch timestamps.
//
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing
//
func DatetimeDecoder(loc *time.Location) DecodeFunc {
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) {
// we only care about "time.Time" here
if outval.Type().String() != "time.Time" {
return nil, nil
}
switch stmt.ColumnType(colIdx) {
case sqlite.TypeInteger:
// stored as unix-epoch, if unixnano is set in the struct field tag
// we parse it with nano-second resolution
// TODO(ppacher): actually split the tag value at "," and search
// the slice for "unixnano"
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), nil
}
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), nil
case sqlite.TypeText:
// stored ISO8601 but does not have any timezone information
// assigned so we always treat it as loc here.
t, err := time.ParseInLocation(sqliteTimeFormat, stmt.ColumnText(colIdx), loc)
if err != nil {
return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
}
return t, nil
case sqlite.TypeFloat:
// stored as Julian day numbers
return nil, fmt.Errorf("REAL storage type not support for time.Time")
default:
return nil, fmt.Errorf("unsupported storage type for time.Time: %s", outval.Type())
}
}
}
func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) error {
if *mp == nil {
*mp = make(map[string]interface{})
}
for i := 0; i < stmt.ColumnCount(); i++ {
var x interface{}
val, err := decodeBasic()(i, stmt, reflect.StructField{}, reflect.ValueOf(&x).Elem())
if err != nil {
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
}
(*mp)[stmt.ColumnName(i)] = val
}
return nil
}
func decodeBasic() DecodeFunc {
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) {
valueKind := getKind(outval)
colType := stmt.ColumnType(colIdx)
colName := stmt.ColumnName(colIdx)
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
switch valueKind {
case reflect.String:
if colType != sqlite.TypeText {
return nil, errInvalidType
}
return stmt.ColumnText(colIdx), nil
case reflect.Bool:
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
// with INTEGER affinity.
if colType != sqlite.TypeInteger {
return nil, errInvalidType
}
return stmt.ColumnBool(colIdx), nil
case reflect.Float64:
if colType != sqlite.TypeFloat {
return nil, errInvalidType
}
return stmt.ColumnFloat(colIdx), nil
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
if colType != sqlite.TypeInteger {
return nil, errInvalidType
}
return stmt.ColumnInt(colIdx), nil
case reflect.Slice:
if outval.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("slices other than []byte for BLOB are not supported")
}
if colType != sqlite.TypeBlob {
return nil, errInvalidType
}
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
if err != nil {
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
}
return columnValue, nil
case reflect.Interface:
var (
t reflect.Type
x interface{}
)
switch colType {
case sqlite.TypeBlob:
t = reflect.TypeOf([]byte{})
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
if err != nil {
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
}
x = columnValue
case sqlite.TypeFloat:
t = reflect.TypeOf(float64(0))
x = stmt.ColumnFloat(colIdx)
case sqlite.TypeInteger:
t = reflect.TypeOf(int(0))
x = stmt.ColumnInt(colIdx)
case sqlite.TypeText:
t = reflect.TypeOf(string(""))
x = stmt.ColumnText(colIdx)
case sqlite.TypeNull:
t = nil
x = nil
default:
return nil, fmt.Errorf("unsupported column type %s", colType)
}
if t == nil {
return nil, nil
}
target := reflect.New(t).Elem()
target.Set(reflect.ValueOf(x))
return target.Interface(), nil
default:
return nil, fmt.Errorf("cannot decode into %s", valueKind)
}
}
}
func sqlColumnName(fieldType reflect.StructField) string {
tagValue, hasTag := fieldType.Tag.Lookup("sqlite")
if !hasTag {
return fieldType.Name
}
parts := strings.Split(tagValue, ",")
if parts[0] != "" {
return parts[0]
}
return fieldType.Name
}
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
func runDecodeHooks(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
for _, fn := range hooks {
res, err := fn(colIdx, stmt, fieldDef, outval)
if err != nil {
return res, err
}
if res != nil {
return res, nil
}
}
return nil, nil
}
// getKind returns the kind of value but normalized Int, Uint and Float varaints
// to their base type.
func getKind(val reflect.Value) reflect.Kind {
kind := val.Kind()
return normalizeKind(kind)
}
func normalizeKind(kind reflect.Kind) reflect.Kind {
switch {
case kind >= reflect.Int && kind <= reflect.Int64:
return reflect.Int
case kind >= reflect.Uint && kind <= reflect.Uint64:
return reflect.Uint
case kind >= reflect.Float32 && kind <= reflect.Float64:
return reflect.Float64
default:
return kind
}
}
var DefaultDecodeConfig = DecodeConfig{
DecodeHooks: []DecodeFunc{
DatetimeDecoder(time.UTC),
},
}

View File

@@ -0,0 +1,475 @@
package orm
import (
"bytes"
"context"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"zombiezen.com/go/sqlite"
)
type testStmt struct {
columns []string
values []interface{}
types []sqlite.ColumnType
}
func (ts testStmt) ColumnCount() int { return len(ts.columns) }
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) }
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) }
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) }
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) }
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) }
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
// compile time check
var _ Stmt = new(testStmt)
type exampleFieldTypes struct {
S string
I int
F float64
B bool
}
type examplePointerTypes struct {
S *string
I *int
F *float64
B *bool
}
type exampleStructTags struct {
S string `sqlite:"col_string"`
I int `sqlite:"col_int"`
}
type exampleIntConv struct {
I8 int8
I16 int16
I32 int32
I64 int64
I int
}
type exampleBlobTypes struct {
B []byte
}
type exampleJSONRawTypes struct {
B json.RawMessage
}
type exampleTimeTypes struct {
T time.Time
TP *time.Time
}
type exampleInterface struct {
I interface{}
IP *interface{}
}
func (ett *exampleTimeTypes) Equal(other interface{}) bool {
oett, ok := other.(*exampleTimeTypes)
if !ok {
return false
}
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
}
type exampleTimeNano struct {
T time.Time `sqlite:",unixnano"`
}
func (etn *exampleTimeNano) Equal(other interface{}) bool {
oetn, ok := other.(*exampleTimeNano)
if !ok {
return false
}
return etn.T.Equal(oetn.T)
}
func Test_Decoder(t *testing.T) {
ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
cases := []struct {
Desc string
Stmt testStmt
Result interface{}
Expected interface{}
}{
{
"Decoding into nil is not allowed",
testStmt{
columns: nil,
values: nil,
types: nil,
},
nil,
nil,
},
{
"Decoding into basic types",
testStmt{
columns: []string{"S", "I", "F", "B"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeInteger,
sqlite.TypeFloat,
sqlite.TypeInteger,
},
values: []interface{}{
"string value",
1,
1.2,
true,
},
},
&exampleFieldTypes{},
&exampleFieldTypes{
S: "string value",
I: 1,
F: 1.2,
B: true,
},
},
{
"Decoding into basic types with different order",
testStmt{
columns: []string{"I", "S", "B", "F"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeText,
sqlite.TypeInteger,
sqlite.TypeFloat,
},
values: []interface{}{
1,
"string value",
true,
1.2,
},
},
&exampleFieldTypes{},
&exampleFieldTypes{
S: "string value",
I: 1,
F: 1.2,
B: true,
},
},
{
"Decoding into basic types with missing values",
testStmt{
columns: []string{"F", "B"},
types: []sqlite.ColumnType{
sqlite.TypeFloat,
sqlite.TypeInteger,
},
values: []interface{}{
1.2,
true,
},
},
&exampleFieldTypes{},
&exampleFieldTypes{
F: 1.2,
B: true,
},
},
{
"Decoding into pointer types",
testStmt{
columns: []string{"S", "I", "F", "B"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeInteger,
sqlite.TypeFloat,
sqlite.TypeInteger,
},
values: []interface{}{
"string value",
1,
1.2,
true,
},
},
&examplePointerTypes{},
func() interface{} {
s := "string value"
i := 1
f := 1.2
b := true
return &examplePointerTypes{
S: &s,
I: &i,
F: &f,
B: &b,
}
},
},
{
"Decoding into pointer types with missing values",
testStmt{
columns: []string{"S", "B"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeInteger,
sqlite.TypeFloat,
sqlite.TypeInteger,
},
values: []interface{}{
"string value",
true,
},
},
&examplePointerTypes{},
func() interface{} {
s := "string value"
b := true
return &examplePointerTypes{
S: &s,
B: &b,
}
},
},
{
"Decoding into fields with struct tags",
testStmt{
columns: []string{"col_string", "col_int"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeInteger,
},
values: []interface{}{
"string value",
1,
},
},
&exampleStructTags{},
&exampleStructTags{
S: "string value",
I: 1,
},
},
{
"Decoding into correct int type",
testStmt{
columns: []string{"I8", "I16", "I32", "I64", "I"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeInteger,
sqlite.TypeInteger,
sqlite.TypeInteger,
sqlite.TypeInteger,
},
values: []interface{}{
1,
1,
1,
1,
1,
},
},
&exampleIntConv{},
&exampleIntConv{
1, 1, 1, 1, 1,
},
},
{
"Handling NULL values for basic types",
testStmt{
columns: []string{"S", "I", "F"},
types: []sqlite.ColumnType{
sqlite.TypeNull,
sqlite.TypeNull,
sqlite.TypeFloat,
},
values: []interface{}{
// we use nil here but actually that does not matter
nil,
nil,
1.0,
},
},
&exampleFieldTypes{},
&exampleFieldTypes{
F: 1.0,
},
},
{
"Handling NULL values for pointer types",
testStmt{
columns: []string{"S", "I", "F"},
types: []sqlite.ColumnType{
sqlite.TypeNull,
sqlite.TypeNull,
sqlite.TypeFloat,
},
values: []interface{}{
// we use nil here but actually that does not matter
nil,
nil,
1.0,
},
},
&examplePointerTypes{},
func() interface{} {
f := 1.0
return &examplePointerTypes{F: &f}
},
},
{
"Handling blob types",
testStmt{
columns: []string{"B"},
types: []sqlite.ColumnType{
sqlite.TypeBlob,
},
values: []interface{}{
([]byte)("hello world"),
},
},
&exampleBlobTypes{},
&exampleBlobTypes{
B: ([]byte)("hello world"),
},
},
{
"Handling blob types as json.RawMessage",
testStmt{
columns: []string{"B"},
types: []sqlite.ColumnType{
sqlite.TypeBlob,
},
values: []interface{}{
([]byte)("hello world"),
},
},
&exampleJSONRawTypes{},
&exampleJSONRawTypes{
B: (json.RawMessage)("hello world"),
},
},
{
"Handling time.Time and pointers to it",
testStmt{
columns: []string{"T", "TP"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeInteger,
},
values: []interface{}{
int(refTime.Unix()),
int(refTime.Unix()),
},
},
&exampleTimeTypes{},
&exampleTimeTypes{
T: refTime,
TP: &refTime,
},
},
{
"Handling time.Time in nano-second resolution (struct tags)",
testStmt{
columns: []string{"T", "TP"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeInteger,
},
values: []interface{}{
int(refTime.UnixNano()),
int(refTime.UnixNano()),
},
},
&exampleTimeNano{},
&exampleTimeNano{
T: refTime,
},
},
{
"Decoding into interface",
testStmt{
columns: []string{"I", "IP"},
types: []sqlite.ColumnType{
sqlite.TypeText,
sqlite.TypeText,
},
values: []interface{}{
"value1",
"value2",
},
},
&exampleInterface{},
func() interface{} {
var x interface{}
x = "value2"
return &exampleInterface{
I: "value1",
IP: &x,
}
},
},
{
"Decoding into map[string]interface{}",
testStmt{
columns: []string{"I", "F", "S", "B"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeFloat,
sqlite.TypeText,
sqlite.TypeBlob,
},
values: []interface{}{
1,
1.1,
"string value",
[]byte("blob value"),
},
},
new(map[string]interface{}),
&map[string]interface{}{
"I": 1,
"F": 1.1,
"S": "string value",
"B": []byte("blob value"),
},
},
}
for idx := range cases {
c := cases[idx]
t.Run(c.Desc, func(t *testing.T) {
t.Parallel()
err := DecodeStmt(ctx, c.Stmt, c.Result, DefaultDecodeConfig)
if fn, ok := c.Expected.(func() interface{}); ok {
c.Expected = fn()
}
if c.Expected == nil {
assert.Error(t, err, c.Desc)
} else {
assert.NoError(t, err, c.Desc)
if equaler, ok := c.Expected.(interface{ Equal(x interface{}) bool }); ok {
assert.True(t, equaler.Equal(c.Result))
} else {
assert.Equal(t, c.Expected, c.Result)
}
}
})
}
}

124
netquery/orm/encoder.go Normal file
View File

@@ -0,0 +1,124 @@
package orm
import (
"context"
"fmt"
"reflect"
"time"
"zombiezen.com/go/sqlite"
)
type (
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
EncodeConfig struct {
EncodeHooks []EncodeFunc
}
)
// EncodeAsMap returns a map that contains the value of each struct field of
// r using the sqlite column name as a map key. It either uses the name of the
// exported struct field or the value of the "sqlite" tag.
func EncodeAsMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) {
// make sure we work on a struct type
val := reflect.Indirect(reflect.ValueOf(r))
if val.Kind() != reflect.Struct {
return nil, fmt.Errorf("%w, got %T", errStructExpected, r)
}
res := make(map[string]interface{}, val.NumField())
for i := 0; i < val.NumField(); i++ {
fieldType := val.Type().Field(i)
field := val.Field(i)
// skip unexported fields
if !fieldType.IsExported() {
continue
}
colDev, err := getColumnDef(fieldType)
if err != nil {
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
}
x, found, err := runEncodeHooks(colDev, fieldType.Type, field, cfg.EncodeHooks)
if err != nil {
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
}
if !found {
if reflect.Indirect(field).IsValid() {
x = reflect.Indirect(field).Interface()
}
}
res[keyPrefix+sqlColumnName(fieldType)] = x
}
return res, nil
}
func DatetimeEncoder(loc *time.Location) EncodeFunc {
return func(colDev *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
// if fieldType holds a pointer we need to dereference the value
ft := valType.String()
if valType.Kind() == reflect.Ptr {
ft = valType.Elem().String()
val = reflect.Indirect(val)
}
// we only care about "time.Time" here
if ft != "time.Time" {
return nil, false, nil
}
// handle the zero time as a NULL.
if !val.IsValid() || val.IsZero() {
return nil, true, nil
}
valInterface := val.Interface()
t, ok := valInterface.(time.Time)
if !ok {
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
}
switch colDev.Type {
case sqlite.TypeInteger:
if colDev.UnixNano {
return t.UnixNano(), true, nil
}
return t.Unix(), true, nil
case sqlite.TypeText:
str := t.In(loc).Format(sqliteTimeFormat)
return str, true, nil
}
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDev.Type)
}
}
func runEncodeHooks(colDev *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
for _, fn := range hooks {
res, end, err := fn(colDev, valType, val)
if err != nil {
return res, false, err
}
if end {
return res, true, nil
}
}
return nil, false, nil
}
var DefaultEncodeConfig = EncodeConfig{
EncodeHooks: []EncodeFunc{
DatetimeEncoder(time.UTC),
},
}

View File

@@ -0,0 +1,126 @@
package orm
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_EncodeAsMap(t *testing.T) {
ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
cases := []struct {
Desc string
Input interface{}
Expected map[string]interface{}
}{
{
"Encode basic types",
struct {
I int
F float64
S string
B []byte
}{
I: 1,
F: 1.2,
S: "string",
B: ([]byte)("bytes"),
},
map[string]interface{}{
"I": 1,
"F": 1.2,
"S": "string",
"B": ([]byte)("bytes"),
},
},
{
"Encode using struct tags",
struct {
I int `sqlite:"col_int"`
S string `sqlite:"col_string"`
}{
I: 1,
S: "string value",
},
map[string]interface{}{
"col_int": 1,
"col_string": "string value",
},
},
{
"Ignore Private fields",
struct {
I int
s string
}{
I: 1,
s: "string value",
},
map[string]interface{}{
"I": 1,
},
},
{
"Handle Pointers",
struct {
I *int
S *string
}{
I: new(int),
},
map[string]interface{}{
"I": 0,
"S": nil,
},
},
{
"Handle time.Time types",
struct {
TinInt time.Time `sqlite:",integer,unixnano"`
TinString time.Time `sqlite:",text"`
}{
TinInt: refTime,
TinString: refTime,
},
map[string]interface{}{
"TinInt": refTime.UnixNano(),
"TinString": refTime.Format(sqliteTimeFormat),
},
},
{
"Handle time.Time pointer types",
struct {
TinInt *time.Time `sqlite:",integer,unixnano"`
TinString *time.Time `sqlite:",text"`
Tnil1 *time.Time `sqlite:",text"`
Tnil2 *time.Time `sqlite:",text"`
}{
TinInt: &refTime,
TinString: &refTime,
Tnil1: nil,
Tnil2: (*time.Time)(nil),
},
map[string]interface{}{
"TinInt": refTime.UnixNano(),
"TinString": refTime.Format(sqliteTimeFormat),
"Tnil1": nil,
"Tnil2": nil,
},
},
}
for idx := range cases {
c := cases[idx]
t.Run(c.Desc, func(t *testing.T) {
// t.Parallel()
res, err := EncodeAsMap(ctx, c.Input, "", DefaultEncodeConfig)
assert.NoError(t, err)
assert.Equal(t, c.Expected, res)
})
}
}

View File

@@ -0,0 +1,123 @@
package orm
import (
"context"
"fmt"
"reflect"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
type (
QueryOption func(opts *queryOpts)
queryOpts struct {
Transient bool
Args []interface{}
NamedArgs map[string]interface{}
Result interface{}
DecodeConfig DecodeConfig
}
)
func WithTransient() QueryOption {
return func(opts *queryOpts) {
opts.Transient = true
}
}
func WithArgs(args ...interface{}) QueryOption {
return func(opts *queryOpts) {
opts.Args = args
}
}
func WithNamedArgs(args map[string]interface{}) QueryOption {
return func(opts *queryOpts) {
opts.NamedArgs = args
}
}
func WithResult(result interface{}) QueryOption {
return func(opts *queryOpts) {
opts.Result = result
}
}
func WithDecodeConfig(cfg DecodeConfig) QueryOption {
return func(opts *queryOpts) {
opts.DecodeConfig = cfg
}
}
func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...QueryOption) error {
args := queryOpts{
DecodeConfig: DefaultDecodeConfig,
}
for _, fn := range modifiers {
fn(&args)
}
opts := &sqlitex.ExecOptions{
Args: args.Args,
Named: args.NamedArgs,
}
var (
sliceVal reflect.Value
valElemType reflect.Type
)
if args.Result != nil {
target := args.Result
outVal := reflect.ValueOf(target)
if outVal.Kind() != reflect.Ptr {
return fmt.Errorf("target must be a pointer, got %T", target)
}
sliceVal = reflect.Indirect(outVal)
if !sliceVal.IsValid() || sliceVal.IsNil() {
newVal := reflect.Zero(outVal.Type().Elem())
sliceVal.Set(newVal)
}
kind := sliceVal.Kind()
if kind != reflect.Slice {
return fmt.Errorf("target but be pointer to slice, got %T", target)
}
valType := sliceVal.Type()
valElemType = valType.Elem()
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
var currentField reflect.Value
currentField = reflect.New(valElemType)
if err := DecodeStmt(ctx, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
return err
}
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))
return nil
}
}
var err error
if args.Transient {
err = sqlitex.ExecuteTransient(conn, sql, opts)
} else {
err = sqlitex.Execute(conn, sql, opts)
}
if err != nil {
return err
}
if args.Result != nil {
reflect.Indirect(reflect.ValueOf(args.Result)).Set(sliceVal)
}
return nil
}

View File

@@ -0,0 +1,220 @@
package orm
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"zombiezen.com/go/sqlite"
)
var (
errSkipStructField = errors.New("struct field should be skipped")
)
var (
TagUnixNano = "unixnano"
TagPrimaryKey = "primary"
TagAutoIncrement = "autoincrement"
TagNotNull = "not-null"
TagNullable = "nullable"
TagTypeInt = "integer"
TagTypeText = "text"
TagTypePrefixVarchar = "varchar"
TagTypeBlob = "blob"
TagTypeFloat = "float"
)
var sqlTypeMap = map[sqlite.ColumnType]string{
sqlite.TypeBlob: "BLOB",
sqlite.TypeFloat: "REAL",
sqlite.TypeInteger: "INTEGER",
sqlite.TypeText: "TEXT",
}
type (
TableSchema struct {
Name string
Columns []ColumnDef
}
ColumnDef struct {
Name string
Nullable bool
Type sqlite.ColumnType
Length int
PrimaryKey bool
AutoIncrement bool
UnixNano bool
}
)
func (ts TableSchema) CreateStatement(ifNotExists bool) string {
sql := "CREATE TABLE"
if ifNotExists {
sql += " IF NOT EXISTS"
}
sql += " " + ts.Name + " ( "
for idx, col := range ts.Columns {
sql += col.AsSQL()
if idx < len(ts.Columns)-1 {
sql += ", "
}
}
sql += " );"
return sql
}
func (def ColumnDef) AsSQL() string {
sql := def.Name + " "
if def.Type == sqlite.TypeText && def.Length > 0 {
sql += fmt.Sprintf("VARCHAR(%d)", def.Length)
} else {
sql += sqlTypeMap[def.Type]
}
if def.PrimaryKey {
sql += " PRIMARY KEY"
}
if def.AutoIncrement {
sql += " AUTOINCREMENT"
}
if !def.Nullable {
sql += " NOT NULL"
}
return sql
}
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
ts := &TableSchema{
Name: name,
}
val := reflect.Indirect(reflect.ValueOf(d))
if val.Kind() != reflect.Struct {
return nil, fmt.Errorf("%w, got %T", errStructExpected, d)
}
for i := 0; i < val.NumField(); i++ {
fieldType := val.Type().Field(i)
if !fieldType.IsExported() {
continue
}
def, err := getColumnDef(fieldType)
if err != nil {
if errors.Is(err, errSkipStructField) {
continue
}
return nil, fmt.Errorf("struct field %s: %w", fieldType.Name, err)
}
ts.Columns = append(ts.Columns, *def)
}
return ts, nil
}
func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
def := &ColumnDef{
Name: fieldType.Name,
Nullable: fieldType.Type.Kind() == reflect.Ptr,
}
ft := fieldType.Type
if fieldType.Type.Kind() == reflect.Ptr {
ft = fieldType.Type.Elem()
}
kind := normalizeKind(ft.Kind())
switch kind {
case reflect.Int:
def.Type = sqlite.TypeInteger
case reflect.Float64:
def.Type = sqlite.TypeFloat
case reflect.String:
def.Type = sqlite.TypeText
case reflect.Slice:
// only []byte/[]uint8 is supported
if ft.Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("slices of type %s is not supported", ft.Elem())
}
def.Type = sqlite.TypeBlob
}
if err := applyStructFieldTag(fieldType, def); err != nil {
return nil, err
}
return def, nil
}
// applyStructFieldTag parses the sqlite:"" struct field tag and update the column
// definition def accordingly.
func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
parts := strings.Split(fieldType.Tag.Get("sqlite"), ",")
if len(parts) > 0 && parts[0] != "" {
if parts[0] == "-" {
return errSkipStructField
}
def.Name = parts[0]
}
if len(parts) > 1 {
for _, k := range parts[1:] {
switch k {
// column modifieres
case TagPrimaryKey:
def.PrimaryKey = true
case TagAutoIncrement:
def.AutoIncrement = true
case TagNotNull:
def.Nullable = false
case TagNullable:
def.Nullable = true
case TagUnixNano:
def.UnixNano = true
// basic column types
case TagTypeInt:
def.Type = sqlite.TypeInteger
case TagTypeText:
def.Type = sqlite.TypeText
case TagTypeFloat:
def.Type = sqlite.TypeFloat
case TagTypeBlob:
def.Type = sqlite.TypeBlob
// advanced column types
default:
if strings.HasPrefix(k, TagTypePrefixVarchar) {
lenStr := strings.TrimSuffix(strings.TrimPrefix(k, TagTypePrefixVarchar+"("), ")")
length, err := strconv.ParseInt(lenStr, 10, 0)
if err != nil {
return fmt.Errorf("failed to parse varchar length %q: %w", lenStr, err)
}
def.Type = sqlite.TypeText
def.Length = int(length)
}
}
}
}
return nil
}

View File

@@ -0,0 +1,41 @@
package orm
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_SchemaBuilder(t *testing.T) {
cases := []struct {
Name string
Model interface{}
ExpectedSQL string
}{
{
"Simple",
struct {
ID int `sqlite:"id,primary,autoincrement"`
Text string `sqlite:"text,nullable"`
Int *int `sqlite:",not-null"`
Float interface{} `sqlite:",float,nullable"`
}{},
`CREATE TABLE Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
},
{
"Varchar",
struct {
S string `sqlite:",varchar(10)"`
}{},
`CREATE TABLE Varchar ( S VARCHAR(10) NOT NULL );`,
},
}
for idx := range cases {
c := cases[idx]
res, err := GenerateTableSchema(c.Name, c.Model)
assert.NoError(t, err)
assert.Equal(t, c.ExpectedSQL, res.CreateStatement(false))
}
}