diff --git a/netquery/database.go b/netquery/database.go index 0434d3c1..dab486c9 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -22,7 +22,7 @@ import ( ) // InMemory is the "file path" to open a new in-memory database. -const InMemory = "file:inmem.db" +const InMemory = "file:inmem.db?mode=memory" // Available connection types as their string representation. const ( @@ -115,13 +115,17 @@ func New(path string) (*Database, error) { sqlite.OpenReadOnly, sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit. sqlite.OpenSharedCache, - sqlite.OpenMemory, + //sqlite.OpenMemory, sqlite.OpenURI, ) if err != nil { return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", path, err) } + if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE 'file:///tmp/history.db?mode=ro' AS history", nil); err != nil { + return nil, fmt.Errorf("failed to attach history database: %w", err) + } + return c, nil } @@ -152,7 +156,7 @@ func New(path string) (*Database, error) { sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit. sqlite.OpenWAL, sqlite.OpenSharedCache, - sqlite.OpenMemory, + //sqlite.OpenMemory, sqlite.OpenURI, ) if err != nil { @@ -189,28 +193,44 @@ func NewInMemory() (*Database, error) { // any data-migrations. Once the history module is implemented this should // become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration. func (db *Database) ApplyMigrations() error { - // get the create-table SQL statement from the inferred schema - sql := db.Schema.CreateStatement(true) - + log.Errorf("applying migrations ...") db.l.Lock() defer db.l.Unlock() - // execute the SQL - if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil { - return fmt.Errorf("failed to create schema: %w", err) + // Attach the history database + log.Errorf("attaching database history") + if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE 'file:///tmp/history.db?mode=rwc' AS 'history';", nil); err != nil { + return fmt.Errorf("failed to attach history database: %w", err) } - // create a few indexes - indexes := []string{ - `CREATE INDEX profile_id_index ON %s (profile)`, - `CREATE INDEX started_time_index ON %s (strftime('%%s', started)+0)`, - `CREATE INDEX started_ended_time_index ON %s (strftime('%%s', started)+0, strftime('%%s', ended)+0) WHERE ended IS NOT NULL`, - } - for _, idx := range indexes { - stmt := fmt.Sprintf(idx, db.Schema.Name) + dbNames := []string{"main", "history"} + for _, dbName := range dbNames { + // get the create-table SQL statement from the inferred schema + sql := db.Schema.CreateStatement(dbName, true) + log.Errorf("creating table schema for database %q", dbName) - if err := sqlitex.ExecuteTransient(db.writeConn, stmt, nil); err != nil { - return fmt.Errorf("failed to create index: %q: %w", idx, err) + // execute the SQL + if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil { + return fmt.Errorf("failed to create schema on database %q: %w", dbName, err) + } + + // create a few indexes + indexes := []string{ + `CREATE INDEX IF NOT EXISTS %sprofile_id_index ON %s (profile)`, + `CREATE INDEX IF NOT EXISTS %sstarted_time_index ON %s (strftime('%%s', started)+0)`, + `CREATE INDEX IF NOT EXISTS %sstarted_ended_time_index ON %s (strftime('%%s', started)+0, strftime('%%s', ended)+0) WHERE ended IS NOT NULL`, + } + for _, idx := range indexes { + name := "" + if dbName != "" { + name = dbName + "." + } + + stmt := fmt.Sprintf(idx, name, db.Schema.Name) + + if err := sqlitex.ExecuteTransient(db.writeConn, stmt, nil); err != nil { + return fmt.Errorf("failed to create index on database %q: %q: %w", dbName, idx, err) + } } } @@ -254,7 +274,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) { Count int `sqlite:"count"` } - if err := db.Execute(ctx, "SELECT COUNT(*) AS count FROM connections", orm.WithResult(&result)); err != nil { + if err := db.Execute(ctx, "SELECT COUNT(*) AS count FROM (SELECT * FROM main.connections UNION SELECT * from history.connections)", orm.WithResult(&result)); err != nil { return 0, fmt.Errorf("failed to perform query: %w", err) } @@ -273,7 +293,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) { func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) { where := `WHERE ended IS NOT NULL AND datetime(ended) < datetime(:threshold)` - sql := "DELETE FROM connections " + where + ";" + sql := "DELETE FROM main.connections " + where + ";" args := orm.WithNamedArgs(map[string]interface{}{ ":threshold": threshold.UTC().Format(orm.SqliteTimeFormat), @@ -367,26 +387,29 @@ func (db *Database) Save(ctx context.Context, conn Conn) error { // TODO(ppacher): make sure this one can be cached to speed up inserting // and save some CPU cycles for the user - sql := fmt.Sprintf( - `INSERT INTO connections (%s) - VALUES(%s) - ON CONFLICT(id) DO UPDATE SET - %s - `, - strings.Join(columns, ", "), - strings.Join(placeholders, ", "), - strings.Join(updateSets, ", "), - ) + for _, dbName := range []string{"main", "history"} { + sql := fmt.Sprintf( + `INSERT INTO %s.connections (%s) + VALUES(%s) + ON CONFLICT(id) DO UPDATE SET + %s + `, + dbName, + strings.Join(columns, ", "), + strings.Join(placeholders, ", "), + strings.Join(updateSets, ", "), + ) - if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{ - Named: values, - ResultFunc: func(stmt *sqlite.Stmt) error { - log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount()) - return nil - }, - }); err != nil { - log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values) - return err + if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{ + Named: values, + ResultFunc: func(stmt *sqlite.Stmt) error { + log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount()) + return nil + }, + }); err != nil { + log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values) + return err + } } return nil diff --git a/netquery/orm/schema_builder.go b/netquery/orm/schema_builder.go index 508b7b18..e9e46874 100644 --- a/netquery/orm/schema_builder.go +++ b/netquery/orm/schema_builder.go @@ -66,12 +66,17 @@ func (ts TableSchema) GetColumnDef(name string) *ColumnDef { } // CreateStatement build the CREATE SQL statement for the table. -func (ts TableSchema) CreateStatement(ifNotExists bool) string { +func (ts TableSchema) CreateStatement(databaseName string, ifNotExists bool) string { sql := "CREATE TABLE" if ifNotExists { sql += " IF NOT EXISTS" } - sql += " " + ts.Name + " ( " + name := ts.Name + if databaseName != "" { + name = databaseName + "." + ts.Name + } + + sql += " " + name + " ( " for idx, col := range ts.Columns { sql += col.AsSQL() diff --git a/netquery/query_handler.go b/netquery/query_handler.go index 599c71ec..e03f6a83 100644 --- a/netquery/query_handler.go +++ b/netquery/query_handler.go @@ -190,7 +190,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab } selectClause := req.generateSelectClause() - query := `SELECT ` + selectClause + ` FROM connections` + query := `SELECT ` + selectClause + ` FROM ( SELECT *, 'memory' as _source FROM main.connections UNION SELECT *, 'history' as _source FROM history.connections) ` if whereClause != "" { query += " WHERE " + whereClause } @@ -298,7 +298,8 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) ( func (req *QueryRequestPayload) generateSelectClause() string { selectClause := "*" if len(req.selectedFields) > 0 { - selectClause = strings.Join(req.selectedFields, ", ") + selectedFields := append(req.selectedFields, "_source") + selectClause = strings.Join(selectedFields, ", ") } return selectClause