diff --git a/netquery/database.go b/netquery/database.go index f9ce1bce..a6333a77 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "strings" + "sync" "time" "zombiezen.com/go/sqlite" @@ -48,6 +49,9 @@ type ( Schema *orm.TableSchema pool *puddle.Pool[*sqlite.Conn] + + l sync.Mutex + writeConn *sqlite.Conn } // Conn is a network connection that is stored in a SQLite database and accepted @@ -111,16 +115,14 @@ func New(path string) (*Database, error) { constructor := func(ctx context.Context) (*sqlite.Conn, error) { c, err := sqlite.OpenConn( path, - sqlite.OpenCreate, - sqlite.OpenReadWrite, + sqlite.OpenReadOnly, sqlite.OpenNoMutex, - sqlite.OpenWAL, sqlite.OpenSharedCache, sqlite.OpenMemory, sqlite.OpenURI, ) if err != nil { - return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) + return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", path, err) } return c, nil @@ -137,10 +139,24 @@ func New(path string) (*Database, error) { return nil, err } + writeConn, err := sqlite.OpenConn( + path, + sqlite.OpenCreate, + sqlite.OpenReadWrite, + sqlite.OpenNoMutex, + sqlite.OpenWAL, + sqlite.OpenSharedCache, + sqlite.OpenMemory, + sqlite.OpenURI, + ) + if err != nil { + return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) + } + return &Database{ - pool: pool, - Schema: schema, - // conn: c, + pool: pool, + Schema: schema, + writeConn: writeConn, }, nil } @@ -283,6 +299,9 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:un // Save inserts the connection conn into the SQLite database. If conn // already exists the table row is updated instead. +// +// Save uses the database write connection instead of relying on the +// connection pool. func (db *Database) Save(ctx context.Context, conn Conn) error { connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig) if err != nil { @@ -301,6 +320,9 @@ func (db *Database) Save(ctx context.Context, conn Conn) error { updateSets = append(updateSets, fmt.Sprintf("%s = :%s", key, key)) } + db.l.Lock() + defer db.l.Unlock() + // TODO(ppacher): make sure this one can be cached to speed up inserting // and save some CPU cycles for the user sql := fmt.Sprintf( @@ -314,14 +336,12 @@ func (db *Database) Save(ctx context.Context, conn Conn) error { strings.Join(updateSets, ", "), ) - if err := db.withConn(ctx, func(conn *sqlite.Conn) error { - return sqlitex.ExecuteTransient(conn, sql, &sqlitex.ExecOptions{ - Named: values, - ResultFunc: func(stmt *sqlite.Stmt) error { - log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount()) - return nil - }, - }) + if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{ + Named: values, + ResultFunc: func(stmt *sqlite.Stmt) error { + log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount()) + return nil + }, }); err != nil { log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values) return err @@ -333,7 +353,5 @@ func (db *Database) Save(ctx context.Context, conn Conn) error { // Close closes the underlying database connection. db should and cannot be // used after Close() has returned. func (db *Database) Close() error { - return nil - - //return db.conn.Close() + return db.writeConn.Close() }