Merge pull request #1255 from safing/feature/history-module

feature/history module
This commit is contained in:
Daniel Hovie
2023-07-22 20:50:24 +02:00
committed by GitHub
31 changed files with 653 additions and 148 deletions

2
.gitignore vendored
View File

@@ -53,3 +53,5 @@ _testmain.go
# Custom dev scripts
win_dev_*
go.work
go.work.sum

View File

@@ -88,12 +88,12 @@ Portmaster is a privacy suite for your desktop OS.
All details and guides live in the dedicated [wiki](https://wiki.safing.io/)
- [Getting Started](https://wiki.safing.io/en/Portmaster/App/GettingStarted)
- [Getting Started](https://wiki.safing.io/en/Portmaster/App)
- Install
- [on Windows](https://wiki.safing.io/en/Portmaster/Install/Windows)
- [on Linux](https://wiki.safing.io/en/Portmaster/Install/Linux)
- [Contribute](https://wiki.safing.io/en/Contribute)
- [VPN Compatibility](https://wiki.safing.io/en/Portmaster/App/Compatibility#vpn-compatibly)
- [Software Compatibility](https://wiki.safing.io/en/Portmaster/App/Compatibility)
- [Architecture](https://wiki.safing.io/en/Portmaster/Architecture/Overview)
- [Architecture](https://wiki.safing.io/en/Portmaster/Architecture)

View File

@@ -28,12 +28,12 @@ var (
systemIntegrationCheckDialNet = fmt.Sprintf("ip4:%d", uint8(SystemIntegrationCheckProtocol))
systemIntegrationCheckDialIP = SystemIntegrationCheckDstIP.String()
systemIntegrationCheckPackets = make(chan packet.Packet, 1)
systemIntegrationCheckWaitDuration = 20 * time.Second
systemIntegrationCheckWaitDuration = 40 * time.Second
// DNSCheckInternalDomainScope is the domain scope to use for dns checks.
DNSCheckInternalDomainScope = ".self-check." + resolver.InternalSpecialUseDomain
dnsCheckReceivedDomain = make(chan string, 1)
dnsCheckWaitDuration = 20 * time.Second
dnsCheckWaitDuration = 40 * time.Second
dnsCheckAnswerLock sync.Mutex
dnsCheckAnswer net.IP
)

View File

@@ -23,11 +23,11 @@ var (
askTimeout config.IntOption
CfgOptionPermanentVerdictsKey = "filter/permanentVerdicts"
cfgOptionPermanentVerdictsOrder = 96
cfgOptionPermanentVerdictsOrder = 80
permanentVerdicts config.BoolOption
CfgOptionDNSQueryInterceptionKey = "filter/dnsQueryInterception"
cfgOptionDNSQueryInterceptionOrder = 97
cfgOptionDNSQueryInterceptionOrder = 81
dnsQueryInterception config.BoolOption
)

View File

@@ -133,15 +133,18 @@ func reportBandwidth(ctx context.Context, objs bpfObjects, bandwidthUpdates chan
false,
)
update := &packet.BandwidthUpdate{
ConnID: connID,
RecvBytes: skInfo.Rx,
SentBytes: skInfo.Tx,
Method: packet.Absolute,
ConnID: connID,
BytesReceived: skInfo.Rx,
BytesSent: skInfo.Tx,
Method: packet.Absolute,
}
select {
case bandwidthUpdates <- update:
case <-ctx.Done():
return
default:
log.Warning("ebpf: bandwidth update queue is full, skipping rest of batch")
return
}
}
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"sync/atomic"
"time"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/ringbuf"
@@ -112,9 +113,11 @@ func ConnectionListenerWorker(ctx context.Context, packets chan packet.Packet) e
Src: convertArrayToIPv4(event.Saddr, packet.IPVersion(event.IpVersion)),
Dst: convertArrayToIPv4(event.Daddr, packet.IPVersion(event.IpVersion)),
PID: int(event.Pid),
SeenAt: time.Now(),
})
if isEventValid(event) {
log.Debugf("ebpf: received valid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt)
// DEBUG:
// log.Debugf("ebpf: received valid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt)
packets <- pkt
} else {
log.Warningf("ebpf: received invalid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt)

View File

@@ -196,7 +196,8 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int {
select {
case q.packets <- pkt:
log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt))
// DEBUG:
// log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt))
case <-ctx.Done():
return 0
case <-time.After(time.Second):

View File

@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/florianl/go-nfqueue"
"github.com/tevino/abool"
@@ -117,7 +116,13 @@ func (pkt *packet) setMark(mark int) error {
}
break
}
log.Tracer(pkt.Ctx()).Tracef("nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.Info().SeenAt))
// DEBUG:
// log.Tracer(pkt.Ctx()).Tracef(
// "nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s",
// pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id,
// markToString(mark), time.Since(pkt.Info().SeenAt),
// )
return nil
}

View File

@@ -55,7 +55,7 @@ func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.Bandwidt
}
// Report all statistics.
for _, stat := range stats {
for i, stat := range stats {
connID := packet.CreateConnectionID(
packet.IPProtocol(stat.protocol),
convertArrayToIP(stat.localIP, stat.ipV6 == 1), stat.localPort,
@@ -63,15 +63,18 @@ func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.Bandwidt
false,
)
update := &packet.BandwidthUpdate{
ConnID: connID,
RecvBytes: stat.receivedBytes,
SentBytes: stat.transmittedBytes,
Method: packet.Additive,
ConnID: connID,
BytesReceived: stat.receivedBytes,
BytesSent: stat.transmittedBytes,
Method: packet.Additive,
}
select {
case bandwidthUpdates <- update:
case <-ctx.Done():
return nil
default:
log.Warningf("kext: bandwidth update queue is full, skipping rest of batch (%d entries)", len(stats)-i)
return nil
}
}

View File

@@ -14,7 +14,7 @@ import (
var module *modules.Module
func init() {
module = modules.Register("filter", prep, start, stop, "core", "interception", "intel")
module = modules.Register("filter", prep, start, stop, "core", "interception", "intel", "netquery")
subsystems.Register(
"filter",
"Privacy Filter",

View File

@@ -18,6 +18,7 @@ import (
"github.com/safing/portmaster/firewall/inspection"
"github.com/safing/portmaster/firewall/interception"
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/netquery"
"github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet"
@@ -510,7 +511,7 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V
atomic.AddUint64(packetsFailed, 1)
err = pkt.Drop()
case network.VerdictUndecided, network.VerdictUndeterminable:
log.Warningf("filter: tried to apply verdict %s to pkt %s: dropping instead", verdict, pkt)
log.Tracer(pkt.Ctx()).Warningf("filter: tried to apply verdict %s to pkt %s: dropping instead", verdict, pkt)
fallthrough
default:
atomic.AddUint64(packetsDropped, 1)
@@ -518,7 +519,7 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V
}
if err != nil {
log.Warningf("filter: failed to apply verdict to pkt %s: %s", pkt, err)
log.Tracer(pkt.Ctx()).Warningf("filter: failed to apply verdict to pkt %s: %s", pkt, err)
}
}
@@ -616,7 +617,7 @@ func bandwidthUpdateHandler(ctx context.Context) error {
return nil
case bwUpdate := <-interception.BandwidthUpdates:
if bwUpdate != nil {
updateBandwidth(bwUpdate)
updateBandwidth(ctx, bwUpdate)
// DEBUG:
// log.Debugf("filter: bandwidth update: %s", bwUpdate)
} else {
@@ -626,9 +627,9 @@ func bandwidthUpdateHandler(ctx context.Context) error {
}
}
func updateBandwidth(bwUpdate *packet.BandwidthUpdate) {
func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
// Check if update makes sense.
if bwUpdate.RecvBytes == 0 && bwUpdate.SentBytes == 0 {
if bwUpdate.BytesReceived == 0 && bwUpdate.BytesSent == 0 {
return
}
@@ -648,16 +649,29 @@ func updateBandwidth(bwUpdate *packet.BandwidthUpdate) {
// Update stats according to method.
switch bwUpdate.Method {
case packet.Absolute:
conn.RecvBytes = bwUpdate.RecvBytes
conn.SentBytes = bwUpdate.SentBytes
conn.BytesReceived = bwUpdate.BytesReceived
conn.BytesSent = bwUpdate.BytesSent
case packet.Additive:
conn.RecvBytes += bwUpdate.RecvBytes
conn.SentBytes += bwUpdate.SentBytes
conn.BytesReceived += bwUpdate.BytesReceived
conn.BytesSent += bwUpdate.BytesSent
default:
log.Warningf("filter: unsupported bandwidth update method: %d", bwUpdate.Method)
return
}
// TODO: Send update.
// Update bandwidth in the netquery module.
if netquery.DefaultModule != nil && conn.BandwidthEnabled {
if err := netquery.DefaultModule.Store.UpdateBandwidth(
ctx,
conn.HistoryEnabled,
conn.Process().GetID(),
conn.ID,
conn.BytesReceived,
conn.BytesSent,
); err != nil {
log.Errorf("filter: failed to persist bandwidth data: %s", err)
}
}
}
func statLogger(ctx context.Context) error {

View File

@@ -91,12 +91,12 @@ func createPrompt(ctx context.Context, conn *network.Connection) (n *notificatio
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
log.Tracer(ctx).Warningf("filter: tried creating prompt for connection without profile")
return
return nil
}
localProfile := layeredProfile.LocalProfile()
if localProfile == nil {
log.Tracer(ctx).Warningf("filter: tried creating prompt for connection without local profile")
return
return nil
}
// first check if there is an existing notification for this.

4
go.mod
View File

@@ -20,7 +20,7 @@ require (
github.com/safing/jess v0.3.1
github.com/safing/portbase v0.17.0
github.com/safing/portmaster-android/go v0.0.0-20230605085256-6abf4c495626
github.com/safing/spn v0.6.8
github.com/safing/spn v0.6.9
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/spf13/cobra v1.7.0
github.com/spkg/zipfs v0.7.1
@@ -28,6 +28,7 @@ require (
github.com/tannerryan/ring v1.1.2
github.com/tevino/abool v1.2.0
github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
golang.org/x/net v0.12.0
golang.org/x/sync v0.3.0
golang.org/x/sys v0.10.0
@@ -86,7 +87,6 @@ require (
github.com/zalando/go-keyring v0.2.3 // indirect
go.etcd.io/bbolt v1.3.7 // indirect
golang.org/x/crypto v0.11.0 // indirect
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.11.0 // indirect

4
go.sum
View File

@@ -210,8 +210,8 @@ github.com/safing/portbase v0.17.0 h1:RsDzbCGxRIbgaArri3y7MZskfxytEvvkzJpiboDUER
github.com/safing/portbase v0.17.0/go.mod h1:eKCRqsfMFLVhNpd2sY/fKvnbuk+LrIYnQEZCg1i86Ho=
github.com/safing/portmaster-android/go v0.0.0-20230605085256-6abf4c495626 h1:olc/REnUdpJN/Gmz8B030OxLpMYxyPDTrDILNEw0eKs=
github.com/safing/portmaster-android/go v0.0.0-20230605085256-6abf4c495626/go.mod h1:abwyAQrZGemWbSh/aCD9nnkp0SvFFf/mGWkAbOwPnFE=
github.com/safing/spn v0.6.8 h1:2obvyMzyw5X3CIYedLBE88kNBBrJumF84q1qtQSFqkc=
github.com/safing/spn v0.6.8/go.mod h1:Mh9bmkqFhO/dHNi9RWXzoXjQij893I4Lj8Wn4tQ0KZA=
github.com/safing/spn v0.6.9 h1:CCRN5jgshJrLBHwGHl0ywWwhukc+Wff7/I66qgYyymg=
github.com/safing/spn v0.6.9/go.mod h1:Mh9bmkqFhO/dHNi9RWXzoXjQij893I4Lj8Wn4tQ0KZA=
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/seehuhn/fortuna v1.0.1 h1:lu9+CHsmR0bZnx5Ay646XvCSRJ8PJTi5UYJwDBX68H0=

View File

@@ -2,18 +2,23 @@ package netquery
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"path"
"sort"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/jackc/puddle/v2"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm"
"github.com/safing/portmaster/network"
@@ -22,7 +27,7 @@ import (
)
// InMemory is the "file path" to open a new in-memory database.
const InMemory = "file:inmem.db"
const InMemory = "file:inmem.db?mode=memory"
// Available connection types as their string representation.
const (
@@ -46,6 +51,7 @@ type (
Schema *orm.TableSchema
readConnPool *puddle.Pool[*sqlite.Conn]
historyPath string
l sync.Mutex
writeConn *sqlite.Conn
@@ -82,7 +88,9 @@ type (
Latitude float64 `sqlite:"latitude"`
Longitude float64 `sqlite:"longitude"`
Scope netutils.IPScope `sqlite:"scope"`
Verdict network.Verdict `sqlite:"verdict"`
WorstVerdict network.Verdict `sqlite:"worst_verdict"`
ActiveVerdict network.Verdict `sqlite:"verdict"`
FirewallVerdict network.Verdict `sqlite:"firewall_verdict"`
Started time.Time `sqlite:"started,text,time"`
Ended *time.Time `sqlite:"ended,text,time"`
Tunneled bool `sqlite:"tunneled"`
@@ -93,6 +101,8 @@ type (
Allowed *bool `sqlite:"allowed"`
ProfileRevision int `sqlite:"profile_revision"`
ExitNode *string `sqlite:"exit_node"`
BytesReceived uint64 `sqlite:"bytes_received,default=0"`
BytesSent uint64 `sqlite:"bytes_sent,default=0"`
// TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
@@ -102,24 +112,33 @@ type (
}
)
// New opens a new in-memory database named path.
// New opens a new in-memory database named path and attaches a persistent history database.
//
// The returned Database used connection pooling for read-only connections
// (see Execute). To perform database writes use either Save() or ExecuteWrite().
// Note that write connections are serialized by the Database object before being
// handed over to SQLite.
func New(path string) (*Database, error) {
func New(dbPath string) (*Database, error) {
historyParentDir := dataroot.Root().ChildDir("databases", 0o700)
if err := historyParentDir.Ensure(); err != nil {
return nil, fmt.Errorf("failed to ensure database directory exists: %w", err)
}
historyPath := "file://" + path.Join(historyParentDir.Path, "history.db")
constructor := func(ctx context.Context) (*sqlite.Conn, error) {
c, err := sqlite.OpenConn(
path,
dbPath,
sqlite.OpenReadOnly,
sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit.
sqlite.OpenSharedCache,
sqlite.OpenMemory,
sqlite.OpenURI,
)
if err != nil {
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", path, err)
return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", dbPath, err)
}
if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE '"+historyPath+"?mode=ro' AS history", nil); err != nil {
return nil, fmt.Errorf("failed to attach history database: %w", err)
}
return c, nil
@@ -146,23 +165,22 @@ func New(path string) (*Database, error) {
}
writeConn, err := sqlite.OpenConn(
path,
dbPath,
sqlite.OpenCreate,
sqlite.OpenReadWrite,
sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit.
sqlite.OpenWAL,
sqlite.OpenSharedCache,
sqlite.OpenMemory,
sqlite.OpenURI,
)
if err != nil {
return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err)
return nil, fmt.Errorf("failed to open sqlite at %s: %w", dbPath, err)
}
return &Database{
readConnPool: pool,
Schema: schema,
writeConn: writeConn,
historyPath: historyPath,
}, nil
}
@@ -189,28 +207,42 @@ func NewInMemory() (*Database, error) {
// any data-migrations. Once the history module is implemented this should
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration.
func (db *Database) ApplyMigrations() error {
// get the create-table SQL statement from the inferred schema
sql := db.Schema.CreateStatement(true)
log.Errorf("applying migrations ...")
db.l.Lock()
defer db.l.Unlock()
// execute the SQL
if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil {
return fmt.Errorf("failed to create schema: %w", err)
if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE '"+db.historyPath+"?mode=rwc' AS 'history';", nil); err != nil {
return fmt.Errorf("failed to attach history database: %w", err)
}
// create a few indexes
indexes := []string{
`CREATE INDEX profile_id_index ON %s (profile)`,
`CREATE INDEX started_time_index ON %s (strftime('%%s', started)+0)`,
`CREATE INDEX started_ended_time_index ON %s (strftime('%%s', started)+0, strftime('%%s', ended)+0) WHERE ended IS NOT NULL`,
}
for _, idx := range indexes {
stmt := fmt.Sprintf(idx, db.Schema.Name)
dbNames := []string{"main", "history"}
for _, dbName := range dbNames {
// get the create-table SQL statement from the inferred schema
sql := db.Schema.CreateStatement(dbName, true)
log.Debugf("creating table schema for database %q", dbName)
if err := sqlitex.ExecuteTransient(db.writeConn, stmt, nil); err != nil {
return fmt.Errorf("failed to create index: %q: %w", idx, err)
// execute the SQL
if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil {
return fmt.Errorf("failed to create schema on database %q: %w", dbName, err)
}
// create a few indexes
indexes := []string{
`CREATE INDEX IF NOT EXISTS %sprofile_id_index ON %s (profile)`,
`CREATE INDEX IF NOT EXISTS %sstarted_time_index ON %s (strftime('%%s', started)+0)`,
`CREATE INDEX IF NOT EXISTS %sstarted_ended_time_index ON %s (strftime('%%s', started)+0, strftime('%%s', ended)+0) WHERE ended IS NOT NULL`,
}
for _, idx := range indexes {
name := ""
if dbName != "" {
name = dbName + "."
}
stmt := fmt.Sprintf(idx, name, db.Schema.Name)
if err := sqlitex.ExecuteTransient(db.writeConn, stmt, nil); err != nil {
return fmt.Errorf("failed to create index on database %q: %q: %w", dbName, idx, err)
}
}
}
@@ -254,7 +286,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
Count int `sqlite:"count"`
}
if err := db.Execute(ctx, "SELECT COUNT(*) AS count FROM connections", orm.WithResult(&result)); err != nil {
if err := db.Execute(ctx, "SELECT COUNT(*) AS count FROM (SELECT * FROM main.connections UNION SELECT * from history.connections)", orm.WithResult(&result)); err != nil {
return 0, fmt.Errorf("failed to perform query: %w", err)
}
@@ -265,7 +297,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
return result[0].Count, nil
}
// Cleanup removes all connections that have ended before threshold.
// Cleanup removes all connections that have ended before threshold from the live database.
//
// NOTE(ppacher): there is no easy way to get the number of removed
// rows other than counting them in a first step. Though, that's
@@ -273,7 +305,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) {
where := `WHERE ended IS NOT NULL
AND datetime(ended) < datetime(:threshold)`
sql := "DELETE FROM connections " + where + ";"
sql := "DELETE FROM main.connections " + where + ";"
args := orm.WithNamedArgs(map[string]interface{}{
":threshold": threshold.UTC().Format(orm.SqliteTimeFormat),
@@ -303,6 +335,21 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro
return result[0].Count, nil
}
// RemoveAllHistoryData removes all connections from the history database.
func (db *Database) RemoveAllHistoryData(ctx context.Context) error {
query := fmt.Sprintf("DELETE FROM %s.connections", HistoryDatabase)
return db.ExecuteWrite(ctx, query)
}
// RemoveHistoryForProfile removes all connections from the history database
// for a given profile ID (source/id).
func (db *Database) RemoveHistoryForProfile(ctx context.Context, profileID string) error {
query := fmt.Sprintf("DELETE FROM %s.connections WHERE profile = :profile", HistoryDatabase)
return db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
":profile": profileID,
}))
}
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database
// as JSON to w.
// Any error aborts dumping rows and is returned.
@@ -330,13 +377,76 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:un
return enc.Encode(conns)
}
// MarkAllHistoryConnectionsEnded marks all connections in the history database as ended.
func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
query := fmt.Sprintf("UPDATE %s.connections SET active = FALSE, ended = :ended WHERE active = TRUE", HistoryDatabase)
if err := db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{
":ended": time.Now().Format(orm.SqliteTimeFormat),
})); err != nil {
return err
}
return nil
}
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database.
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
data := connID + "-" + processKey
hash := sha256.Sum256([]byte(data))
dbConnID := hex.EncodeToString(hash[:])
params := map[string]any{
":id": dbConnID,
}
parts := []string{}
if bytesReceived != 0 {
parts = append(parts, "bytes_received = :bytes_received")
params[":bytes_received"] = bytesReceived
}
if bytesSent != 0 {
parts = append(parts, "bytes_sent = :bytes_sent")
params[":bytes_sent"] = bytesSent
}
updateSet := strings.Join(parts, ", ")
updateStmts := []string{
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, LiveDatabase, updateSet),
}
if enableHistory {
updateStmts = append(updateStmts,
fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, HistoryDatabase, updateSet),
)
}
merr := new(multierror.Error)
for _, stmt := range updateStmts {
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, err)
}
}
return merr.ErrorOrNil()
}
// Save inserts the connection conn into the SQLite database. If conn
// already exists the table row is updated instead.
//
// Save uses the database write connection instead of relying on the
// connection pool.
func (db *Database) Save(ctx context.Context, conn Conn) error {
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig)
func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error {
// convert the connection to a param map where each key is already translated
// to the sql column name. We also skip bytes_received and bytes_sent since those
// will be updated independently from the connection object.
connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig, []string{
"bytes_received",
"bytes_sent",
})
if err != nil {
return fmt.Errorf("failed to encode connection for SQL: %w", err)
}
@@ -367,26 +477,35 @@ func (db *Database) Save(ctx context.Context, conn Conn) error {
// TODO(ppacher): make sure this one can be cached to speed up inserting
// and save some CPU cycles for the user
sql := fmt.Sprintf(
`INSERT INTO connections (%s)
VALUES(%s)
ON CONFLICT(id) DO UPDATE SET
%s
`,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "),
strings.Join(updateSets, ", "),
)
dbNames := []DatabaseName{LiveDatabase}
if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{
Named: values,
ResultFunc: func(stmt *sqlite.Stmt) error {
log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount())
return nil
},
}); err != nil {
log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values)
return err
if enableHistory {
dbNames = append(dbNames, HistoryDatabase)
}
for _, dbName := range dbNames {
sql := fmt.Sprintf(
`INSERT INTO %s.connections (%s)
VALUES(%s)
ON CONFLICT(id) DO UPDATE SET
%s
`,
dbName,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "),
strings.Join(updateSets, ", "),
)
if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{
Named: values,
ResultFunc: func(stmt *sqlite.Stmt) error {
log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount())
return nil
},
}); err != nil {
log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values)
return err
}
}
return nil

View File

@@ -25,7 +25,22 @@ type (
// insert or an update.
// The ID of Conn is unique and can be trusted to never collide with other
// connections of the save device.
Save(context.Context, Conn) error
Save(context.Context, Conn, bool) error
// MarkAllHistoryConnectionsEnded marks all active connections in the history
// database as ended NOW.
MarkAllHistoryConnectionsEnded(context.Context) error
// RemoveAllHistoryData removes all connections from the history database.
RemoveAllHistoryData(context.Context) error
// RemoveHistoryForProfile removes all connections from the history database.
// for a given profile ID (source/id)
RemoveHistoryForProfile(context.Context, string) error
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database.
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
}
// Manager handles new and updated network.Connections feeds and persists them
@@ -98,9 +113,10 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
continue
}
log.Tracef("netquery: updating connection %s", conn.ID)
// DEBUG:
// log.Tracef("netquery: updating connection %s", conn.ID)
if err := mng.store.Save(ctx, *model); err != nil {
if err := mng.store.Save(ctx, *model, conn.HistoryEnabled); err != nil {
log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err)
continue
@@ -158,7 +174,9 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
IPProtocol: conn.IPProtocol,
LocalIP: conn.LocalIP.String(),
LocalPort: conn.LocalPort,
Verdict: conn.Verdict.Firewall, // TODO: Expose both Worst and Firewall verdicts.
FirewallVerdict: conn.Verdict.Firewall,
ActiveVerdict: conn.Verdict.Active,
WorstVerdict: conn.Verdict.Worst,
Started: time.Unix(conn.Started, 0),
Tunneled: conn.Tunneled,
Encrypted: conn.Encrypted,
@@ -250,7 +268,7 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
}
func genConnID(conn *network.Connection) string {
data := conn.ID + "-" + time.Unix(conn.Started, 0).String()
data := conn.ID + "-" + conn.Process().GetID()
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}

View File

@@ -2,39 +2,58 @@ package netquery
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/hashicorp/go-multierror"
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/runtime"
"github.com/safing/portmaster/network"
)
// DefaultModule is the default netquery module.
var DefaultModule *module
type module struct {
*modules.Module
db *database.Interface
sqlStore *Database
mng *Manager
feed chan *network.Connection
Store *Database
db *database.Interface
mng *Manager
feed chan *network.Connection
}
func init() {
m := new(module)
m.Module = modules.Register(
DefaultModule = new(module)
DefaultModule.Module = modules.Register(
"netquery",
m.prepare,
m.start,
m.stop,
DefaultModule.prepare,
DefaultModule.start,
DefaultModule.stop,
"api",
"network",
"database",
)
subsystems.Register(
"history",
"Network History",
"Keep Network History Data",
DefaultModule.Module,
"config:history/",
nil,
)
}
func (m *module) prepare() error {
@@ -45,12 +64,12 @@ func (m *module) prepare() error {
Internal: true,
})
m.sqlStore, err = NewInMemory()
m.Store, err = NewInMemory()
if err != nil {
return fmt.Errorf("failed to create in-memory database: %w", err)
}
m.mng, err = NewManager(m.sqlStore, "netquery/data/", runtime.DefaultRegistry)
m.mng, err = NewManager(m.Store, "netquery/data/", runtime.DefaultRegistry)
if err != nil {
return fmt.Errorf("failed to create manager: %w", err)
}
@@ -58,12 +77,12 @@ func (m *module) prepare() error {
m.feed = make(chan *network.Connection, 1000)
queryHander := &QueryHandler{
Database: m.sqlStore,
Database: m.Store,
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
}
chartHandler := &ChartHandler{
Database: m.sqlStore,
Database: m.Store,
}
if err := api.RegisterEndpoint(api.Endpoint{
@@ -92,6 +111,56 @@ func (m *module) prepare() error {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
Path: "netquery/history/clear",
MimeType: "application/json",
Read: api.PermitUser,
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
var body struct {
ProfileIDs []string `json:"profileIDs"`
}
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(body.ProfileIDs) == 0 {
if err := m.mng.store.RemoveAllHistoryData(r.Context()); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
merr := new(multierror.Error)
for _, profileID := range body.ProfileIDs {
if err := m.mng.store.RemoveHistoryForProfile(r.Context(), profileID); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err))
} else {
log.Infof("netquery: successfully cleared history for %s", profileID)
}
}
if err := merr.ErrorOrNil(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
w.WriteHeader(http.StatusNoContent)
},
Name: "Remove connections from profile history",
Description: "Remove all connections from the history database for one or more profiles",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
return nil
}
@@ -139,7 +208,7 @@ func (m *module) start() error {
return nil
case <-time.After(10 * time.Second):
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
count, err := m.sqlStore.Cleanup(ctx, threshold)
count, err := m.Store.Cleanup(ctx, threshold)
if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
} else {
@@ -153,7 +222,7 @@ func (m *module) start() error {
// the runtime database.
// Only expose in development mode.
if config.GetAsBool(config.CfgDevModeKey, false)() {
_, err := NewRuntimeQueryRunner(m.sqlStore, "netquery/query/", runtime.DefaultRegistry)
_, err := NewRuntimeQueryRunner(m.Store, "netquery/query/", runtime.DefaultRegistry)
if err != nil {
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
}
@@ -163,5 +232,16 @@ func (m *module) start() error {
}
func (m *module) stop() error {
// we don't use m.Module.Ctx here because it is already cancelled when stop is called.
// just give the clean up 1 minute to happen and abort otherwise.
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil {
// handle the error by just logging it. There's not much we can do here
// and returning an error to the module system doesn't help much as well...
log.Errorf("netquery: failed to mark connections in history database as ended: %s", err)
}
return nil
}

View File

@@ -6,6 +6,7 @@ import (
"reflect"
"time"
"golang.org/x/exp/slices"
"zombiezen.com/go/sqlite"
)
@@ -22,7 +23,7 @@ type (
// ToParamMap returns a map that contains the sqlite compatible value of each struct field of
// r using the sqlite column name as a map key. It either uses the name of the
// exported struct field or the value of the "sqlite" tag.
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) {
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) {
// make sure we work on a struct type
val := reflect.Indirect(reflect.ValueOf(r))
if val.Kind() != reflect.Struct {
@@ -45,6 +46,10 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
}
if slices.Contains(skipFields, colDef.Name) {
continue
}
x, found, err := runEncodeHooks(
colDef,
fieldType.Type,

View File

@@ -119,7 +119,7 @@ func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
for idx := range cases { //nolint:paralleltest
c := cases[idx]
t.Run(c.Desc, func(t *testing.T) {
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig)
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil)
assert.NoError(t, err)
assert.Equal(t, c.Expected, res)
})

View File

@@ -143,7 +143,23 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
currentField := reflect.New(valElemType)
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
return err
resultDump := make(map[string]any)
for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ {
name := stmt.ColumnName(colIdx)
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // TODO: handle type BLOB?
case sqlite.TypeText:
resultDump[name] = stmt.ColumnText(colIdx)
case sqlite.TypeFloat:
resultDump[name] = stmt.ColumnFloat(colIdx)
case sqlite.TypeInteger:
resultDump[name] = stmt.ColumnInt(colIdx)
case sqlite.TypeNull:
resultDump[name] = "<null>"
}
}
return fmt.Errorf("%w: %+v", err, resultDump)
}
sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField))

View File

@@ -8,6 +8,8 @@ import (
"strings"
"zombiezen.com/go/sqlite"
"github.com/safing/portbase/log"
)
var errSkipStructField = errors.New("struct field should be skipped")
@@ -25,6 +27,7 @@ var (
TagTypePrefixVarchar = "varchar"
TagTypeBlob = "blob"
TagTypeFloat = "float"
TagTypePrefixDefault = "default="
)
var sqlTypeMap = map[sqlite.ColumnType]string{
@@ -52,6 +55,7 @@ type (
AutoIncrement bool
UnixNano bool
IsTime bool
Default any
}
)
@@ -66,12 +70,17 @@ func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
}
// CreateStatement build the CREATE SQL statement for the table.
func (ts TableSchema) CreateStatement(ifNotExists bool) string {
func (ts TableSchema) CreateStatement(databaseName string, ifNotExists bool) string {
sql := "CREATE TABLE"
if ifNotExists {
sql += " IF NOT EXISTS"
}
sql += " " + ts.Name + " ( "
name := ts.Name
if databaseName != "" {
name = databaseName + "." + ts.Name
}
sql += " " + name + " ( "
for idx, col := range ts.Columns {
sql += col.AsSQL()
@@ -100,6 +109,21 @@ func (def ColumnDef) AsSQL() string {
if def.AutoIncrement {
sql += " AUTOINCREMENT"
}
if def.Default != nil {
sql += " DEFAULT "
switch def.Type { //nolint:exhaustive // TODO: handle types BLOB, NULL?
case sqlite.TypeFloat:
sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64) //nolint:forcetypeassert
case sqlite.TypeInteger:
sql += strconv.FormatInt(def.Default.(int64), 10) //nolint:forcetypeassert
case sqlite.TypeText:
sql += fmt.Sprintf("%q", def.Default.(string)) //nolint:forcetypeassert
default:
log.Errorf("unsupported default value: %q %q", def.Type, def.Default)
sql = strings.TrimSuffix(sql, " DEFAULT ")
}
sql += " "
}
if !def.Nullable {
sql += " NOT NULL"
}
@@ -155,7 +179,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
kind := normalizeKind(ft.Kind())
switch kind { //nolint:exhaustive
case reflect.Int:
case reflect.Int, reflect.Uint:
def.Type = sqlite.TypeInteger
case reflect.Float64:
@@ -232,6 +256,30 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
def.Length = int(length)
}
if strings.HasPrefix(k, TagTypePrefixDefault) {
defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault)
switch def.Type { //nolint:exhaustive
case sqlite.TypeFloat:
fv, err := strconv.ParseFloat(defaultValue, 64)
if err != nil {
return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err)
}
def.Default = fv
case sqlite.TypeInteger:
fv, err := strconv.ParseInt(defaultValue, 10, 0)
if err != nil {
return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err)
}
def.Default = fv
case sqlite.TypeText:
def.Default = defaultValue
case sqlite.TypeBlob:
return fmt.Errorf("default values for TypeBlob not yet supported")
default:
return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type)
}
}
}
}
}

View File

@@ -22,14 +22,14 @@ func TestSchemaBuilder(t *testing.T) {
Int *int `sqlite:",not-null"`
Float interface{} `sqlite:",float,nullable"`
}{},
`CREATE TABLE Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
`CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`,
},
{
"Varchar",
struct {
S string `sqlite:",varchar(10)"`
}{},
`CREATE TABLE Varchar ( S VARCHAR(10) NOT NULL );`,
`CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`,
},
}
@@ -38,6 +38,6 @@ func TestSchemaBuilder(t *testing.T) {
res, err := GenerateTableSchema(c.Name, c.Model)
assert.NoError(t, err)
assert.Equal(t, c.ExpectedSQL, res.CreateStatement(false))
assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false))
}
}

View File

@@ -14,6 +14,15 @@ import (
"github.com/safing/portmaster/netquery/orm"
)
// DatabaseName is a database name constant.
type DatabaseName string
// Databases.
const (
LiveDatabase = DatabaseName("main")
HistoryDatabase = DatabaseName("history")
)
// Collection of Query and Matcher types.
// NOTE: whenever adding support for new operators make sure
// to update UnmarshalJSON as well.
@@ -48,11 +57,19 @@ type (
Distinct bool `json:"distinct"`
}
Min struct {
Condition *Query `json:"condition,omitempty"`
Field string `json:"field"`
As string `json:"as"`
Distinct bool `json:"distinct"`
}
Select struct {
Field string `json:"field"`
Count *Count `json:"$count,omitempty"`
Sum *Sum `json:"$sum,omitempty"`
Distinct *string `json:"$distinct"`
Min *Min `json:"$min,omitempty"`
Distinct *string `json:"$distinct,omitempty"`
}
Selects []Select
@@ -68,6 +85,9 @@ type (
OrderBy OrderBys `json:"orderBy"`
GroupBy []string `json:"groupBy"`
TextSearch *TextSearch `json:"textSearch"`
// A list of databases to query. If left empty,
// both, the LiveDatabase and the HistoryDatabase are queried
Databases []DatabaseName `json:"databases"`
Pagination
@@ -457,6 +477,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
Field string `json:"field"`
Count *Count `json:"$count"`
Sum *Sum `json:"$sum"`
Min *Min `json:"$min"`
Distinct *string `json:"$distinct"`
}
@@ -468,12 +489,23 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
sel.Field = res.Field
sel.Distinct = res.Distinct
sel.Sum = res.Sum
sel.Min = res.Min
if sel.Count != nil && sel.Count.As != "" {
if !charOnlyRegexp.MatchString(sel.Count.As) {
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
}
}
if sel.Sum != nil && sel.Sum.As != "" {
if !charOnlyRegexp.MatchString(sel.Sum.As) {
return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+")
}
}
if sel.Min != nil && sel.Min.As != "" {
if !charOnlyRegexp.MatchString(sel.Min.As) {
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
}
}
return nil
}

View File

@@ -12,6 +12,8 @@ import (
"strings"
"time"
"golang.org/x/exp/slices"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm"
)
@@ -152,13 +154,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
return "", nil, fmt.Errorf("generating where clause: %w", err)
}
if req.paramMap == nil {
req.paramMap = make(map[string]interface{})
}
for key, val := range paramMap {
req.paramMap[key] = val
}
req.mergeParams(paramMap)
if req.TextSearch != nil {
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
@@ -173,9 +169,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
whereClause += textClause
for key, val := range textParams {
req.paramMap[key] = val
}
req.mergeParams(textParams)
}
}
@@ -190,11 +184,24 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
}
selectClause := req.generateSelectClause()
query := `SELECT ` + selectClause + ` FROM connections`
if whereClause != "" {
query += " WHERE " + whereClause
whereClause = "WHERE " + whereClause
}
if len(req.Databases) == 0 {
req.Databases = []DatabaseName{LiveDatabase, HistoryDatabase}
}
sources := make([]string, len(req.Databases))
for idx, db := range req.Databases {
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
}
source := strings.Join(sources, " UNION ")
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause()
return strings.TrimSpace(query), req.paramMap, nil
@@ -203,6 +210,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
for idx, s := range req.Select {
var field string
switch {
case s.Count != nil:
field = s.Count.Field
@@ -211,6 +219,12 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
case s.Sum != nil:
// field is not used in case of $sum
field = "*"
case s.Min != nil:
if s.Min.Field != "" {
field = s.Min.Field
} else {
field = "*"
}
default:
field = s.Field
}
@@ -251,13 +265,40 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
return fmt.Errorf("in $sum: %w", err)
}
req.paramMap = params
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
case s.Min != nil:
if s.Min.As == "" {
return fmt.Errorf("missing 'as' for $min")
}
var (
clause string
params map[string]any
)
if s.Min.Field != "" {
clause = field
} else {
var err error
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $min: %w", err)
}
}
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
case s.Distinct != nil:
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
req.whitelistedFields = append(req.whitelistedFields, colName)
@@ -270,6 +311,16 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
return nil
}
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
if req.paramMap == nil {
req.paramMap = make(map[string]any)
}
for key, value := range params {
req.paramMap[key] = value
}
}
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
if len(req.GroupBy) == 0 {
return "", nil
@@ -332,16 +383,12 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
return colDef.Name, nil
}
for _, selected := range req.whitelistedFields {
if field == selected {
return field, nil
}
if slices.Contains(req.whitelistedFields, field) {
return field, nil
}
for _, selected := range req.selectedFields {
if field == selected {
return field, nil
}
if slices.Contains(req.selectedFields, field) {
return field, nil
}
return "", fmt.Errorf("column name %q not allowed", field)

View File

@@ -78,7 +78,8 @@ func cleanConnections() (activePIDs map[int]struct{}) {
}
case conn.Ended < deleteOlderThan:
// Step 3: delete
log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0))
// DEBUG:
// log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0))
conn.delete()
}

View File

@@ -19,6 +19,8 @@ import (
"github.com/safing/portmaster/process"
_ "github.com/safing/portmaster/process/tags"
"github.com/safing/portmaster/resolver"
"github.com/safing/spn/access"
"github.com/safing/spn/access/account"
"github.com/safing/spn/navigator"
)
@@ -173,8 +175,17 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
StopTunnel() error
}
RecvBytes uint64
SentBytes uint64
// HistoryEnabled is set to true when the connection should be persisted
// in the history database.
HistoryEnabled bool
// BanwidthEnabled is set to true if connection bandwidth data should be persisted
// in netquery.
BandwidthEnabled bool
// BytesReceived holds the observed received bytes of the connection.
BytesReceived uint64
// BytesSent holds the observed sent bytes of the connection.
BytesSent uint64
// pkgQueue is used to serialize packet handling for a single
// connection and is served by the connections packetHandler.
@@ -326,6 +337,10 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
// Inherit internal status of profile.
if localProfile := proc.Profile().LocalProfile(); localProfile != nil {
dnsConn.Internal = localProfile.Internal
if err := dnsConn.updateFeatures(); err != nil {
log.Tracer(ctx).Warningf("network: failed to check for enabled features: %s", err)
}
}
// DNS Requests are saved by the nameserver depending on the result of the
@@ -364,6 +379,10 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname
// Inherit internal status of profile.
if localProfile := remoteHost.Profile().LocalProfile(); localProfile != nil {
dnsConn.Internal = localProfile.Internal
if err := dnsConn.updateFeatures(); err != nil {
log.Tracer(ctx).Warningf("network: failed to check for enabled features: %s", err)
}
}
// DNS Requests are saved by the nameserver depending on the result of the
@@ -374,6 +393,8 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname
return dnsConn, nil
}
var tooOldTimestamp = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC).Unix()
// NewIncompleteConnection creates a new incomplete connection with only minimal information.
func NewIncompleteConnection(pkt packet.Packet) *Connection {
info := pkt.Info()
@@ -390,6 +411,12 @@ func NewIncompleteConnection(pkt packet.Packet) *Connection {
dataComplete: abool.NewBool(false),
}
// Bullshit check Started timestamp.
if conn.Started < tooOldTimestamp {
// Fix timestamp, use current time as fallback.
conn.Started = time.Now().Unix()
}
// Save connection to internal state in order to mitigate creation of
// duplicates. Do not propagate yet, as data is not yet complete.
conn.UpdateMeta()
@@ -420,7 +447,12 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
// Inherit internal status of profile.
if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil {
conn.Internal = localProfile.Internal
if err := conn.updateFeatures(); err != nil {
log.Tracer(pkt.Ctx()).Warningf("network: failed to check for enabled features: %s", err)
}
}
} else {
conn.process = nil
if pkt.InfoOnly() {
@@ -533,6 +565,31 @@ func (conn *Connection) SetLocalIP(ip net.IP) {
conn.LocalIPScope = netutils.GetIPScope(ip)
}
// updateFeatures checks which connection related features may be used and sets
// the flags accordingly.
func (conn *Connection) updateFeatures() error {
// Get user.
user, err := access.GetUser()
if err != nil {
return err
}
// Check if history may be used and if it is enabled for this application.
if user.MayUse(account.FeatureHistory) {
lProfile := conn.Process().Profile()
if lProfile != nil {
conn.HistoryEnabled = lProfile.HistoryEnabled()
}
}
// Check if bandwidth visibility may be used.
if user.MayUse(account.FeatureBWVis) {
conn.BandwidthEnabled = true
}
return nil
}
// AcceptWithContext accepts the connection.
func (conn *Connection) AcceptWithContext(reason, reasonOptionKey string, ctx interface{}) {
if !conn.SetVerdict(VerdictAccept, reason, reasonOptionKey, ctx) {

View File

@@ -4,10 +4,10 @@ import "fmt"
// BandwidthUpdate holds an update to the seen bandwidth of a connection.
type BandwidthUpdate struct {
ConnID string
RecvBytes uint64
SentBytes uint64
Method BandwidthUpdateMethod
ConnID string
BytesReceived uint64
BytesSent uint64
Method BandwidthUpdateMethod
}
// BandwidthUpdateMethod defines how the bandwidth data of a bandwidth update should be interpreted.
@@ -20,7 +20,7 @@ const (
)
func (bu *BandwidthUpdate) String() string {
return fmt.Sprintf("%s: %dB recv | %dB sent [%s]", bu.ConnID, bu.RecvBytes, bu.SentBytes, bu.Method)
return fmt.Sprintf("%s: %dB recv | %dB sent [%s]", bu.ConnID, bu.BytesReceived, bu.BytesSent, bu.Method)
}
func (bum BandwidthUpdateMethod) String() string {

View File

@@ -313,6 +313,13 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*
return process, nil
}
// GetID returns the key that is used internally to identify the process.
// The ID consists of the PID and the start time of the process as reported by
// the system.
func (p *Process) GetID() string {
return p.processKey
}
// Builds a unique identifier for a processes.
func getProcessKey(pid int32, createdTime int64) string {
return fmt.Sprintf("%d-%d", pid, createdTime)

View File

@@ -6,6 +6,7 @@ import (
"github.com/safing/portbase/config"
"github.com/safing/portmaster/profile/endpoints"
"github.com/safing/portmaster/status"
"github.com/safing/spn/access/account"
"github.com/safing/spn/navigator"
)
@@ -103,7 +104,13 @@ var (
cfgOptionDisableAutoPermit config.IntOption // security level option
cfgOptionDisableAutoPermitOrder = 65
// Setting "Permanent Verdicts" at order 96.
// Setting "Permanent Verdicts" at order 80.
// Network History.
CfgOptionEnableHistoryKey = "history/enable"
cfgOptionEnableHistory config.BoolOption
cfgOptionEnableHistoryOrder = 96
// Setting "Enable SPN" at order 128.
@@ -239,6 +246,27 @@ func registerConfiguration() error { //nolint:maintidx
cfgOptionDisableAutoPermit = config.Concurrent.GetAsInt(CfgOptionDisableAutoPermitKey, int64(status.SecurityLevelsAll))
cfgIntOptions[CfgOptionDisableAutoPermitKey] = cfgOptionDisableAutoPermit
// Enable History
err = config.Register(&config.Option{
Name: "Enable Connection History",
Key: CfgOptionEnableHistoryKey,
Description: "Whether or not to save connections to the history database",
OptType: config.OptTypeBool,
ReleaseLevel: config.ReleaseLevelStable,
ExpertiseLevel: config.ExpertiseLevelExpert,
DefaultValue: false,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder,
config.CategoryAnnotation: "History",
config.RequiresFeatureID: account.FeatureHistory,
},
})
if err != nil {
return err
}
cfgOptionEnableHistory = config.Concurrent.GetAsBool(CfgOptionEnableHistoryKey, false)
cfgBoolOptions[CfgOptionEnableHistoryKey] = cfgOptionEnableHistory
rulesHelp := strings.ReplaceAll(`Rules are checked from top to bottom, stopping after the first match. They can match:
- By address: "192.168.0.1"

View File

@@ -49,6 +49,7 @@ type LayeredProfile struct {
DomainHeuristics config.BoolOption `json:"-"`
UseSPN config.BoolOption `json:"-"`
SPNRoutingAlgorithm config.StringOption `json:"-"`
HistoryEnabled config.BoolOption `json:"-"`
}
// NewLayeredProfile returns a new layered profile based on the given local profile.
@@ -120,6 +121,10 @@ func NewLayeredProfile(localProfile *Profile) *LayeredProfile {
CfgOptionRoutingAlgorithmKey,
cfgOptionRoutingAlgorithm,
)
lp.HistoryEnabled = lp.wrapBoolOption(
CfgOptionEnableHistoryKey,
cfgOptionEnableHistory,
)
lp.LayerIDs = append(lp.LayerIDs, localProfile.ScopedID())
lp.layers = append(lp.layers, localProfile)

View File

@@ -136,6 +136,7 @@ type Profile struct { //nolint:maligned // not worth the effort
filterListIDs []string
spnUsagePolicy endpoints.Endpoints
spnExitHubPolicy endpoints.Endpoints
enableHistory bool
// Lifecycle Management
outdated *abool.AtomicBool
@@ -233,6 +234,11 @@ func (profile *Profile) parseConfig() error {
}
}
enableHistory, ok := profile.configPerspective.GetAsBool(CfgOptionEnableHistoryKey)
if ok {
profile.enableHistory = enableHistory
}
return lastErr
}
@@ -315,6 +321,11 @@ func (profile *Profile) IsOutdated() bool {
return profile.outdated.IsSet()
}
// HistoryEnabled returns true if connection history is enabled for the profile.
func (profile *Profile) HistoryEnabled() bool {
return profile.enableHistory
}
// GetEndpoints returns the endpoint list of the profile. This functions
// requires the profile to be read locked.
func (profile *Profile) GetEndpoints() endpoints.Endpoints {