diff --git a/base/database/storage/errors.go b/base/database/storage/errors.go index ecc28530..f8534046 100644 --- a/base/database/storage/errors.go +++ b/base/database/storage/errors.go @@ -4,5 +4,6 @@ import "errors" // Errors for storages. var ( - ErrNotFound = errors.New("storage entry not found") + ErrNotFound = errors.New("storage entry not found") + ErrRecordMalformed = errors.New("record is malformed") ) diff --git a/base/database/storage/sqlite/prepared.go b/base/database/storage/sqlite/prepared.go index 11136dd8..797b53a4 100644 --- a/base/database/storage/sqlite/prepared.go +++ b/base/database/storage/sqlite/prepared.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "fmt" "strconv" "github.com/stephenafamo/bob" @@ -9,6 +10,7 @@ import ( "github.com/stephenafamo/bob/expr" "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" "github.com/safing/portmaster/base/database/storage/sqlite/models" "github.com/safing/structures/dsd" ) @@ -83,8 +85,23 @@ func writeWithPreparedStatement(ctx context.Context, pStmt *bob.StdPrepared, r r r.Lock() defer r.Unlock() - // Serialize to JSON. - data, err := r.MarshalDataOnly(r, dsd.JSON) + // default serialization format - JSON + format := uint8(dsd.JSON) + + // For wrapped records, check the required format + if r.IsWrapped() { + wrapper, ok := r.(*record.Wrapper) + if !ok { + return fmt.Errorf("%w: reports to be wrapped but is not of type *record.Wrapper", storage.ErrRecordMalformed) + } + format, ok = dsd.ValidateSerializationFormat(wrapper.Format) + if !ok { + return dsd.ErrIncompatibleFormat + } + } + + // Serialize. + data, err := r.MarshalDataOnly(r, format) if err != nil { return err } @@ -94,7 +111,7 @@ func writeWithPreparedStatement(ctx context.Context, pStmt *bob.StdPrepared, r r // Insert. if len(data) > 0 { - format := strconv.Itoa(dsd.JSON) + format := strconv.Itoa(int(format)) _, err = pStmt.ExecContext( ctx, r.DatabaseKey(), diff --git a/base/database/storage/sqlite/sqlite.go b/base/database/storage/sqlite/sqlite.go index fe5427ff..2a54cfa8 100644 --- a/base/database/storage/sqlite/sqlite.go +++ b/base/database/storage/sqlite/sqlite.go @@ -161,13 +161,28 @@ func (db *SQLite) putRecord(r record.Record, tx *bob.Tx) (record.Record, error) defer r.Unlock() } - // Serialize to JSON. - data, err := r.MarshalDataOnly(r, dsd.JSON) + // default serialization format - JSON + format := uint8(dsd.JSON) + + // For wrapped records, check the required format + if r.IsWrapped() { + wrapper, ok := r.(*record.Wrapper) + if !ok { + return nil, fmt.Errorf("%w: reports to be wrapped but is not of type *record.Wrapper", storage.ErrRecordMalformed) + } + format, ok = dsd.ValidateSerializationFormat(wrapper.Format) + if !ok { + return nil, dsd.ErrIncompatibleFormat + } + } + + // Serialize. + data, err := r.MarshalDataOnly(r, format) if err != nil { return nil, err } // Prepare for setter. - setFormat := omitnull.From(int16(dsd.JSON)) + setFormat := omitnull.From(int16(format)) setData := omitnull.From(data) if len(data) == 0 { setFormat.Null()