diff --git a/go.mod b/go.mod index 8ee22764..741cee16 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/jackc/puddle/v2 v2.0.0-beta.1 // indirect github.com/josharian/native v1.0.0 // indirect github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect github.com/mattn/go-isatty v0.0.14 // indirect diff --git a/go.sum b/go.sum index 587e7aeb..fcfb45d3 100644 --- a/go.sum +++ b/go.sum @@ -569,6 +569,8 @@ github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJ github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jackc/puddle/v2 v2.0.0-beta.1 h1:Y4Ao+kFWANtDhWUkdw1JcbH+x84/aq6WUfhVQ1wdib8= +github.com/jackc/puddle/v2 v2.0.0-beta.1/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= diff --git a/netquery/database.go b/netquery/database.go index ed1dcd8b..b53d8c14 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -12,6 +12,7 @@ import ( "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" + "github.com/jackc/puddle/v2" "github.com/safing/portbase/log" "github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/network" @@ -20,7 +21,7 @@ import ( ) // InMemory is the "file path" to open a new in-memory database. -const InMemory = ":memory:" +const InMemory = "file:inmemdb" // Available connection types as their string representation. const ( @@ -47,8 +48,10 @@ type ( Database struct { Schema *orm.TableSchema - l sync.Mutex - conn *sqlite.Conn + pool *puddle.Pool[*sqlite.Conn] + + l sync.Mutex + // conn *sqlite.Conn } // Conn is a network connection that is stored in a SQLite database and accepted @@ -109,26 +112,38 @@ type ( // SIGBUS/SIGSEGV when used with out a dedicated mutex in *Database. // func New(path string) (*Database, error) { - c, err := sqlite.OpenConn( - path, - sqlite.OpenCreate, - sqlite.OpenReadWrite, - sqlite.OpenFullMutex, - sqlite.OpenWAL, - sqlite.OpenSharedCache, - ) - if err != nil { - return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) + constructor := func(ctx context.Context) (*sqlite.Conn, error) { + c, err := sqlite.OpenConn( + path, + sqlite.OpenCreate, + sqlite.OpenReadWrite, + sqlite.OpenNoMutex, + sqlite.OpenWAL, + sqlite.OpenSharedCache, + sqlite.OpenMemory, + ) + if err != nil { + return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) + } + + return c, nil } + destructor := func(resource *sqlite.Conn) { + resource.Close() + } + + pool := puddle.NewPool(constructor, destructor, 10) + schema, err := orm.GenerateTableSchema("connections", Conn{}) if err != nil { return nil, err } return &Database{ + pool: pool, Schema: schema, - conn: c, + // conn: c, }, nil } @@ -156,14 +171,26 @@ func NewInMemory() (*Database, error) { // become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration. func (db *Database) ApplyMigrations() error { // get the create-table SQL statement from the inferred schema - sql := db.Schema.CreateStatement(false) + sql := db.Schema.CreateStatement(true) - // execute the SQL - if err := sqlitex.ExecuteTransient(db.conn, sql, nil); err != nil { - return fmt.Errorf("failed to create schema: %w", err) + return db.withConn(context.Background(), func(conn *sqlite.Conn) error { + // execute the SQL + if err := sqlitex.ExecuteTransient(conn, sql, nil); err != nil { + return fmt.Errorf("failed to create schema: %w", err) + } + + return nil + }) +} + +func (db *Database) withConn(ctx context.Context, fn func(conn *sqlite.Conn) error) error { + res, err := db.pool.Acquire(ctx) + if err != nil { + return err } + defer res.Release() - return nil + return fn(res.Value()) } // Execute executes a custom SQL query against the SQLite database used by db. @@ -173,7 +200,9 @@ func (db *Database) Execute(ctx context.Context, sql string, args ...orm.QueryOp db.l.Lock() defer db.l.Unlock() - return orm.RunQuery(ctx, db.conn, sql, args...) + return db.withConn(ctx, func(conn *sqlite.Conn) error { + return orm.RunQuery(ctx, conn, sql, args...) + }) } // CountRows returns the number of rows stored in the database. @@ -239,17 +268,20 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:un defer db.l.Unlock() var conns []Conn - if err := sqlitex.ExecuteTransient(db.conn, "SELECT * FROM connections", &sqlitex.ExecOptions{ - ResultFunc: func(stmt *sqlite.Stmt) error { - var c Conn - if err := orm.DecodeStmt(ctx, db.Schema, stmt, &c, orm.DefaultDecodeConfig); err != nil { - return err - } + err := db.withConn(ctx, func(conn *sqlite.Conn) error { + return sqlitex.ExecuteTransient(conn, "SELECT * FROM connections", &sqlitex.ExecOptions{ + ResultFunc: func(stmt *sqlite.Stmt) error { + var c Conn + if err := orm.DecodeStmt(ctx, db.Schema, stmt, &c, orm.DefaultDecodeConfig); err != nil { + return err + } - conns = append(conns, c) - return nil - }, - }); err != nil { + conns = append(conns, c) + return nil + }, + }) + }) + if err != nil { return err } @@ -294,12 +326,14 @@ func (db *Database) Save(ctx context.Context, conn Conn) error { strings.Join(updateSets, ", "), ) - if err := sqlitex.ExecuteTransient(db.conn, sql, &sqlitex.ExecOptions{ - Named: values, - ResultFunc: func(stmt *sqlite.Stmt) error { - log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount()) - return nil - }, + if err := db.withConn(ctx, func(conn *sqlite.Conn) error { + return sqlitex.ExecuteTransient(conn, sql, &sqlitex.ExecOptions{ + Named: values, + ResultFunc: func(stmt *sqlite.Stmt) error { + log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount()) + return nil + }, + }) }); err != nil { log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values) return err @@ -311,5 +345,7 @@ func (db *Database) Save(ctx context.Context, conn Conn) error { // Close closes the underlying database connection. db should and cannot be // used after Close() has returned. func (db *Database) Close() error { - return db.conn.Close() + return nil + + //return db.conn.Close() }