wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
116
service/broadcasts/api.go
Normal file
116
service/broadcasts/api.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package broadcasts
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/accessor"
|
||||
)
|
||||
|
||||
func registerAPIEndpoints() error {
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: `broadcasts/matching-data`,
|
||||
Read: api.PermitAdmin,
|
||||
BelongsTo: module,
|
||||
StructFunc: handleMatchingData,
|
||||
Name: "Get Broadcast Notifications Matching Data",
|
||||
Description: "Returns the data used by the broadcast notifications to match the instance.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: `broadcasts/reset-state`,
|
||||
Write: api.PermitAdmin,
|
||||
WriteMethod: http.MethodPost,
|
||||
BelongsTo: module,
|
||||
ActionFunc: handleResetState,
|
||||
Name: "Resets the Broadcast Notification States",
|
||||
Description: "Delete the cache of Broadcast Notifications, making them appear again.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: `broadcasts/simulate`,
|
||||
Write: api.PermitAdmin,
|
||||
WriteMethod: http.MethodPost,
|
||||
BelongsTo: module,
|
||||
ActionFunc: handleSimulate,
|
||||
Name: "Simulate Broadcast Notifications",
|
||||
Description: "Test broadcast notifications by sending a valid source file in the body.",
|
||||
Parameters: []api.Parameter{
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Field: "state",
|
||||
Value: "true",
|
||||
Description: "Check against state when deciding to display a broadcast notification. Acknowledgements are always saved.",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleMatchingData(ar *api.Request) (i interface{}, err error) {
|
||||
return collectData(), nil
|
||||
}
|
||||
|
||||
func handleResetState(ar *api.Request) (msg string, err error) {
|
||||
err = db.Delete(broadcastStatesDBKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "Reset complete.", nil
|
||||
}
|
||||
|
||||
func handleSimulate(ar *api.Request) (msg string, err error) {
|
||||
// Parse broadcast notification data.
|
||||
broadcasts, err := parseBroadcastSource(ar.InputData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse broadcast notifications update: %w", err)
|
||||
}
|
||||
|
||||
// Get and marshal matching data.
|
||||
matchingData := collectData()
|
||||
matchingJSON, err := json.Marshal(matchingData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal broadcast notifications matching data: %w", err)
|
||||
}
|
||||
matchingDataAccessor := accessor.NewJSONBytesAccessor(&matchingJSON)
|
||||
|
||||
var bss *BroadcastStates
|
||||
if ar.URL.Query().Get("state") == "true" {
|
||||
// Get broadcast notification states.
|
||||
bss, err = getBroadcastStates()
|
||||
if err != nil {
|
||||
if !errors.Is(err, database.ErrNotFound) {
|
||||
return "", fmt.Errorf("failed to get broadcast notifications states: %w", err)
|
||||
}
|
||||
bss = newBroadcastStates()
|
||||
}
|
||||
}
|
||||
|
||||
// Go through all broadcast nofications and check if they match.
|
||||
var results []string
|
||||
for _, bn := range broadcasts.Notifications {
|
||||
err := handleBroadcast(bn, matchingDataAccessor, bss)
|
||||
switch {
|
||||
case err == nil:
|
||||
results = append(results, fmt.Sprintf("%30s: displayed", bn.id))
|
||||
case errors.Is(err, ErrSkip):
|
||||
results = append(results, fmt.Sprintf("%30s: %s", bn.id, err))
|
||||
default:
|
||||
results = append(results, fmt.Sprintf("FAILED %23s: %s", bn.id, err))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n"), nil
|
||||
}
|
||||
141
service/broadcasts/data.go
Normal file
141
service/broadcasts/data.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package broadcasts
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portmaster/service/intel/geoip"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
"github.com/safing/portmaster/spn/access"
|
||||
"github.com/safing/portmaster/spn/access/account"
|
||||
"github.com/safing/portmaster/spn/captain"
|
||||
)
|
||||
|
||||
var portmasterStarted = time.Now()
|
||||
|
||||
func collectData() interface{} {
|
||||
data := make(map[string]interface{})
|
||||
|
||||
// Get data about versions.
|
||||
versions := updates.GetSimpleVersions()
|
||||
data["Updates"] = versions
|
||||
data["Version"] = versions.Build.Version
|
||||
numericVersion, err := MakeNumericVersion(versions.Build.Version)
|
||||
if err != nil {
|
||||
data["NumericVersion"] = &DataError{
|
||||
Error: err,
|
||||
}
|
||||
} else {
|
||||
data["NumericVersion"] = numericVersion
|
||||
}
|
||||
|
||||
// Get data about install.
|
||||
installInfo, err := GetInstallInfo()
|
||||
if err != nil {
|
||||
data["Install"] = &DataError{
|
||||
Error: err,
|
||||
}
|
||||
} else {
|
||||
data["Install"] = installInfo
|
||||
}
|
||||
|
||||
// Get global configuration.
|
||||
data["Config"] = config.GetActiveConfigValues()
|
||||
|
||||
// Get data about device location.
|
||||
locs, ok := netenv.GetInternetLocation()
|
||||
if ok && locs.Best().LocationOrNil() != nil {
|
||||
loc := locs.Best()
|
||||
data["Location"] = &Location{
|
||||
Country: loc.Location.Country.Code,
|
||||
Coordinates: loc.Location.Coordinates,
|
||||
ASN: loc.Location.AutonomousSystemNumber,
|
||||
ASOrg: loc.Location.AutonomousSystemOrganization,
|
||||
Source: loc.Source,
|
||||
SourceAccuracy: loc.SourceAccuracy,
|
||||
}
|
||||
}
|
||||
|
||||
// Get data about SPN status.
|
||||
data["SPN"] = captain.GetSPNStatus()
|
||||
|
||||
// Get data about account.
|
||||
userRecord, err := access.GetUser()
|
||||
if err != nil {
|
||||
data["Account"] = &DataError{
|
||||
Error: err,
|
||||
}
|
||||
} else {
|
||||
account := &Account{
|
||||
UserRecord: userRecord,
|
||||
Active: userRecord.MayUse(""),
|
||||
UpToDate: userRecord.Meta().Modified > time.Now().Add(-7*24*time.Hour).Unix(),
|
||||
}
|
||||
// Only add feature IDs when account is active.
|
||||
if account.Active {
|
||||
account.FeatureIDs = userRecord.CurrentPlan.FeatureIDs
|
||||
}
|
||||
data["Account"] = account
|
||||
}
|
||||
|
||||
// Time running.
|
||||
data["UptimeHours"] = int(time.Since(portmasterStarted).Hours())
|
||||
|
||||
// Get current time and date.
|
||||
now := time.Now()
|
||||
data["Current"] = &Current{
|
||||
UnixTime: now.Unix(),
|
||||
UTC: makeDateTimeInfo(now.UTC()),
|
||||
Local: makeDateTimeInfo(now),
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// Location holds location matching data.
|
||||
type Location struct {
|
||||
Country string
|
||||
Coordinates geoip.Coordinates
|
||||
ASN uint
|
||||
ASOrg string
|
||||
Source netenv.DeviceLocationSource
|
||||
SourceAccuracy int
|
||||
}
|
||||
|
||||
// Account holds SPN account matching data.
|
||||
type Account struct {
|
||||
*access.UserRecord
|
||||
Active bool
|
||||
UpToDate bool
|
||||
FeatureIDs []account.FeatureID
|
||||
}
|
||||
|
||||
// DataError represents an error getting some matching data.
|
||||
type DataError struct {
|
||||
Error error
|
||||
}
|
||||
|
||||
// Current holds current date and time data.
|
||||
type Current struct {
|
||||
UnixTime int64
|
||||
UTC *DateTime
|
||||
Local *DateTime
|
||||
}
|
||||
|
||||
// DateTime holds date and time data in different formats.
|
||||
type DateTime struct {
|
||||
NumericDateTime int64
|
||||
NumericDate int64
|
||||
NumericTime int64
|
||||
}
|
||||
|
||||
func makeDateTimeInfo(t time.Time) *DateTime {
|
||||
info := &DateTime{}
|
||||
info.NumericDateTime, _ = strconv.ParseInt(t.Format("20060102150405"), 10, 64)
|
||||
info.NumericDate, _ = strconv.ParseInt(t.Format("20060102"), 10, 64)
|
||||
info.NumericTime, _ = strconv.ParseInt(t.Format("150405"), 10, 64)
|
||||
|
||||
return info
|
||||
}
|
||||
175
service/broadcasts/install_info.go
Normal file
175
service/broadcasts/install_info.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package broadcasts
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
semver "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/info"
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
const installInfoDBKey = "core:status/install-info"
|
||||
|
||||
// InstallInfo holds generic info about the install.
|
||||
type InstallInfo struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Version string
|
||||
NumericVersion int64
|
||||
|
||||
Time time.Time
|
||||
NumericDate int64
|
||||
DaysSinceInstall int64
|
||||
UnixTimestamp int64
|
||||
}
|
||||
|
||||
// GetInstallInfo returns the install info from the database.
|
||||
func GetInstallInfo() (*InstallInfo, error) {
|
||||
r, err := db.Get(installInfoDBKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unwrap.
|
||||
if r.IsWrapped() {
|
||||
// Only allocate a new struct, if we need it.
|
||||
newRecord := &InstallInfo{}
|
||||
err = record.Unwrap(r, newRecord)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newRecord, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
newRecord, ok := r.(*InstallInfo)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *InstallInfo, but %T", r)
|
||||
}
|
||||
return newRecord, nil
|
||||
}
|
||||
|
||||
func ensureInstallInfo() {
|
||||
// Get current install info from database.
|
||||
installInfo, err := GetInstallInfo()
|
||||
if err != nil {
|
||||
installInfo = &InstallInfo{}
|
||||
if !errors.Is(err, database.ErrNotFound) {
|
||||
log.Warningf("updates: failed to load install info: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fill in missing data and save.
|
||||
installInfo.checkAll()
|
||||
if err := installInfo.save(); err != nil {
|
||||
log.Warningf("updates: failed to save install info: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ii *InstallInfo) save() error {
|
||||
if !ii.KeyIsSet() {
|
||||
ii.SetKey(installInfoDBKey)
|
||||
}
|
||||
return db.Put(ii)
|
||||
}
|
||||
|
||||
func (ii *InstallInfo) checkAll() {
|
||||
ii.checkVersion()
|
||||
ii.checkInstallDate()
|
||||
}
|
||||
|
||||
func (ii *InstallInfo) checkVersion() {
|
||||
// Check if everything is present.
|
||||
if ii.Version != "" && ii.NumericVersion > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Update version information.
|
||||
versionInfo := info.GetInfo()
|
||||
ii.Version = versionInfo.Version
|
||||
|
||||
// Update numeric version.
|
||||
if versionInfo.Version != "" {
|
||||
numericVersion, err := MakeNumericVersion(versionInfo.Version)
|
||||
if err != nil {
|
||||
log.Warningf("updates: failed to make numeric version: %s", err)
|
||||
} else {
|
||||
ii.NumericVersion = numericVersion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MakeNumericVersion makes a numeric version with the first three version
|
||||
// segment always using three digits.
|
||||
func MakeNumericVersion(version string) (numericVersion int64, err error) {
|
||||
// Parse version string.
|
||||
ver, err := semver.NewVersion(version)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse core version: %w", err)
|
||||
}
|
||||
|
||||
// Transform version for numeric representation.
|
||||
segments := ver.Segments()
|
||||
for i := 0; i < 3 && i < len(segments); i++ {
|
||||
segmentNumber := int64(segments[i])
|
||||
if segmentNumber > 999 {
|
||||
segmentNumber = 999
|
||||
}
|
||||
switch i {
|
||||
case 0:
|
||||
numericVersion += segmentNumber * 1000000
|
||||
case 1:
|
||||
numericVersion += segmentNumber * 1000
|
||||
case 2:
|
||||
numericVersion += segmentNumber
|
||||
}
|
||||
}
|
||||
|
||||
return numericVersion, nil
|
||||
}
|
||||
|
||||
func (ii *InstallInfo) checkInstallDate() {
|
||||
// Check if everything is present.
|
||||
if ii.UnixTimestamp > 0 &&
|
||||
ii.NumericDate > 0 &&
|
||||
ii.DaysSinceInstall > 0 &&
|
||||
!ii.Time.IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
// Find oldest created database entry and use it as install time.
|
||||
oldest := time.Now().Unix()
|
||||
it, err := db.Query(query.New("core"))
|
||||
if err != nil {
|
||||
log.Warningf("updates: failed to create iterator for searching DB for install time: %s", err)
|
||||
return
|
||||
}
|
||||
defer it.Cancel()
|
||||
for r := range it.Next {
|
||||
if oldest > r.Meta().Created {
|
||||
oldest = r.Meta().Created
|
||||
}
|
||||
}
|
||||
|
||||
// Set data.
|
||||
ii.UnixTimestamp = oldest
|
||||
ii.Time = time.Unix(oldest, 0)
|
||||
ii.DaysSinceInstall = int64(time.Since(ii.Time).Hours()) / 24
|
||||
|
||||
// Transform date for numeric representation.
|
||||
numericDate, err := strconv.ParseInt(ii.Time.Format("20060102"), 10, 64)
|
||||
if err != nil {
|
||||
log.Warningf("updates: failed to make numeric date from %s: %s", ii.Time, err)
|
||||
} else {
|
||||
ii.NumericDate = numericDate
|
||||
}
|
||||
}
|
||||
46
service/broadcasts/module.go
Normal file
46
service/broadcasts/module.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package broadcasts
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
db = database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
startOnce sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("broadcasts", prep, start, nil, "updates", "netenv", "notifications")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
// Register API endpoints.
|
||||
if err := registerAPIEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
// Ensure the install info is up to date.
|
||||
ensureInstallInfo()
|
||||
|
||||
// Start broadcast notifier task.
|
||||
startOnce.Do(func() {
|
||||
module.NewTask("broadcast notifier", broadcastNotify).
|
||||
Repeat(10 * time.Minute).Queue()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
286
service/broadcasts/notify.go
Normal file
286
service/broadcasts/notify.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package broadcasts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ghodss/yaml"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/accessor"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
)
|
||||
|
||||
const (
|
||||
broadcastsResourcePath = "intel/portmaster/notifications.yaml"
|
||||
|
||||
broadcastNotificationIDPrefix = "broadcasts:"
|
||||
|
||||
minRepeatDuration = 1 * time.Hour
|
||||
)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrSkip = errors.New("broadcast skipped")
|
||||
ErrSkipDoesNotMatch = fmt.Errorf("%w: does not match", ErrSkip)
|
||||
ErrSkipAlreadyActive = fmt.Errorf("%w: already active", ErrSkip)
|
||||
ErrSkipAlreadyShown = fmt.Errorf("%w: already shown", ErrSkip)
|
||||
ErrSkipRemovedByMismatch = fmt.Errorf("%w: removed due to mismatch", ErrSkip)
|
||||
ErrSkipRemovedBySource = fmt.Errorf("%w: removed by source", ErrSkip)
|
||||
)
|
||||
|
||||
// BroadcastNotifications holds the data structure of the broadcast
|
||||
// notifications update file.
|
||||
type BroadcastNotifications struct {
|
||||
Notifications map[string]*BroadcastNotification
|
||||
}
|
||||
|
||||
// BroadcastNotification is a single broadcast notification.
|
||||
type BroadcastNotification struct {
|
||||
*notifications.Notification
|
||||
id string
|
||||
|
||||
// Match holds a query string that needs to match the local matching data in
|
||||
// order for the broadcast to be displayed.
|
||||
Match string
|
||||
matchingQuery *query.Query
|
||||
// AttachToModule signifies if the broadcast notification should be attached to the module.
|
||||
AttachToModule bool
|
||||
// Remove signifies that the broadcast should be canceled and its state removed.
|
||||
Remove bool
|
||||
// Permanent signifies that the broadcast cannot be acknowledge by the user
|
||||
// and remains in the UI indefinitely.
|
||||
Permanent bool
|
||||
// Repeat specifies a duration after which the broadcast should be shown again.
|
||||
Repeat string
|
||||
repeatDuration time.Duration
|
||||
}
|
||||
|
||||
func broadcastNotify(ctx context.Context, t *modules.Task) error {
|
||||
// Get broadcast notifications file, load it from disk and parse it.
|
||||
broadcastsResource, err := updates.GetFile(broadcastsResourcePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get broadcast notifications update: %w", err)
|
||||
}
|
||||
broadcastsData, err := os.ReadFile(broadcastsResource.Path())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load broadcast notifications update: %w", err)
|
||||
}
|
||||
broadcasts, err := parseBroadcastSource(broadcastsData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse broadcast notifications update: %w", err)
|
||||
}
|
||||
|
||||
// Get and marshal matching data.
|
||||
matchingData := collectData()
|
||||
matchingJSON, err := json.Marshal(matchingData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal broadcast notifications matching data: %w", err)
|
||||
}
|
||||
matchingDataAccessor := accessor.NewJSONBytesAccessor(&matchingJSON)
|
||||
|
||||
// Get broadcast notification states.
|
||||
bss, err := getBroadcastStates()
|
||||
if err != nil {
|
||||
if !errors.Is(err, database.ErrNotFound) {
|
||||
return fmt.Errorf("failed to get broadcast notifications states: %w", err)
|
||||
}
|
||||
bss = newBroadcastStates()
|
||||
}
|
||||
|
||||
// Go through all broadcast nofications and check if they match.
|
||||
for _, bn := range broadcasts.Notifications {
|
||||
err := handleBroadcast(bn, matchingDataAccessor, bss)
|
||||
switch {
|
||||
case err == nil:
|
||||
log.Infof("broadcasts: displaying broadcast %s", bn.id)
|
||||
case errors.Is(err, ErrSkip):
|
||||
log.Tracef("broadcasts: skipped displaying broadcast %s: %s", bn.id, err)
|
||||
default:
|
||||
log.Warningf("broadcasts: failed to handle broadcast %s: %s", bn.id, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseBroadcastSource(yamlData []byte) (*BroadcastNotifications, error) {
|
||||
// Parse data.
|
||||
broadcasts := &BroadcastNotifications{}
|
||||
err := yaml.Unmarshal(yamlData, broadcasts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add IDs to struct for easier handling.
|
||||
for id, bn := range broadcasts.Notifications {
|
||||
bn.id = id
|
||||
|
||||
// Parse matching query.
|
||||
if bn.Match != "" {
|
||||
q, err := query.ParseQuery("query / where " + bn.Match)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse query of broadcast notification %s: %w", bn.id, err)
|
||||
}
|
||||
bn.matchingQuery = q
|
||||
}
|
||||
|
||||
// Parse the repeat duration.
|
||||
if bn.Repeat != "" {
|
||||
duration, err := time.ParseDuration(bn.Repeat)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse repeat duration of broadcast notification %s: %w", bn.id, err)
|
||||
}
|
||||
bn.repeatDuration = duration
|
||||
// Raise duration to minimum.
|
||||
if bn.repeatDuration < minRepeatDuration {
|
||||
bn.repeatDuration = minRepeatDuration
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return broadcasts, nil
|
||||
}
|
||||
|
||||
func handleBroadcast(bn *BroadcastNotification, matchingDataAccessor accessor.Accessor, bss *BroadcastStates) error {
|
||||
// Check if broadcast was already shown.
|
||||
if bss != nil {
|
||||
state, ok := bss.States[bn.id]
|
||||
switch {
|
||||
case !ok || state.Read.IsZero():
|
||||
// Was never shown, continue.
|
||||
case bn.repeatDuration == 0:
|
||||
// Was already shown and is not repeated, skip.
|
||||
return ErrSkipAlreadyShown
|
||||
case time.Now().Before(state.Read.Add(bn.repeatDuration)):
|
||||
// Was already shown and should be repeated - but not yet, skip.
|
||||
return ErrSkipAlreadyShown
|
||||
}
|
||||
}
|
||||
|
||||
// Check if broadcast should be removed.
|
||||
if bn.Remove {
|
||||
removeBroadcast(bn, bss)
|
||||
return ErrSkipRemovedBySource
|
||||
}
|
||||
|
||||
// Skip if broadcast does not match.
|
||||
if bn.matchingQuery != nil && !bn.matchingQuery.MatchesAccessor(matchingDataAccessor) {
|
||||
removed := removeBroadcast(bn, bss)
|
||||
if removed {
|
||||
return ErrSkipRemovedByMismatch
|
||||
}
|
||||
return ErrSkipDoesNotMatch
|
||||
}
|
||||
|
||||
// Check if there is already an active notification for this.
|
||||
eventID := broadcastNotificationIDPrefix + bn.id
|
||||
n := notifications.Get(eventID)
|
||||
if n != nil {
|
||||
// Already active!
|
||||
return ErrSkipAlreadyActive
|
||||
}
|
||||
|
||||
// Prepare notification for displaying.
|
||||
n = bn.Notification
|
||||
n.EventID = eventID
|
||||
n.GUID = ""
|
||||
n.State = ""
|
||||
n.SelectedActionID = ""
|
||||
|
||||
// It is okay to edit the notification, as they are loaded from the file every time.
|
||||
// Add dismiss button if the notification is not permanent.
|
||||
if !bn.Permanent {
|
||||
n.AvailableActions = append(n.AvailableActions, ¬ifications.Action{
|
||||
ID: "ack",
|
||||
Text: "Got it!",
|
||||
})
|
||||
}
|
||||
n.SetActionFunction(markBroadcastAsRead)
|
||||
|
||||
// Display notification.
|
||||
n.Save()
|
||||
|
||||
// Attach to module to raise more awareness.
|
||||
if bn.AttachToModule {
|
||||
n.AttachToModule(module)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeBroadcast(bn *BroadcastNotification, bss *BroadcastStates) (removed bool) {
|
||||
// Remove any active notification.
|
||||
n := notifications.Get(broadcastNotificationIDPrefix + bn.id)
|
||||
if n != nil {
|
||||
removed = true
|
||||
n.Delete()
|
||||
}
|
||||
|
||||
// Remove any state.
|
||||
if bss != nil {
|
||||
delete(bss.States, bn.id)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var savingBroadcastStateLock sync.Mutex
|
||||
|
||||
func markBroadcastAsRead(ctx context.Context, n *notifications.Notification) error {
|
||||
// Lock persisting broadcast state.
|
||||
savingBroadcastStateLock.Lock()
|
||||
defer savingBroadcastStateLock.Unlock()
|
||||
|
||||
// Get notification data.
|
||||
var broadcastID, actionID string
|
||||
func() {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
broadcastID = strings.TrimPrefix(n.EventID, broadcastNotificationIDPrefix)
|
||||
actionID = n.SelectedActionID
|
||||
}()
|
||||
|
||||
// Check response.
|
||||
switch actionID {
|
||||
case "ack":
|
||||
case "":
|
||||
return fmt.Errorf("no action ID for %s", broadcastID)
|
||||
default:
|
||||
return fmt.Errorf("unexpected action ID for %s: %s", broadcastID, actionID)
|
||||
}
|
||||
|
||||
// Get broadcast notification states.
|
||||
bss, err := getBroadcastStates()
|
||||
if err != nil {
|
||||
if !errors.Is(err, database.ErrNotFound) {
|
||||
return fmt.Errorf("failed to get broadcast notifications states: %w", err)
|
||||
}
|
||||
bss = newBroadcastStates()
|
||||
}
|
||||
|
||||
// Get state for this notification.
|
||||
bs, ok := bss.States[broadcastID]
|
||||
if !ok {
|
||||
bs = &BroadcastState{}
|
||||
bss.States[broadcastID] = bs
|
||||
}
|
||||
|
||||
// Delete to allow for timely repeats.
|
||||
n.Delete()
|
||||
|
||||
// Mark as read and save to DB.
|
||||
log.Infof("broadcasts: user acknowledged broadcast %s", broadcastID)
|
||||
bs.Read = time.Now()
|
||||
return bss.save()
|
||||
}
|
||||
64
service/broadcasts/state.go
Normal file
64
service/broadcasts/state.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package broadcasts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
)
|
||||
|
||||
const broadcastStatesDBKey = "core:broadcasts/state"
|
||||
|
||||
// BroadcastStates holds states for broadcast notifications.
|
||||
type BroadcastStates struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
States map[string]*BroadcastState
|
||||
}
|
||||
|
||||
// BroadcastState holds state for a single broadcast notifications.
|
||||
type BroadcastState struct {
|
||||
Read time.Time
|
||||
}
|
||||
|
||||
func (bss *BroadcastStates) save() error {
|
||||
return db.Put(bss)
|
||||
}
|
||||
|
||||
// getbroadcastStates returns the broadcast states from the database.
|
||||
func getBroadcastStates() (*BroadcastStates, error) {
|
||||
r, err := db.Get(broadcastStatesDBKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unwrap.
|
||||
if r.IsWrapped() {
|
||||
// Only allocate a new struct, if we need it.
|
||||
newRecord := &BroadcastStates{}
|
||||
err = record.Unwrap(r, newRecord)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newRecord, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
newRecord, ok := r.(*BroadcastStates)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *BroadcastStates, but %T", r)
|
||||
}
|
||||
return newRecord, nil
|
||||
}
|
||||
|
||||
// newBroadcastStates returns a new BroadcastStates.
|
||||
func newBroadcastStates() *BroadcastStates {
|
||||
bss := &BroadcastStates{
|
||||
States: make(map[string]*BroadcastState),
|
||||
}
|
||||
bss.SetKey(broadcastStatesDBKey)
|
||||
|
||||
return bss
|
||||
}
|
||||
12
service/broadcasts/testdata/README.md
vendored
Normal file
12
service/broadcasts/testdata/README.md
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
# Testing Broadcast Notifications
|
||||
|
||||
```
|
||||
# View matching data
|
||||
curl http://127.0.0.1:817/api/v1/broadcasts/matching-data
|
||||
|
||||
# Reset state
|
||||
curl -X POST http://127.0.0.1:817/api/v1/broadcasts/reset-state
|
||||
|
||||
# Simulate notifications
|
||||
curl --upload-file notifications.yaml http://127.0.0.1:817/api/v1/broadcasts/simulate
|
||||
```
|
||||
26
service/broadcasts/testdata/notifications.yaml
vendored
Normal file
26
service/broadcasts/testdata/notifications.yaml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
notifications:
|
||||
# test1:
|
||||
# title: "[TEST] Normal Broadcast"
|
||||
# message: "This is a normal broadcast without matching. (#1)"
|
||||
# test2:
|
||||
# title: "[TEST] Permanent Broadcast"
|
||||
# message: "This is a permanent broadcast without matching. (#2)"
|
||||
# type: 1 # Warning
|
||||
# permanent: true
|
||||
# test3:
|
||||
# title: "[TEST] Repeating Broadcast"
|
||||
# message: "This is a repeating broadcast without matching. (#3)"
|
||||
# repeat: "1m"
|
||||
# test4:
|
||||
# title: "[TEST] Matching Broadcast: PM version"
|
||||
# message: "This is a normal broadcast that matches the PM version. (#4)"
|
||||
# match: "NumericVersion > 8000"
|
||||
# test5:
|
||||
# title: "[TEST] Important Update"
|
||||
# message: "A criticial update has been released, please update immediately. (#5)"
|
||||
# type: 3 # Error
|
||||
# attachToModule: true
|
||||
new-network-monitor-feedback:
|
||||
title: "New: Full-Device Network Activity"
|
||||
message: "We re-imagined the network monitor to bring you better search and graphs. Thanks for being a Beta Tester! We would love to hear your feedback on [Twitter](https://twitter.com/SafingIO), [Reddit](https://www.reddit.com/r/safing), [GitHub](https://github.com/safing/portmaster-ui/issues/new/choose) or [via Email](mailto:support@safing.io)."
|
||||
match: "Version sameas 0.9.0 and Updates.Channel sameas beta"
|
||||
29
service/compat/api.go
Normal file
29
service/compat/api.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/api"
|
||||
)
|
||||
|
||||
func registerAPIEndpoints() error {
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "compat/self-check",
|
||||
Read: api.PermitUser,
|
||||
BelongsTo: module,
|
||||
ActionFunc: selfcheckViaAPI,
|
||||
Name: "Run Integration Self-Check",
|
||||
Description: "Runs a couple integration self-checks in order to see if the system integration works.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func selfcheckViaAPI(ar *api.Request) (msg string, err error) {
|
||||
_, err = selfcheck(ar.Context())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "self-check successful", nil
|
||||
}
|
||||
40
service/compat/callbacks.go
Normal file
40
service/compat/callbacks.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
)
|
||||
|
||||
// SubmitSystemIntegrationCheckPacket submit a packet for the system integrity check.
|
||||
func SubmitSystemIntegrationCheckPacket(p packet.Packet) {
|
||||
select {
|
||||
case systemIntegrationCheckPackets <- p:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// SubmitDNSCheckDomain submits a subdomain for the dns check.
|
||||
func SubmitDNSCheckDomain(subdomain string) (respondWith net.IP) {
|
||||
// Submit queried domain.
|
||||
select {
|
||||
case dnsCheckReceivedDomain <- subdomain:
|
||||
default:
|
||||
}
|
||||
|
||||
// Return the answer.
|
||||
dnsCheckAnswerLock.Lock()
|
||||
defer dnsCheckAnswerLock.Unlock()
|
||||
return dnsCheckAnswer
|
||||
}
|
||||
|
||||
// ReportSecureDNSBypassIssue reports a DNS bypassing issue for the given process.
|
||||
func ReportSecureDNSBypassIssue(p *process.Process) {
|
||||
secureDNSBypassIssue.notify(p)
|
||||
}
|
||||
|
||||
// ReportMultiPeerUDPTunnelIssue reports a multi-peer UDP tunnel for the given process.
|
||||
func ReportMultiPeerUDPTunnelIssue(p *process.Process) {
|
||||
multiPeerUDPTunnelIssue.notify(p)
|
||||
}
|
||||
10
service/compat/debug_default.go
Normal file
10
service/compat/debug_default.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !windows && !linux
|
||||
|
||||
package compat
|
||||
|
||||
import "github.com/safing/portbase/utils/debug"
|
||||
|
||||
// AddToDebugInfo adds compatibility data to the given debug.Info.
|
||||
func AddToDebugInfo(di *debug.Info) {
|
||||
// Not yet implemented on this platform.
|
||||
}
|
||||
28
service/compat/debug_linux.go
Normal file
28
service/compat/debug_linux.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portbase/utils/debug"
|
||||
)
|
||||
|
||||
// AddToDebugInfo adds compatibility data to the given debug.Info.
|
||||
func AddToDebugInfo(di *debug.Info) {
|
||||
// Get iptables state and add error info if it fails.
|
||||
chains, err := GetIPTablesChains()
|
||||
if err != nil {
|
||||
di.AddSection(
|
||||
"Compatibility: IPTables Chains (failed)",
|
||||
debug.UseCodeSection,
|
||||
err.Error(),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Add data as section.
|
||||
di.AddSection(
|
||||
fmt.Sprintf("Compatibility: IPTables Chains (%d)", len(chains)-10),
|
||||
debug.UseCodeSection|debug.AddContentLineBreaks,
|
||||
chains...,
|
||||
)
|
||||
}
|
||||
30
service/compat/debug_windows.go
Normal file
30
service/compat/debug_windows.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portbase/utils/debug"
|
||||
)
|
||||
|
||||
// AddToDebugInfo adds compatibility data to the given debug.Info.
|
||||
func AddToDebugInfo(di *debug.Info) {
|
||||
// Get WFP state and add error info if it fails.
|
||||
wfp, err := GetWFPState()
|
||||
if err != nil {
|
||||
di.AddSection(
|
||||
"Compatibility: WFP State (failed)",
|
||||
debug.UseCodeSection,
|
||||
err.Error(),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Add data as section.
|
||||
wfpTable := wfp.AsTable()
|
||||
di.AddSection(
|
||||
fmt.Sprintf("Compatibility: WFP State (%d)", strings.Count(wfpTable, "\n")),
|
||||
debug.UseCodeSection,
|
||||
wfpTable,
|
||||
)
|
||||
}
|
||||
60
service/compat/iptables.go
Normal file
60
service/compat/iptables.go
Normal file
@@ -0,0 +1,60 @@
|
||||
//go:build linux
|
||||
|
||||
package compat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
)
|
||||
|
||||
var (
|
||||
iptProtocols = []iptables.Protocol{
|
||||
iptables.ProtocolIPv4,
|
||||
iptables.ProtocolIPv6,
|
||||
}
|
||||
iptTables = []string{
|
||||
"filter",
|
||||
"nat",
|
||||
"mangle",
|
||||
"raw",
|
||||
}
|
||||
)
|
||||
|
||||
// GetIPTablesChains returns the chain names currently in ip(6)tables.
|
||||
func GetIPTablesChains() ([]string, error) {
|
||||
chains := make([]string, 0, 100)
|
||||
|
||||
// Iterate over protocols.
|
||||
for _, protocol := range iptProtocols {
|
||||
if protocol == iptables.ProtocolIPv4 {
|
||||
chains = append(chains, "v4")
|
||||
} else {
|
||||
chains = append(chains, "v6")
|
||||
}
|
||||
|
||||
// Get iptables access for protocol.
|
||||
tbls, err := iptables.NewWithProtocol(protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate over tables.
|
||||
for _, table := range iptTables {
|
||||
chains = append(chains, " "+table)
|
||||
|
||||
// Get chain names
|
||||
chainNames, err := tbls.ListChains(table)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get chains of table %s: %w", table, err)
|
||||
}
|
||||
|
||||
// Add chain names to list.
|
||||
for _, name := range chainNames {
|
||||
chains = append(chains, " "+name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chains, nil
|
||||
}
|
||||
24
service/compat/iptables_test.go
Normal file
24
service/compat/iptables_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build linux
|
||||
|
||||
package compat
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPTablesChains(t *testing.T) {
|
||||
// Skip in CI.
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
t.Parallel()
|
||||
|
||||
chain, err := GetIPTablesChains()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(chain) < 35 {
|
||||
t.Errorf("Expected at least 35 output lines, not %d", len(chain))
|
||||
}
|
||||
}
|
||||
137
service/compat/module.go
Normal file
137
service/compat/module.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
selfcheckTask *modules.Task
|
||||
selfcheckTaskRetryAfter = 15 * time.Second
|
||||
|
||||
// selfCheckIsFailing holds whether or not the self-check is currently
|
||||
// failing. This helps other failure systems to not make noise when there is
|
||||
// an underlying failure.
|
||||
selfCheckIsFailing = abool.New()
|
||||
|
||||
// selfcheckFails counts how often the self check failed successively.
|
||||
// selfcheckFails is not locked as it is only accessed by the self-check task.
|
||||
selfcheckFails int
|
||||
|
||||
// selfcheckNetworkChangedFlag is used to track changed to the network for
|
||||
// the self-check.
|
||||
selfcheckNetworkChangedFlag = netenv.GetNetworkChangedFlag()
|
||||
)
|
||||
|
||||
// selfcheckFailThreshold holds the threshold of how many times the selfcheck
|
||||
// must fail before it is reported.
|
||||
const selfcheckFailThreshold = 10
|
||||
|
||||
func init() {
|
||||
module = modules.Register("compat", prep, start, stop, "base", "network", "interception", "netenv", "notifications")
|
||||
|
||||
// Workaround resolver integration.
|
||||
// See resolver/compat.go for details.
|
||||
resolver.CompatDNSCheckInternalDomainScope = DNSCheckInternalDomainScope
|
||||
resolver.CompatSelfCheckIsFailing = SelfCheckIsFailing
|
||||
resolver.CompatSubmitDNSCheckDomain = SubmitDNSCheckDomain
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
return registerAPIEndpoints()
|
||||
}
|
||||
|
||||
func start() error {
|
||||
startNotify()
|
||||
|
||||
selfcheckNetworkChangedFlag.Refresh()
|
||||
selfcheckTask = module.NewTask("compatibility self-check", selfcheckTaskFunc).
|
||||
Repeat(5 * time.Minute).
|
||||
MaxDelay(selfcheckTaskRetryAfter).
|
||||
Schedule(time.Now().Add(selfcheckTaskRetryAfter))
|
||||
|
||||
module.NewTask("clean notify thresholds", cleanNotifyThreshold).
|
||||
Repeat(1 * time.Hour)
|
||||
|
||||
return module.RegisterEventHook(
|
||||
netenv.ModuleName,
|
||||
netenv.NetworkChangedEvent,
|
||||
"trigger compat self-check",
|
||||
func(_ context.Context, _ interface{}) error {
|
||||
selfcheckTask.Schedule(time.Now().Add(selfcheckTaskRetryAfter))
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
selfcheckTask.Cancel()
|
||||
selfcheckTask = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func selfcheckTaskFunc(ctx context.Context, task *modules.Task) error {
|
||||
// Create tracing logger.
|
||||
ctx, tracer := log.AddTracer(ctx)
|
||||
defer tracer.Submit()
|
||||
tracer.Tracef("compat: running self-check")
|
||||
|
||||
// Run selfcheck and return if successful.
|
||||
issue, err := selfcheck(ctx)
|
||||
switch {
|
||||
case err == nil:
|
||||
// Successful.
|
||||
tracer.Debugf("compat: self-check successful")
|
||||
case errors.Is(err, errSelfcheckSkipped):
|
||||
// Skipped.
|
||||
tracer.Debugf("compat: %s", err)
|
||||
case issue == nil:
|
||||
// Internal error.
|
||||
tracer.Warningf("compat: %s", err)
|
||||
case selfcheckNetworkChangedFlag.IsSet():
|
||||
// The network changed, ignore the issue.
|
||||
default:
|
||||
// The self-check failed.
|
||||
|
||||
// Set state and increase counter.
|
||||
selfCheckIsFailing.Set()
|
||||
selfcheckFails++
|
||||
|
||||
// Log and notify.
|
||||
tracer.Errorf("compat: %s", err)
|
||||
if selfcheckFails >= selfcheckFailThreshold {
|
||||
issue.notify(err)
|
||||
}
|
||||
|
||||
// Retry quicker when failed.
|
||||
task.Schedule(time.Now().Add(selfcheckTaskRetryAfter))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset self-check state.
|
||||
selfcheckNetworkChangedFlag.Refresh()
|
||||
selfCheckIsFailing.UnSet()
|
||||
selfcheckFails = 0
|
||||
resetSystemIssue()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SelfCheckIsFailing returns whether the self check is currently failing.
|
||||
// This returns true after the first check fails, and does not wait for the
|
||||
// failing threshold to be met.
|
||||
func SelfCheckIsFailing() bool {
|
||||
return selfCheckIsFailing.IsSet()
|
||||
}
|
||||
287
service/compat/notify.go
Normal file
287
service/compat/notify.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
)
|
||||
|
||||
type baseIssue struct {
|
||||
id string //nolint:structcheck // Inherited.
|
||||
title string //nolint:structcheck // Inherited.
|
||||
message string //nolint:structcheck // Inherited.
|
||||
level notifications.Type //nolint:structcheck // Inherited.
|
||||
actions []*notifications.Action //nolint:structcheck // Inherited.
|
||||
}
|
||||
|
||||
type systemIssue baseIssue
|
||||
|
||||
type appIssue baseIssue
|
||||
|
||||
var (
|
||||
// Copy of firewall.CfgOptionDNSQueryInterceptionKey.
|
||||
cfgOptionDNSQueryInterceptionKey = "filter/dnsQueryInterception"
|
||||
dnsQueryInterception config.BoolOption
|
||||
|
||||
systemIssueNotification *notifications.Notification
|
||||
systemIssueNotificationLock sync.Mutex
|
||||
|
||||
systemIntegrationIssue = &systemIssue{
|
||||
id: "compat:system-integration-issue",
|
||||
title: "Detected System Integration Issue",
|
||||
message: "Portmaster detected a problem with its system integration. You can try to restart or reinstall the Portmaster. If that does not help, [get support here](https://safing.io/support/).",
|
||||
level: notifications.Error,
|
||||
}
|
||||
systemCompatibilityIssue = &systemIssue{
|
||||
id: "compat:compatibility-issue",
|
||||
title: "Detected Compatibility Issue",
|
||||
message: "Portmaster detected that something is interfering with its operation. This could be a VPN, an Anti-Virus or another network protection software. Please check if you are running an incompatible [VPN client](https://docs.safing.io/portmaster/install/status/vpn-compatibility) or [software](https://docs.safing.io/portmaster/install/status/software-compatibility) and disable it. If that does not help, [get support here](https://safing.io/support/).",
|
||||
level: notifications.Error,
|
||||
}
|
||||
// manualDNSSetupRequired is additionally initialized in startNotify().
|
||||
manualDNSSetupRequired = &systemIssue{
|
||||
id: "compat:manual-dns-setup-required",
|
||||
title: "Manual DNS Setup Required",
|
||||
level: notifications.Error,
|
||||
actions: []*notifications.Action{
|
||||
{
|
||||
Text: "Revert",
|
||||
Type: notifications.ActionTypeOpenSetting,
|
||||
Payload: ¬ifications.ActionTypeOpenSettingPayload{
|
||||
Key: cfgOptionDNSQueryInterceptionKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
manualDNSSetupRequiredMessage = "You have disabled Seamless DNS Integration. As a result, Portmaster can no longer protect you or filter connections reliably. To fix this, you have to manually configure %s as the DNS Server in your system and in any conflicting application. This message will disappear some time after correct configuration."
|
||||
|
||||
secureDNSBypassIssue = &appIssue{
|
||||
id: "compat:secure-dns-bypass-%s",
|
||||
title: "Blocked Bypass Attempt by %s",
|
||||
message: `[APPNAME] is using its own Secure DNS resolver, which would bypass Portmaster's firewall protections. If [APPNAME] experiences problems, disable Secure DNS within [APPNAME] to restore functionality. Rest assured that Portmaster handles Secure DNS for your whole device, including [APPNAME].`,
|
||||
// TODO: Add this when the new docs page is finished:
|
||||
// , or [find out about other options](link to new docs page)
|
||||
level: notifications.Warning,
|
||||
}
|
||||
multiPeerUDPTunnelIssue = &appIssue{
|
||||
id: "compat:multi-peer-udp-tunnel-%s",
|
||||
title: "Detected SPN Incompatibility in %s",
|
||||
message: "Portmaster detected that [APPNAME] is trying to connect to multiple servers via the SPN using a single UDP connection. This is common for technologies such as torrents. Unfortunately, the SPN does not support this feature currently. You can try to change this behavior within the affected app or you could exempt it from using the SPN.",
|
||||
level: notifications.Warning,
|
||||
}
|
||||
)
|
||||
|
||||
func startNotify() {
|
||||
dnsQueryInterception = config.Concurrent.GetAsBool(cfgOptionDNSQueryInterceptionKey, true)
|
||||
|
||||
systemIssueNotificationLock.Lock()
|
||||
defer systemIssueNotificationLock.Unlock()
|
||||
|
||||
manualDNSSetupRequired.message = fmt.Sprintf(
|
||||
manualDNSSetupRequiredMessage,
|
||||
`"127.0.0.1"`,
|
||||
)
|
||||
}
|
||||
|
||||
// SetNameserverListenIP sets the IP address the nameserver is listening on.
|
||||
// The IP address is used in compatibility notifications.
|
||||
func SetNameserverListenIP(ip net.IP) {
|
||||
systemIssueNotificationLock.Lock()
|
||||
defer systemIssueNotificationLock.Unlock()
|
||||
|
||||
manualDNSSetupRequired.message = fmt.Sprintf(
|
||||
manualDNSSetupRequiredMessage,
|
||||
`"`+ip.String()+`"`,
|
||||
)
|
||||
}
|
||||
|
||||
func systemCompatOrManualDNSIssue() *systemIssue {
|
||||
if dnsQueryInterception() {
|
||||
return systemCompatibilityIssue
|
||||
}
|
||||
return manualDNSSetupRequired
|
||||
}
|
||||
|
||||
func (issue *systemIssue) notify(err error) {
|
||||
systemIssueNotificationLock.Lock()
|
||||
defer systemIssueNotificationLock.Unlock()
|
||||
|
||||
if systemIssueNotification != nil {
|
||||
// Ignore duplicate notification.
|
||||
if issue.id == systemIssueNotification.EventID {
|
||||
return
|
||||
}
|
||||
|
||||
// Remove old notification.
|
||||
systemIssueNotification.Delete()
|
||||
}
|
||||
|
||||
// Create new notification.
|
||||
n := ¬ifications.Notification{
|
||||
EventID: issue.id,
|
||||
Type: issue.level,
|
||||
Title: issue.title,
|
||||
Message: issue.message,
|
||||
ShowOnSystem: true,
|
||||
AvailableActions: issue.actions,
|
||||
}
|
||||
notifications.Notify(n)
|
||||
|
||||
systemIssueNotification = n
|
||||
n.AttachToModule(module)
|
||||
|
||||
// Report the raw error as module error.
|
||||
module.NewErrorMessage("selfcheck", err).Report()
|
||||
}
|
||||
|
||||
func resetSystemIssue() {
|
||||
systemIssueNotificationLock.Lock()
|
||||
defer systemIssueNotificationLock.Unlock()
|
||||
|
||||
if systemIssueNotification != nil {
|
||||
systemIssueNotification.Delete()
|
||||
}
|
||||
systemIssueNotification = nil
|
||||
}
|
||||
|
||||
func (issue *appIssue) notify(proc *process.Process) {
|
||||
// Get profile from process.
|
||||
p := proc.Profile().LocalProfile()
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Ignore notifications for unidentified processes.
|
||||
if p.ID == profile.UnidentifiedProfileID {
|
||||
return
|
||||
}
|
||||
|
||||
// Log warning.
|
||||
log.Warningf(
|
||||
"compat: detected %s issue with %s",
|
||||
strings.ReplaceAll(
|
||||
strings.TrimPrefix(
|
||||
strings.TrimSuffix(issue.id, "-%s"),
|
||||
"compat:",
|
||||
),
|
||||
"-", " ",
|
||||
),
|
||||
proc.Path,
|
||||
)
|
||||
|
||||
// Check if we already have this notification.
|
||||
eventID := fmt.Sprintf(issue.id, p.ID)
|
||||
n := notifications.Get(eventID)
|
||||
if n != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we reach the threshold to actually send a notification.
|
||||
if !isOverThreshold(eventID) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build message.
|
||||
message := strings.ReplaceAll(issue.message, "[APPNAME]", p.Name)
|
||||
|
||||
// Create a new notification.
|
||||
n = ¬ifications.Notification{
|
||||
EventID: eventID,
|
||||
Type: issue.level,
|
||||
Title: fmt.Sprintf(issue.title, p.Name),
|
||||
Message: message,
|
||||
ShowOnSystem: true,
|
||||
AvailableActions: issue.actions,
|
||||
}
|
||||
if len(n.AvailableActions) == 0 {
|
||||
n.AvailableActions = []*notifications.Action{
|
||||
{
|
||||
ID: "ack",
|
||||
Text: "OK",
|
||||
},
|
||||
}
|
||||
}
|
||||
notifications.Notify(n)
|
||||
|
||||
// Set warning on profile.
|
||||
module.StartWorker("set app compat warning", func(ctx context.Context) error {
|
||||
var changed bool
|
||||
|
||||
func() {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if p.Warning != message || time.Now().Add(-1*time.Hour).After(p.WarningLastUpdated) {
|
||||
p.Warning = message
|
||||
p.WarningLastUpdated = time.Now()
|
||||
changed = true
|
||||
}
|
||||
}()
|
||||
|
||||
if changed {
|
||||
return p.Save()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
const (
|
||||
notifyThresholdMinIncidents = 10
|
||||
notifyThresholdResetAfter = 2 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
notifyThresholds = make(map[string]*notifyThreshold)
|
||||
notifyThresholdsLock sync.Mutex
|
||||
)
|
||||
|
||||
type notifyThreshold struct {
|
||||
FirstSeen time.Time
|
||||
Incidents uint
|
||||
}
|
||||
|
||||
func (nt *notifyThreshold) expired() bool {
|
||||
return time.Now().Add(-notifyThresholdResetAfter).After(nt.FirstSeen)
|
||||
}
|
||||
|
||||
func isOverThreshold(id string) bool {
|
||||
notifyThresholdsLock.Lock()
|
||||
defer notifyThresholdsLock.Unlock()
|
||||
|
||||
// Get notify threshold and check if we reach the minimum incidents.
|
||||
nt, ok := notifyThresholds[id]
|
||||
if ok && !nt.expired() {
|
||||
nt.Incidents++
|
||||
return nt.Incidents >= notifyThresholdMinIncidents
|
||||
}
|
||||
|
||||
// Add new entry.
|
||||
notifyThresholds[id] = ¬ifyThreshold{
|
||||
FirstSeen: time.Now(),
|
||||
Incidents: 1,
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cleanNotifyThreshold(ctx context.Context, task *modules.Task) error {
|
||||
notifyThresholdsLock.Lock()
|
||||
defer notifyThresholdsLock.Unlock()
|
||||
|
||||
for id, nt := range notifyThresholds {
|
||||
if nt.expired() {
|
||||
delete(notifyThresholds, id)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
205
service/compat/selfcheck.go
Normal file
205
service/compat/selfcheck.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
)
|
||||
|
||||
var (
|
||||
selfcheckLock sync.Mutex
|
||||
|
||||
// SystemIntegrationCheckDstIP is the IP address to send a packet to for the
|
||||
// system integration test.
|
||||
SystemIntegrationCheckDstIP = net.IPv4(127, 65, 67, 75)
|
||||
// SystemIntegrationCheckProtocol is the IP protocol to use for the system
|
||||
// integration test.
|
||||
SystemIntegrationCheckProtocol = packet.AnyHostInternalProtocol61
|
||||
|
||||
systemIntegrationCheckDialNet = fmt.Sprintf("ip4:%d", uint8(SystemIntegrationCheckProtocol))
|
||||
systemIntegrationCheckDialIP = SystemIntegrationCheckDstIP.String()
|
||||
systemIntegrationCheckPackets = make(chan packet.Packet, 1)
|
||||
systemIntegrationCheckWaitDuration = 45 * time.Second
|
||||
|
||||
// DNSCheckInternalDomainScope is the domain scope to use for dns checks.
|
||||
DNSCheckInternalDomainScope = ".self-check." + resolver.InternalSpecialUseDomain
|
||||
dnsCheckReceivedDomain = make(chan string, 1)
|
||||
dnsCheckWaitDuration = 45 * time.Second
|
||||
dnsCheckAnswerLock sync.Mutex
|
||||
dnsCheckAnswer net.IP
|
||||
|
||||
errSelfcheckSkipped = errors.New("self-check skipped")
|
||||
)
|
||||
|
||||
func selfcheck(ctx context.Context) (issue *systemIssue, err error) {
|
||||
selfcheckLock.Lock()
|
||||
defer selfcheckLock.Unlock()
|
||||
|
||||
// Step 0: Check if self-check makes sense.
|
||||
if !netenv.Online() {
|
||||
return nil, fmt.Errorf("%w: device is offline or in limited network", errSelfcheckSkipped)
|
||||
}
|
||||
|
||||
// Step 1: Check if the system integration sees a packet.
|
||||
|
||||
// Empty recv channel.
|
||||
select {
|
||||
case <-systemIntegrationCheckPackets:
|
||||
case <-ctx.Done():
|
||||
return nil, context.Canceled
|
||||
default:
|
||||
}
|
||||
|
||||
// Send packet.
|
||||
conn, err := net.DialTimeout(
|
||||
systemIntegrationCheckDialNet,
|
||||
systemIntegrationCheckDialIP,
|
||||
time.Second,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system integration conn: %w", err)
|
||||
}
|
||||
_, err = conn.Write([]byte("PORTMASTER SELF CHECK"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send system integration packet: %w", err)
|
||||
}
|
||||
|
||||
// Wait for packet.
|
||||
select {
|
||||
case <-systemIntegrationCheckPackets:
|
||||
// Check passed!
|
||||
log.Tracer(ctx).Tracef("compat: self-check #1: system integration check passed")
|
||||
case <-time.After(systemIntegrationCheckWaitDuration):
|
||||
return systemIntegrationIssue, fmt.Errorf("self-check #1: system integration check failed: did not receive test packet after %s", systemIntegrationCheckWaitDuration)
|
||||
case <-ctx.Done():
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Step 2: Check if a DNS request arrives at the nameserver
|
||||
// This step necessary also includes some setup for step 3.
|
||||
|
||||
// Generate random subdomain.
|
||||
randomSubdomainBytes, err := rng.Bytes(16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("self-check #2: failed to get random bytes for subdomain check: %w", err)
|
||||
}
|
||||
randomSubdomain := "a" + strings.ToLower(hex.EncodeToString(randomSubdomainBytes)) + "b"
|
||||
|
||||
// Generate random answer.
|
||||
var B, C, D uint64
|
||||
B, err = rng.Number(255)
|
||||
if err == nil {
|
||||
C, err = rng.Number(255)
|
||||
}
|
||||
if err == nil {
|
||||
D, err = rng.Number(255)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("self-check #2: failed to get random number for subdomain check response: %w", err)
|
||||
}
|
||||
randomAnswer := net.IPv4(127, byte(B), byte(C), byte(D))
|
||||
func() {
|
||||
dnsCheckAnswerLock.Lock()
|
||||
defer dnsCheckAnswerLock.Unlock()
|
||||
dnsCheckAnswer = randomAnswer
|
||||
}()
|
||||
|
||||
// Setup variables for lookup worker.
|
||||
var (
|
||||
dnsCheckReturnedIP net.IP
|
||||
dnsCheckLookupError = make(chan error)
|
||||
)
|
||||
|
||||
// Empty recv channel.
|
||||
select {
|
||||
case <-dnsCheckReceivedDomain:
|
||||
case <-ctx.Done():
|
||||
return nil, context.Canceled
|
||||
default:
|
||||
}
|
||||
|
||||
// Start worker for the DNS lookup.
|
||||
module.StartWorker("dns check lookup", func(_ context.Context) error {
|
||||
ips, err := net.LookupIP(randomSubdomain + DNSCheckInternalDomainScope)
|
||||
if err == nil && len(ips) > 0 {
|
||||
dnsCheckReturnedIP = ips[0]
|
||||
}
|
||||
select {
|
||||
case dnsCheckLookupError <- err:
|
||||
case <-time.After(dnsCheckWaitDuration * 2):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Wait for the resolver to receive the query.
|
||||
select {
|
||||
case receivedTestDomain := <-dnsCheckReceivedDomain:
|
||||
if receivedTestDomain != randomSubdomain {
|
||||
return systemCompatOrManualDNSIssue(), fmt.Errorf("self-check #2: dns integration check failed: received unmatching subdomain %q", receivedTestDomain)
|
||||
}
|
||||
case <-time.After(dnsCheckWaitDuration):
|
||||
return systemCompatOrManualDNSIssue(), fmt.Errorf("self-check #2: dns integration check failed: did not receive test query after %s", dnsCheckWaitDuration)
|
||||
}
|
||||
log.Tracer(ctx).Tracef("compat: self-check #2: dns integration query check passed")
|
||||
|
||||
// Step 3: Have the nameserver respond with random data in the answer section.
|
||||
|
||||
// Wait for the reply from the resolver.
|
||||
select {
|
||||
case err := <-dnsCheckLookupError:
|
||||
if err != nil {
|
||||
return systemCompatibilityIssue, fmt.Errorf("self-check #3: dns integration check failed: failed to receive test response: %w", err)
|
||||
}
|
||||
case <-time.After(dnsCheckWaitDuration):
|
||||
return systemCompatibilityIssue, fmt.Errorf("self-check #3: dns integration check failed: did not receive test response after %s", dnsCheckWaitDuration)
|
||||
case <-ctx.Done():
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Check response.
|
||||
if !dnsCheckReturnedIP.Equal(randomAnswer) {
|
||||
return systemCompatibilityIssue, fmt.Errorf("self-check #3: dns integration check failed: received unmatching response %q", dnsCheckReturnedIP)
|
||||
}
|
||||
log.Tracer(ctx).Tracef("compat: self-check #3: dns integration response check passed")
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
* Check if the system integration sees a packet:
|
||||
* Send raw IP packet with random content and protocol, report finding to compat module.
|
||||
* use `Dial("ip4:61", "127.65.67.75")`.
|
||||
* Firewall reports back the data seen on `ip4:61` to IP `127.65.67.75`.
|
||||
* If this fails, the system integration is broken. -> Integration Issue
|
||||
* Check if a DNS request arrives at the nameserver:
|
||||
* Send A question for `[random-subdomain].self-check.portmaster.home.arpa.`.
|
||||
* Nameserver reports back the data seen.
|
||||
* If this fails, redirection to the nameserver fails.
|
||||
* This means there is another software interfering with DNS. -> Compatibility Issue
|
||||
* Have the nameserver respond with random data in the answer section.
|
||||
* Compat provides nameserver with random response data.
|
||||
* Compat module checks if the received data matches.
|
||||
* If this fails, redirection to the nameserver fails.
|
||||
* This means there is another software interfering with DNS on the return path. -> Compatibility Issue
|
||||
* DROPPED: If resolvers are reported failing, but we are online:
|
||||
* Send out plain DNS requests to one.one.one.one. and dns.quad9.net via the Go standard lookup and check if the responses are correct.
|
||||
* If not, something is blocking the Portmaster -> Secure DNS Issue
|
||||
* Discuss if this is necessary:
|
||||
* Does this improve from only having a failed TCP connection to the resolver?
|
||||
* Could another program block port 853, but fully leave requests for one.one.one.one. to port 53 alone?
|
||||
|
||||
*/
|
||||
527
service/compat/wfpstate.go
Normal file
527
service/compat/wfpstate.go
Normal file
@@ -0,0 +1,527 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/safing/portbase/utils/osdetail"
|
||||
)
|
||||
|
||||
// GetWFPState queries the system for the WFP state and returns a simplified
|
||||
// and cleaned version.
|
||||
func GetWFPState() (*SimplifiedWFPState, error) {
|
||||
// Use a file to get the wfp state, as the terminal isn't able to return the
|
||||
// data encoded in UTF-8.
|
||||
tmpDir, err := os.MkdirTemp("", "portmaster-debug-data-wfpstate")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tmp dir for wfpstate: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}()
|
||||
tmpFile := filepath.Join(tmpDir, "wfpstate.xml")
|
||||
|
||||
// Get wfp state and write it to the tmp file.
|
||||
_, err = osdetail.RunCmd(
|
||||
"netsh.exe",
|
||||
"wfp",
|
||||
"show",
|
||||
"state",
|
||||
tmpFile,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write wfp state to tmp file: %w", err)
|
||||
}
|
||||
|
||||
// Get tmp file contents.
|
||||
output, err := os.ReadFile(tmpFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read wfp state to tmp file: %w", err)
|
||||
}
|
||||
if len(output) == 0 {
|
||||
return nil, errors.New("wfp state tmp file was empty")
|
||||
}
|
||||
|
||||
// Parse wfp state.
|
||||
parsedState, err := parseWFPState(output)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse wfpstate: %w", err)
|
||||
}
|
||||
|
||||
// Return simplified and cleaned state.
|
||||
return parsedState.simplified(), nil
|
||||
}
|
||||
|
||||
/*
|
||||
Interesting data is found at:
|
||||
|
||||
providers->item[]
|
||||
->displayData->name
|
||||
->displayData->description
|
||||
->providerKey
|
||||
|
||||
subLayers->item[]
|
||||
->displayData->name
|
||||
->displayData->description
|
||||
->subLayerKey
|
||||
|
||||
layers->item[]->callouts->item[]
|
||||
->displayData->name
|
||||
->displayData->description
|
||||
->calloutKey
|
||||
->providerKey
|
||||
->applicableLayer
|
||||
|
||||
layers->item[]->filters->item[]
|
||||
->displayData->name
|
||||
->displayData->description
|
||||
->filterKey
|
||||
->providerKey
|
||||
->layerKey
|
||||
->subLayerKey
|
||||
*/
|
||||
|
||||
// SimplifiedWFPState is a simplified version of the full WFP state.
|
||||
type SimplifiedWFPState struct {
|
||||
Providers []*WFPProvider
|
||||
SubLayers []*WFPSubLayer
|
||||
Callouts []*WFPCallout
|
||||
Filters []*WFPFilter
|
||||
}
|
||||
|
||||
// WFPProvider represents a WFP Provider.
|
||||
type WFPProvider struct {
|
||||
Name string
|
||||
Description string
|
||||
ProviderKey string
|
||||
}
|
||||
|
||||
// WFPSubLayer represents a WFP SubLayer.
|
||||
type WFPSubLayer struct {
|
||||
Name string
|
||||
Description string
|
||||
SubLayerKey string
|
||||
}
|
||||
|
||||
// WFPCallout represents a WFP Callout.
|
||||
type WFPCallout struct {
|
||||
Name string
|
||||
Description string
|
||||
CalloutKey string
|
||||
ProviderKey string
|
||||
ApplicableLayer string
|
||||
}
|
||||
|
||||
// WFPFilter represents a WFP Filter.
|
||||
type WFPFilter struct {
|
||||
Name string
|
||||
Description string
|
||||
FilterKey string
|
||||
ProviderKey string
|
||||
LayerKey string
|
||||
SubLayerKey string
|
||||
}
|
||||
|
||||
// Keys returns all keys found in the WFP state.
|
||||
func (sw *SimplifiedWFPState) Keys() map[string]struct{} {
|
||||
lookupMap := make(map[string]struct{}, len(sw.Providers)+len(sw.SubLayers)+len(sw.Callouts)+len(sw.Filters))
|
||||
|
||||
// Collect keys.
|
||||
for _, provider := range sw.Providers {
|
||||
lookupMap[provider.ProviderKey] = struct{}{}
|
||||
}
|
||||
for _, subLayer := range sw.SubLayers {
|
||||
lookupMap[subLayer.SubLayerKey] = struct{}{}
|
||||
}
|
||||
for _, callout := range sw.Callouts {
|
||||
lookupMap[callout.CalloutKey] = struct{}{}
|
||||
}
|
||||
for _, filter := range sw.Filters {
|
||||
lookupMap[filter.FilterKey] = struct{}{}
|
||||
}
|
||||
|
||||
return lookupMap
|
||||
}
|
||||
|
||||
// AsTable formats the simplified WFP state as a table.
|
||||
func (sw *SimplifiedWFPState) AsTable() string {
|
||||
rows := make([]string, 0, len(sw.Providers)+len(sw.SubLayers)+len(sw.Callouts)+len(sw.Filters))
|
||||
|
||||
// Collect data and put it into rows.
|
||||
for _, provider := range sw.Providers {
|
||||
rows = append(rows, strings.Join([]string{
|
||||
provider.Name,
|
||||
"Provider",
|
||||
provider.Description,
|
||||
provider.ProviderKey,
|
||||
}, "\t"))
|
||||
}
|
||||
for _, subLayer := range sw.SubLayers {
|
||||
rows = append(rows, strings.Join([]string{
|
||||
subLayer.Name,
|
||||
"SubLayer",
|
||||
subLayer.Description,
|
||||
subLayer.SubLayerKey,
|
||||
}, "\t"))
|
||||
}
|
||||
for _, callout := range sw.Callouts {
|
||||
rows = append(rows, strings.Join([]string{
|
||||
callout.Name,
|
||||
"Callout",
|
||||
callout.Description,
|
||||
callout.CalloutKey,
|
||||
callout.ProviderKey,
|
||||
callout.ApplicableLayer,
|
||||
}, "\t"))
|
||||
}
|
||||
for _, filter := range sw.Filters {
|
||||
rows = append(rows, strings.Join([]string{
|
||||
filter.Name,
|
||||
"Filter",
|
||||
filter.Description,
|
||||
filter.FilterKey,
|
||||
filter.ProviderKey,
|
||||
filter.LayerKey,
|
||||
filter.SubLayerKey,
|
||||
}, "\t"))
|
||||
}
|
||||
|
||||
// Sort and build table.
|
||||
sort.Strings(rows)
|
||||
buf := bytes.NewBuffer(nil)
|
||||
tabWriter := tabwriter.NewWriter(buf, 8, 4, 3, ' ', 0)
|
||||
for _, row := range rows {
|
||||
fmt.Fprint(tabWriter, row)
|
||||
fmt.Fprint(tabWriter, "\n")
|
||||
}
|
||||
_ = tabWriter.Flush()
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// wfpState is the WFP state as returned by `netsh.exe wfp show state -`.
|
||||
type wfpState struct {
|
||||
XMLName xml.Name `xml:"wfpstate"`
|
||||
Text string `xml:",chardata"`
|
||||
TimeStamp string `xml:"timeStamp"`
|
||||
Providers struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []struct {
|
||||
Text string `xml:",chardata"`
|
||||
ProviderKey string `xml:"providerKey"`
|
||||
DisplayData struct {
|
||||
Text string `xml:",chardata"`
|
||||
Name string `xml:"name"`
|
||||
Description string `xml:"description"`
|
||||
} `xml:"displayData"`
|
||||
Flags struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item string `xml:"item"`
|
||||
} `xml:"flags"`
|
||||
ProviderData string `xml:"providerData"`
|
||||
ServiceName string `xml:"serviceName"`
|
||||
} `xml:"item"`
|
||||
} `xml:"providers"`
|
||||
SubLayers struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []struct {
|
||||
Text string `xml:",chardata"`
|
||||
SubLayerKey string `xml:"subLayerKey"`
|
||||
DisplayData struct {
|
||||
Text string `xml:",chardata"`
|
||||
Name string `xml:"name"`
|
||||
Description string `xml:"description"`
|
||||
} `xml:"displayData"`
|
||||
Flags struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item string `xml:"item"`
|
||||
} `xml:"flags"`
|
||||
ProviderKey string `xml:"providerKey"`
|
||||
ProviderData string `xml:"providerData"`
|
||||
Weight string `xml:"weight"`
|
||||
} `xml:"item"`
|
||||
} `xml:"subLayers"`
|
||||
Layers struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []struct {
|
||||
Text string `xml:",chardata"`
|
||||
Layer struct {
|
||||
Text string `xml:",chardata"`
|
||||
LayerKey string `xml:"layerKey"`
|
||||
DisplayData struct {
|
||||
Text string `xml:",chardata"`
|
||||
Name string `xml:"name"`
|
||||
Description string `xml:"description"`
|
||||
} `xml:"displayData"`
|
||||
Flags struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []string `xml:"item"`
|
||||
} `xml:"flags"`
|
||||
Field struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []struct {
|
||||
Text string `xml:",chardata"`
|
||||
FieldKey string `xml:"fieldKey"`
|
||||
Type string `xml:"type"`
|
||||
DataType string `xml:"dataType"`
|
||||
} `xml:"item"`
|
||||
} `xml:"field"`
|
||||
DefaultSubLayerKey string `xml:"defaultSubLayerKey"`
|
||||
LayerID string `xml:"layerId"`
|
||||
} `xml:"layer"`
|
||||
Callouts struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []struct {
|
||||
Text string `xml:",chardata"`
|
||||
CalloutKey string `xml:"calloutKey"`
|
||||
DisplayData struct {
|
||||
Text string `xml:",chardata"`
|
||||
Name string `xml:"name"`
|
||||
Description string `xml:"description"`
|
||||
} `xml:"displayData"`
|
||||
Flags struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []string `xml:"item"`
|
||||
} `xml:"flags"`
|
||||
ProviderKey string `xml:"providerKey"`
|
||||
ProviderData string `xml:"providerData"`
|
||||
ApplicableLayer string `xml:"applicableLayer"`
|
||||
CalloutID string `xml:"calloutId"`
|
||||
} `xml:"item"`
|
||||
} `xml:"callouts"`
|
||||
Filters struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []struct {
|
||||
Text string `xml:",chardata"`
|
||||
FilterKey string `xml:"filterKey"`
|
||||
DisplayData struct {
|
||||
Text string `xml:",chardata"`
|
||||
Name string `xml:"name"`
|
||||
Description string `xml:"description"`
|
||||
} `xml:"displayData"`
|
||||
Flags struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []string `xml:"item"`
|
||||
} `xml:"flags"`
|
||||
ProviderKey string `xml:"providerKey"`
|
||||
ProviderData struct {
|
||||
Text string `xml:",chardata"`
|
||||
Data string `xml:"data"`
|
||||
AsString string `xml:"asString"`
|
||||
} `xml:"providerData"`
|
||||
LayerKey string `xml:"layerKey"`
|
||||
SubLayerKey string `xml:"subLayerKey"`
|
||||
Weight struct {
|
||||
Text string `xml:",chardata"`
|
||||
Type string `xml:"type"`
|
||||
Uint8 string `xml:"uint8"`
|
||||
Uint64 string `xml:"uint64"`
|
||||
} `xml:"weight"`
|
||||
FilterCondition struct {
|
||||
Text string `xml:",chardata"`
|
||||
NumItems string `xml:"numItems,attr"`
|
||||
Item []struct {
|
||||
Text string `xml:",chardata"`
|
||||
FieldKey string `xml:"fieldKey"`
|
||||
MatchType string `xml:"matchType"`
|
||||
ConditionValue struct {
|
||||
Text string `xml:",chardata"`
|
||||
Type string `xml:"type"`
|
||||
Uint32 string `xml:"uint32"`
|
||||
Uint16 string `xml:"uint16"`
|
||||
RangeValue struct {
|
||||
Text string `xml:",chardata"`
|
||||
ValueLow struct {
|
||||
Text string `xml:",chardata"`
|
||||
Type string `xml:"type"`
|
||||
Uint16 string `xml:"uint16"`
|
||||
Uint32 string `xml:"uint32"`
|
||||
ByteArray16 string `xml:"byteArray16"`
|
||||
} `xml:"valueLow"`
|
||||
ValueHigh struct {
|
||||
Text string `xml:",chardata"`
|
||||
Type string `xml:"type"`
|
||||
Uint16 string `xml:"uint16"`
|
||||
Uint32 string `xml:"uint32"`
|
||||
ByteArray16 string `xml:"byteArray16"`
|
||||
} `xml:"valueHigh"`
|
||||
} `xml:"rangeValue"`
|
||||
Uint8 string `xml:"uint8"`
|
||||
ByteBlob struct {
|
||||
Text string `xml:",chardata"`
|
||||
Data string `xml:"data"`
|
||||
AsString string `xml:"asString"`
|
||||
} `xml:"byteBlob"`
|
||||
Sd string `xml:"sd"`
|
||||
Sid string `xml:"sid"`
|
||||
Uint64 string `xml:"uint64"`
|
||||
} `xml:"conditionValue"`
|
||||
} `xml:"item"`
|
||||
} `xml:"filterCondition"`
|
||||
Action struct {
|
||||
Text string `xml:",chardata"`
|
||||
Type string `xml:"type"`
|
||||
FilterType string `xml:"filterType"`
|
||||
} `xml:"action"`
|
||||
RawContext string `xml:"rawContext"`
|
||||
Reserved string `xml:"reserved"`
|
||||
FilterID string `xml:"filterId"`
|
||||
EffectiveWeight struct {
|
||||
Text string `xml:",chardata"`
|
||||
Type string `xml:"type"`
|
||||
Uint64 string `xml:"uint64"`
|
||||
} `xml:"effectiveWeight"`
|
||||
ProviderContextKey string `xml:"providerContextKey"`
|
||||
} `xml:"item"`
|
||||
} `xml:"filters"`
|
||||
} `xml:"item"`
|
||||
} `xml:"layers"`
|
||||
}
|
||||
|
||||
func parseWFPState(data []byte) (*wfpState, error) {
|
||||
w := &wfpState{}
|
||||
err := xml.Unmarshal(data, w)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *wfpState) simplified() *SimplifiedWFPState {
|
||||
sw := &SimplifiedWFPState{
|
||||
Providers: make([]*WFPProvider, 0, len(w.Providers.Item)),
|
||||
SubLayers: make([]*WFPSubLayer, 0, len(w.SubLayers.Item)),
|
||||
Callouts: make([]*WFPCallout, 0, len(w.Layers.Item)),
|
||||
Filters: make([]*WFPFilter, 0, len(w.Layers.Item)),
|
||||
}
|
||||
|
||||
// Collect data.
|
||||
for _, provider := range w.Providers.Item {
|
||||
if isIgnoredProvider(provider.DisplayData.Name, provider.ProviderKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
sw.Providers = append(sw.Providers, &WFPProvider{
|
||||
Name: defaultTo(provider.DisplayData.Name, "[no name]"),
|
||||
Description: defaultTo(provider.DisplayData.Description, "[no description]"),
|
||||
ProviderKey: defaultTo(provider.ProviderKey, "[no provider key]"),
|
||||
})
|
||||
}
|
||||
for _, subLayer := range w.SubLayers.Item {
|
||||
if isIgnoredProvider(subLayer.DisplayData.Name, "") {
|
||||
continue
|
||||
}
|
||||
|
||||
sw.SubLayers = append(sw.SubLayers, &WFPSubLayer{
|
||||
Name: defaultTo(subLayer.DisplayData.Name, "[no name]"),
|
||||
Description: defaultTo(subLayer.DisplayData.Description, "[no description]"),
|
||||
SubLayerKey: defaultTo(subLayer.SubLayerKey, "[no sublayer key]"),
|
||||
})
|
||||
}
|
||||
for _, layer := range w.Layers.Item {
|
||||
for _, callout := range layer.Callouts.Item {
|
||||
if isIgnoredProvider(callout.DisplayData.Name, callout.ProviderKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
sw.Callouts = append(sw.Callouts, &WFPCallout{
|
||||
Name: defaultTo(callout.DisplayData.Name, "[no name]"),
|
||||
Description: defaultTo(callout.DisplayData.Description, "[no description]"),
|
||||
CalloutKey: defaultTo(callout.CalloutKey, "[no callout key]"),
|
||||
ProviderKey: defaultTo(callout.ProviderKey, "[no provider key]"),
|
||||
ApplicableLayer: defaultTo(callout.ApplicableLayer, "[no applicable layer]"),
|
||||
})
|
||||
}
|
||||
for _, filter := range layer.Filters.Item {
|
||||
if isIgnoredProvider(filter.DisplayData.Name, filter.ProviderKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
sw.Filters = append(sw.Filters, &WFPFilter{
|
||||
Name: defaultTo(filter.DisplayData.Name, "[no name]"),
|
||||
Description: defaultTo(filter.DisplayData.Description, "[no description]"),
|
||||
FilterKey: defaultTo(filter.FilterKey, "[no filter key]"),
|
||||
ProviderKey: defaultTo(filter.ProviderKey, "[no provider key]"),
|
||||
LayerKey: defaultTo(filter.LayerKey, "[no layer key]"),
|
||||
SubLayerKey: defaultTo(filter.SubLayerKey, "[no sublayer key]"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return sw
|
||||
}
|
||||
|
||||
func isIgnoredProvider(name, key string) bool {
|
||||
// Check provider key.
|
||||
if key != "" {
|
||||
matched := true
|
||||
switch key {
|
||||
case "{1bebc969-61a5-4732-a177-847a0817862a}": // Microsoft Windows Defender Firewall IPsec Provider.
|
||||
case "{4b153735-1049-4480-aab4-d1b9bdc03710}": // Microsoft Windows Defender Firewall Provider.
|
||||
case "{893a4f22-9bba-49b7-8c66-3d40929c8fd5}": // Microsoft Windows Teredo firewall provider.
|
||||
case "{8e44982a-f477-11df-85ce-78e7d1810190}": // Windows Network Data Usage (NDU) Provider.
|
||||
case "{9c2532b4-0314-434f-8274-0cbaebdbda56}": // Microsoft Windows edge traversal socket option authorization provider.
|
||||
case "{aa6a7d87-7f8f-4d2a-be53-fda555cd5fe3}": // Microsoft Windows Defender Firewall IPsec Provider.
|
||||
case "{c698301d-9129-450c-937c-f4b834bfb374}": // Microsoft Windows edge traversal socket option authorization provider.
|
||||
case "{decc16ca-3f33-4346-be1e-8fb4ae0f3d62}": // Microsoft Windows Defender Firewall Provider.
|
||||
case "FWPM_PROVIDER_IKEEXT": // Microsoft Windows WFP Built-in IKEEXT provider used to identify filters added by IKE/AuthIP.
|
||||
case "FWPM_PROVIDER_IPSEC_DOSP_CONFIG": // Microsoft Windows WFP Built-in IPsec DoS Protection configuration provider used to identify filters added by IPsec Denial of Service Protection.
|
||||
case "FWPM_PROVIDER_MPSSVC_APP_ISOLATION": // Microsoft Windows WFP Built-in MPSSVC App Isolation provider.
|
||||
case "FWPM_PROVIDER_MPSSVC_EDP": // Microsoft Windows WFP Built-in MPSSVC Enterprise Data Protection provider.
|
||||
case "FWPM_PROVIDER_MPSSVC_TENANT_RESTRICTIONS": // Microsoft Windows WFP Built-in MPSSVC Tenant Restrictions provider.
|
||||
case "FWPM_PROVIDER_MPSSVC_WF": // Microsoft Windows WFP Built-in MPSSVC Windows Firewall provider.
|
||||
case "FWPM_PROVIDER_MPSSVC_WSH": // Microsoft Windows WFP Built-in MPSSVC Windows Service Hardening and Quarantine provider.
|
||||
case "FWPM_PROVIDER_TCP_CHIMNEY_OFFLOAD": // Microsoft Windows WFP Built-in TCP Chimney Offload provider used to identify filters added by TCP Chimney Offload.
|
||||
case "FWPM_PROVIDER_TCP_TEMPLATES": // Microsoft Windows WFP Built-in TCP Templates provider used to identify filters added by TCP Template based configuration.
|
||||
default:
|
||||
matched = false
|
||||
}
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Some entries don't have a provider key (set).
|
||||
// These are pretty generic, but the output strings are localized.
|
||||
if name != "" {
|
||||
switch {
|
||||
case strings.Contains(name, "Microsoft Corporation"):
|
||||
return true
|
||||
case strings.Contains(name, "windefend"):
|
||||
return true
|
||||
case strings.Contains(name, "WFP"):
|
||||
return true
|
||||
case strings.Contains(name, "RPC"):
|
||||
return true
|
||||
case strings.Contains(name, "NDU"):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultTo(a, b string) string {
|
||||
if a != "" {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
56270
service/compat/wfpstate_test.go
Normal file
56270
service/compat/wfpstate_test.go
Normal file
File diff suppressed because it is too large
Load Diff
294
service/core/api.go
Normal file
294
service/core/api.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portbase/utils/debug"
|
||||
"github.com/safing/portmaster/service/compat"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
"github.com/safing/portmaster/service/status"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
"github.com/safing/portmaster/spn/captain"
|
||||
)
|
||||
|
||||
func registerAPIEndpoints() error {
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "core/shutdown",
|
||||
Write: api.PermitSelf,
|
||||
// Do NOT register as belonging to the module, so that the API is available
|
||||
// when something fails during starting of this module or a dependency.
|
||||
ActionFunc: shutdown,
|
||||
Name: "Shut Down Portmaster",
|
||||
Description: "Shut down the Portmaster Core Service and all UI components.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "core/restart",
|
||||
Write: api.PermitAdmin,
|
||||
// Do NOT register as belonging to the module, so that the API is available
|
||||
// when something fails during starting of this module or a dependency.
|
||||
ActionFunc: restart,
|
||||
Name: "Restart Portmaster",
|
||||
Description: "Restart the Portmaster Core Service.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "debug/core",
|
||||
Read: api.PermitAnyone,
|
||||
BelongsTo: module,
|
||||
DataFunc: debugInfo,
|
||||
Name: "Get Debug Information",
|
||||
Description: "Returns network debugging information, similar to debug/info, but with system status data.",
|
||||
Parameters: []api.Parameter{{
|
||||
Method: http.MethodGet,
|
||||
Field: "style",
|
||||
Value: "github",
|
||||
Description: "Specify the formatting style. The default is simple markdown formatting.",
|
||||
}},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "app/auth",
|
||||
Read: api.PermitAnyone,
|
||||
BelongsTo: module,
|
||||
StructFunc: authorizeApp,
|
||||
Name: "Request an authentication token with a given set of permissions. The user will be prompted to either authorize or deny the request. Used for external or third-party tool integrations.",
|
||||
Parameters: []api.Parameter{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Field: "app-name",
|
||||
Description: "The name of the application requesting access",
|
||||
},
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Field: "read",
|
||||
Description: "The requested read permission",
|
||||
},
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Field: "write",
|
||||
Description: "The requested write permission",
|
||||
},
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Field: "ttl",
|
||||
Description: "The time-to-live for the new access token. Defaults to 24h",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "app/profile",
|
||||
Read: api.PermitUser,
|
||||
BelongsTo: module,
|
||||
StructFunc: getMyProfile,
|
||||
Name: "Get the ID of the calling profile",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// shutdown shuts the Portmaster down.
|
||||
func shutdown(_ *api.Request) (msg string, err error) {
|
||||
log.Warning("core: user requested shutdown via action")
|
||||
|
||||
// Do not run in worker, as this would block itself here.
|
||||
go modules.Shutdown() //nolint:errcheck
|
||||
|
||||
return "shutdown initiated", nil
|
||||
}
|
||||
|
||||
// restart restarts the Portmaster.
|
||||
func restart(_ *api.Request) (msg string, err error) {
|
||||
log.Info("core: user requested restart via action")
|
||||
|
||||
// Let the updates module handle restarting.
|
||||
updates.RestartNow()
|
||||
|
||||
return "restart initiated", nil
|
||||
}
|
||||
|
||||
// debugInfo returns the debugging information for support requests.
|
||||
func debugInfo(ar *api.Request) (data []byte, err error) {
|
||||
// Create debug information helper.
|
||||
di := new(debug.Info)
|
||||
di.Style = ar.Request.URL.Query().Get("style")
|
||||
|
||||
// Add debug information.
|
||||
|
||||
// Very basic information at the start.
|
||||
di.AddVersionInfo()
|
||||
di.AddPlatformInfo(ar.Context())
|
||||
|
||||
// Errors and unexpected logs.
|
||||
di.AddLastReportedModuleError()
|
||||
di.AddLastUnexpectedLogs()
|
||||
|
||||
// Status Information from various modules.
|
||||
status.AddToDebugInfo(di)
|
||||
captain.AddToDebugInfo(di)
|
||||
resolver.AddToDebugInfo(di)
|
||||
config.AddToDebugInfo(di)
|
||||
|
||||
// Detailed information.
|
||||
updates.AddToDebugInfo(di)
|
||||
compat.AddToDebugInfo(di)
|
||||
di.AddGoroutineStack()
|
||||
|
||||
// Return data.
|
||||
return di.Bytes(), nil
|
||||
}
|
||||
|
||||
// getSavePermission returns the requested api.Permission from p.
|
||||
// It only allows "user" and "admin" as external processes should
|
||||
// never be able to request "self".
|
||||
func getSavePermission(p string) api.Permission {
|
||||
switch p {
|
||||
case "user":
|
||||
return api.PermitUser
|
||||
case "admin":
|
||||
return api.PermitAdmin
|
||||
default:
|
||||
return api.NotSupported
|
||||
}
|
||||
}
|
||||
|
||||
func getMyProfile(ar *api.Request) (interface{}, error) {
|
||||
proc, err := process.GetProcessByRequestOrigin(ar)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
localProfile := proc.Profile().LocalProfile()
|
||||
|
||||
return map[string]interface{}{
|
||||
"profile": localProfile.ID,
|
||||
"source": localProfile.Source,
|
||||
"name": localProfile.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authorizeApp(ar *api.Request) (interface{}, error) {
|
||||
appName := ar.Request.URL.Query().Get("app-name")
|
||||
readPermStr := ar.Request.URL.Query().Get("read")
|
||||
writePermStr := ar.Request.URL.Query().Get("write")
|
||||
|
||||
ttl := time.Hour * 24
|
||||
if ttlStr := ar.Request.URL.Query().Get("ttl"); ttlStr != "" {
|
||||
var err error
|
||||
ttl, err = time.ParseDuration(ttlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// convert the requested read and write permissions to their api.Permission
|
||||
// value. This ensures only "user" or "admin" permissions can be requested.
|
||||
if getSavePermission(readPermStr) <= api.NotSupported {
|
||||
return nil, fmt.Errorf("invalid read permission")
|
||||
}
|
||||
if getSavePermission(writePermStr) <= api.NotSupported {
|
||||
return nil, fmt.Errorf("invalid read permission")
|
||||
}
|
||||
|
||||
proc, err := process.GetProcessByRequestOrigin(ar)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to identify requesting process: %w", err)
|
||||
}
|
||||
|
||||
n := notifications.Notification{
|
||||
Type: notifications.Prompt,
|
||||
EventID: "core:authorize-app-" + time.Now().String(),
|
||||
Title: "An app requests access to the Portmaster",
|
||||
Message: "Allow " + appName + " (" + proc.Profile().LocalProfile().Name + ") to query and modify the Portmaster?\n\nBinary: " + proc.Path,
|
||||
ShowOnSystem: true,
|
||||
Expires: time.Now().Add(time.Minute).Unix(),
|
||||
AvailableActions: []*notifications.Action{
|
||||
{
|
||||
ID: "allow",
|
||||
Text: "Authorize",
|
||||
},
|
||||
{
|
||||
ID: "deny",
|
||||
Text: "Deny",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ch := make(chan string)
|
||||
|
||||
validUntil := time.Now().Add(ttl)
|
||||
|
||||
n.SetActionFunction(func(ctx context.Context, n *notifications.Notification) error {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if n.SelectedActionID != "allow" {
|
||||
close(ch)
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := config.Concurrent.GetAsStringArray(api.CfgAPIKeys, []string{})()
|
||||
|
||||
newKeyData, err := rng.Bytes(8)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newKeyHex := hex.EncodeToString(newKeyData)
|
||||
|
||||
query := url.Values{
|
||||
"read": []string{readPermStr},
|
||||
"write": []string{writePermStr},
|
||||
"expires": []string{validUntil.Format(time.RFC3339)},
|
||||
}
|
||||
|
||||
keys = append(keys, fmt.Sprintf("%s?%s", newKeyHex, query.Encode()))
|
||||
|
||||
if err := config.SetConfigOption(api.CfgAPIKeys, keys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ch <- newKeyHex
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
n.Save()
|
||||
|
||||
select {
|
||||
case key := <-ch:
|
||||
if len(key) == 0 {
|
||||
return nil, fmt.Errorf("access denied")
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"key": key,
|
||||
"validUntil": validUntil,
|
||||
}, nil
|
||||
case <-ar.Context().Done():
|
||||
return nil, fmt.Errorf("timeout")
|
||||
}
|
||||
}
|
||||
43
service/core/base/databases.go
Normal file
43
service/core/base/databases.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/database"
|
||||
_ "github.com/safing/portbase/database/dbmodule"
|
||||
_ "github.com/safing/portbase/database/storage/bbolt"
|
||||
)
|
||||
|
||||
// Default Values (changeable for testing).
|
||||
var (
|
||||
DefaultDatabaseStorageType = "bbolt"
|
||||
)
|
||||
|
||||
func registerDatabases() error {
|
||||
_, err := database.Register(&database.Database{
|
||||
Name: "core",
|
||||
Description: "Holds core data, such as settings and profiles",
|
||||
StorageType: DefaultDatabaseStorageType,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = database.Register(&database.Database{
|
||||
Name: "cache",
|
||||
Description: "Cached data, such as Intelligence and DNS Records",
|
||||
StorageType: DefaultDatabaseStorageType,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// _, err = database.Register(&database.Database{
|
||||
// Name: "history",
|
||||
// Description: "Historic event data",
|
||||
// StorageType: DefaultDatabaseStorageType,
|
||||
// })
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
69
service/core/base/global.go
Normal file
69
service/core/base/global.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/info"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
// Default Values (changeable for testing).
|
||||
var (
|
||||
DefaultAPIListenAddress = "127.0.0.1:817"
|
||||
|
||||
dataDir string
|
||||
databaseDir string
|
||||
showVersion bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&dataDir, "data", "", "set data directory")
|
||||
flag.StringVar(&databaseDir, "db", "", "alias to --data (deprecated)")
|
||||
flag.BoolVar(&showVersion, "version", false, "show version and exit")
|
||||
|
||||
modules.SetGlobalPrepFn(globalPrep)
|
||||
}
|
||||
|
||||
func globalPrep() error {
|
||||
// check if meta info is ok
|
||||
err := info.CheckVersion()
|
||||
if err != nil {
|
||||
return errors.New("compile error: please compile using the provided build script")
|
||||
}
|
||||
|
||||
// print version
|
||||
if showVersion {
|
||||
fmt.Println(info.FullVersion())
|
||||
return modules.ErrCleanExit
|
||||
}
|
||||
|
||||
// check data root
|
||||
if dataroot.Root() == nil {
|
||||
// initialize data dir
|
||||
|
||||
// backwards compatibility
|
||||
if dataDir == "" {
|
||||
dataDir = databaseDir
|
||||
}
|
||||
|
||||
// check data dir
|
||||
if dataDir == "" {
|
||||
return errors.New("please set the data directory using --data=/path/to/data/dir")
|
||||
}
|
||||
|
||||
// initialize structure
|
||||
err := dataroot.Initialize(dataDir, 0o0755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// set api listen address
|
||||
api.SetDefaultAPIListenAddress(DefaultAPIListenAddress)
|
||||
|
||||
return nil
|
||||
}
|
||||
60
service/core/base/logs.go
Normal file
60
service/core/base/logs.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
const (
|
||||
logTTL = 30 * 24 * time.Hour
|
||||
logFileDir = "logs"
|
||||
logFileSuffix = ".log"
|
||||
)
|
||||
|
||||
func registerLogCleaner() {
|
||||
module.NewTask("log cleaner", logCleaner).
|
||||
Repeat(24 * time.Hour).
|
||||
Schedule(time.Now().Add(15 * time.Minute))
|
||||
}
|
||||
|
||||
func logCleaner(_ context.Context, _ *modules.Task) error {
|
||||
ageThreshold := time.Now().Add(-logTTL)
|
||||
|
||||
return filepath.Walk(
|
||||
filepath.Join(dataroot.Root().Path, logFileDir),
|
||||
func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Warningf("core: failed to access %s while deleting old log files: %s", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case !info.Mode().IsRegular():
|
||||
// Only delete regular files.
|
||||
case !strings.HasSuffix(path, logFileSuffix):
|
||||
// Only delete files that end with the correct suffix.
|
||||
case info.ModTime().After(ageThreshold):
|
||||
// Only delete files that are older that the log TTL.
|
||||
default:
|
||||
// Delete log file.
|
||||
err := os.Remove(path)
|
||||
if err != nil {
|
||||
log.Warningf("core: failed to delete old log file %s: %s", path, err)
|
||||
} else {
|
||||
log.Tracef("core: deleted old log file %s", path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
38
service/core/base/module.go
Normal file
38
service/core/base/module.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
_ "github.com/safing/portbase/config"
|
||||
_ "github.com/safing/portbase/metrics"
|
||||
"github.com/safing/portbase/modules"
|
||||
_ "github.com/safing/portbase/rng"
|
||||
)
|
||||
|
||||
var module *modules.Module
|
||||
|
||||
func init() {
|
||||
module = modules.Register("base", nil, start, nil, "database", "config", "rng", "metrics")
|
||||
|
||||
// For prettier subsystem graph, printed with --print-subsystem-graph
|
||||
/*
|
||||
subsystems.Register(
|
||||
"base",
|
||||
"Base",
|
||||
"THE GROUND.",
|
||||
baseModule,
|
||||
"",
|
||||
nil,
|
||||
)
|
||||
*/
|
||||
}
|
||||
|
||||
func start() error {
|
||||
startProfiling()
|
||||
|
||||
if err := registerDatabases(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
registerLogCleaner()
|
||||
|
||||
return nil
|
||||
}
|
||||
41
service/core/base/profiling.go
Normal file
41
service/core/base/profiling.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
)
|
||||
|
||||
var cpuProfile string
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&cpuProfile, "cpuprofile", "", "write cpu profile to `file`")
|
||||
}
|
||||
|
||||
func startProfiling() {
|
||||
if cpuProfile != "" {
|
||||
module.StartWorker("cpu profiler", cpuProfiler)
|
||||
}
|
||||
}
|
||||
|
||||
func cpuProfiler(ctx context.Context) error {
|
||||
f, err := os.Create(cpuProfile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create CPU profile: %w", err)
|
||||
}
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
return fmt.Errorf("could not start CPU profile: %w", err)
|
||||
}
|
||||
|
||||
// wait for shutdown
|
||||
<-ctx.Done()
|
||||
|
||||
pprof.StopCPUProfile()
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close CPU profile file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
112
service/core/config.go
Normal file
112
service/core/config.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
locale "github.com/Xuanwo/go-locale"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
// Configuration Keys.
|
||||
var (
|
||||
// CfgDevModeKey was previously defined here.
|
||||
CfgDevModeKey = config.CfgDevModeKey
|
||||
|
||||
CfgNetworkServiceKey = "core/networkService"
|
||||
defaultNetworkServiceMode bool
|
||||
|
||||
CfgLocaleKey = "core/locale"
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(
|
||||
&defaultNetworkServiceMode,
|
||||
"network-service",
|
||||
false,
|
||||
"set default network service mode; configuration is stronger",
|
||||
)
|
||||
}
|
||||
|
||||
func registerConfig() error {
|
||||
if err := config.Register(&config.Option{
|
||||
Name: "Network Service",
|
||||
Key: CfgNetworkServiceKey,
|
||||
Description: "Use the Portmaster as a network service, where applicable. You will have to take care of lots of network setup yourself in order to run this properly and securely.",
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
DefaultValue: defaultNetworkServiceMode,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: 513,
|
||||
config.CategoryAnnotation: "Network Service",
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := config.Register(&config.Option{
|
||||
Name: "Time and Date Format",
|
||||
Key: CfgLocaleKey,
|
||||
Description: "Configures the time and date format for the user interface. Selection is an example and correct formatting in the UI is a continual work in progress.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
DefaultValue: getDefaultLocale(),
|
||||
PossibleValues: []config.PossibleValue{
|
||||
{
|
||||
Name: "24h DD-MM-YYYY",
|
||||
Value: enGBLocale,
|
||||
},
|
||||
{
|
||||
Name: "12h MM/DD/YYYY",
|
||||
Value: enUSLocale,
|
||||
},
|
||||
},
|
||||
Annotations: config.Annotations{
|
||||
config.CategoryAnnotation: "User Interface",
|
||||
config.DisplayHintAnnotation: config.DisplayHintOneOf,
|
||||
config.RequiresUIReloadAnnotation: true,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDefaultLocale() string {
|
||||
// Get locales from system.
|
||||
detectedLocales, err := locale.DetectAll()
|
||||
if err != nil {
|
||||
log.Warningf("core: failed to detect locale: %s", err)
|
||||
return enGBLocale
|
||||
}
|
||||
|
||||
// log.Debugf("core: detected locales: %s", detectedLocales)
|
||||
|
||||
// Check if there is a locale that corresponds to the en-US locale.
|
||||
for _, detectedLocale := range detectedLocales {
|
||||
if slices.Contains[[]string, string](defaultEnUSLocales, detectedLocale.String()) {
|
||||
return enUSLocale
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, return the en-GB locale as default.
|
||||
return enGBLocale
|
||||
}
|
||||
|
||||
var (
|
||||
enGBLocale = "en-GB"
|
||||
enUSLocale = "en-US"
|
||||
|
||||
defaultEnUSLocales = []string{
|
||||
"en-AS", // English (American Samoa)
|
||||
"en-GU", // English (Guam)
|
||||
"en-UM", // English (U.S. Minor Outlying Islands)
|
||||
"en-US", // English (United States)
|
||||
"en-VI", // English (U.S. Virgin Islands)
|
||||
}
|
||||
)
|
||||
100
service/core/core.go
Normal file
100
service/core/core.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/metrics"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
_ "github.com/safing/portmaster/service/broadcasts"
|
||||
_ "github.com/safing/portmaster/service/netenv"
|
||||
_ "github.com/safing/portmaster/service/netquery"
|
||||
_ "github.com/safing/portmaster/service/status"
|
||||
_ "github.com/safing/portmaster/service/sync"
|
||||
_ "github.com/safing/portmaster/service/ui"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
)
|
||||
|
||||
const (
|
||||
eventShutdown = "shutdown"
|
||||
eventRestart = "restart"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
disableShutdownEvent bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "netquery", "interception", "compat", "broadcasts", "sync")
|
||||
subsystems.Register(
|
||||
"core",
|
||||
"Core",
|
||||
"Base Structure and System Integration",
|
||||
module,
|
||||
"config:core/",
|
||||
nil,
|
||||
)
|
||||
|
||||
flag.BoolVar(
|
||||
&disableShutdownEvent,
|
||||
"disable-shutdown-event",
|
||||
false,
|
||||
"disable shutdown event to keep app and notifier open when core shuts down",
|
||||
)
|
||||
|
||||
modules.SetGlobalShutdownFn(shutdownHook)
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
registerEvents()
|
||||
|
||||
// init config
|
||||
err := registerConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := registerAPIEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
if err := startPlatformSpecific(); err != nil {
|
||||
return fmt.Errorf("failed to start plattform-specific components: %w", err)
|
||||
}
|
||||
|
||||
// Enable persistent metrics.
|
||||
if err := metrics.EnableMetricPersistence("core:metrics/storage"); err != nil {
|
||||
log.Warningf("core: failed to enable persisted metrics: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerEvents() {
|
||||
module.RegisterEvent(eventShutdown, true)
|
||||
module.RegisterEvent(eventRestart, true)
|
||||
}
|
||||
|
||||
func shutdownHook() {
|
||||
// Notify everyone of the restart/shutdown.
|
||||
if !updates.IsRestarting() {
|
||||
// Only trigger shutdown event if not disabled.
|
||||
if !disableShutdownEvent {
|
||||
module.TriggerEvent(eventShutdown, nil)
|
||||
}
|
||||
} else {
|
||||
module.TriggerEvent(eventRestart, nil)
|
||||
}
|
||||
|
||||
// Wait a bit for the event to propagate.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
8
service/core/os_default.go
Normal file
8
service/core/os_default.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package core
|
||||
|
||||
// only return on Fatal error!
|
||||
func startPlatformSpecific() error {
|
||||
return nil
|
||||
}
|
||||
16
service/core/os_windows.go
Normal file
16
service/core/os_windows.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/utils/osdetail"
|
||||
)
|
||||
|
||||
// only return on Fatal error!
|
||||
func startPlatformSpecific() error {
|
||||
// We can't catch errors when calling WindowsNTVersion() in logging, so we call the function here, just to catch possible errors
|
||||
if _, err := osdetail.WindowsNTVersion(); err != nil {
|
||||
log.Errorf("failed to obtain WindowsNTVersion: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
137
service/core/pmtesting/testing.go
Normal file
137
service/core/pmtesting/testing.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Package pmtesting provides a simple unit test setup routine.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// package name
|
||||
//
|
||||
// import (
|
||||
// "testing"
|
||||
//
|
||||
// "github.com/safing/portmaster/service/core/pmtesting"
|
||||
// )
|
||||
//
|
||||
// func TestMain(m *testing.M) {
|
||||
// pmtesting.TestMain(m, module)
|
||||
// }
|
||||
package pmtesting
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
"testing"
|
||||
|
||||
_ "github.com/safing/portbase/database/storage/hashmap"
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/core/base"
|
||||
)
|
||||
|
||||
var printStackOnExit bool
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down")
|
||||
}
|
||||
|
||||
// TestHookFunc describes the functions passed to TestMainWithHooks.
|
||||
type TestHookFunc func() error
|
||||
|
||||
// TestMain provides a simple unit test setup routine.
|
||||
func TestMain(m *testing.M, module *modules.Module) {
|
||||
TestMainWithHooks(m, module, nil, nil)
|
||||
}
|
||||
|
||||
// TestMainWithHooks provides a simple unit test setup routine and calls
|
||||
// afterStartFn after modules have started and beforeStopFn before modules
|
||||
// are shutdown.
|
||||
func TestMainWithHooks(m *testing.M, module *modules.Module, afterStartFn, beforeStopFn TestHookFunc) {
|
||||
// Only enable needed modules.
|
||||
modules.EnableModuleManagement(nil)
|
||||
|
||||
// Enable this module for testing.
|
||||
if module != nil {
|
||||
module.Enable()
|
||||
}
|
||||
|
||||
// switch databases to memory only
|
||||
base.DefaultDatabaseStorageType = "hashmap"
|
||||
|
||||
// switch API to high port
|
||||
base.DefaultAPIListenAddress = "127.0.0.1:10817"
|
||||
|
||||
// set log level
|
||||
log.SetLogLevel(log.TraceLevel)
|
||||
|
||||
// tmp dir for data root (db & config)
|
||||
tmpDir := filepath.Join(os.TempDir(), "portmaster-testing")
|
||||
// initialize data dir
|
||||
err := dataroot.Initialize(tmpDir, 0o0755)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// start modules
|
||||
var exitCode int
|
||||
err = modules.Start()
|
||||
if err != nil {
|
||||
// starting failed
|
||||
fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err)
|
||||
exitCode = 1
|
||||
} else {
|
||||
runTests := true
|
||||
if afterStartFn != nil {
|
||||
if err := afterStartFn(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to run test start hook: %s\n", err)
|
||||
runTests = false
|
||||
exitCode = 1
|
||||
}
|
||||
}
|
||||
|
||||
if runTests {
|
||||
// run tests
|
||||
exitCode = m.Run()
|
||||
}
|
||||
}
|
||||
|
||||
if beforeStopFn != nil {
|
||||
if err := beforeStopFn(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to run test shutdown hook: %s\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// shutdown
|
||||
_ = modules.Shutdown()
|
||||
if modules.GetExitStatusCode() != 0 {
|
||||
exitCode = modules.GetExitStatusCode()
|
||||
fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err)
|
||||
}
|
||||
printStack()
|
||||
|
||||
// clean up and exit
|
||||
|
||||
// Important: Do not remove tmpDir, as it is used as a cache for updates.
|
||||
// remove config
|
||||
_ = os.Remove(filepath.Join(tmpDir, "config.json"))
|
||||
// remove databases
|
||||
_ = os.Remove(filepath.Join(tmpDir, "databases.json"))
|
||||
_ = os.RemoveAll(filepath.Join(tmpDir, "databases"))
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func printStack() {
|
||||
if printStackOnExit {
|
||||
fmt.Println("=== PRINTING TRACES ===")
|
||||
fmt.Println("=== GOROUTINES ===")
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 2)
|
||||
fmt.Println("=== BLOCKING ===")
|
||||
_ = pprof.Lookup("block").WriteTo(os.Stdout, 2)
|
||||
fmt.Println("=== MUTEXES ===")
|
||||
_ = pprof.Lookup("mutex").WriteTo(os.Stdout, 2)
|
||||
fmt.Println("=== END TRACES ===")
|
||||
}
|
||||
}
|
||||
55
service/detection/dga/lms.go
Normal file
55
service/detection/dga/lms.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package dga
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// LmsScoreOfDomain calculates the mean longest meaningful substring of a domain.
|
||||
// It follows some special rules to increase accuracy. It returns a value between
|
||||
// 0 and 100, representing the length-based percentage of the meaningful substring.
|
||||
func LmsScoreOfDomain(domain string) float64 {
|
||||
var totalScore float64
|
||||
domain = strings.ToLower(domain)
|
||||
subjects := strings.Split(domain, ".")
|
||||
var totalLength int
|
||||
for _, subject := range subjects {
|
||||
totalLength += len(subject)
|
||||
}
|
||||
for _, subject := range subjects {
|
||||
// calculate score, weigh it and add it
|
||||
if len(subject) > 0 {
|
||||
totalScore += LmsScore(subject) * (float64(len(subject)) / float64(totalLength))
|
||||
}
|
||||
}
|
||||
return totalScore
|
||||
}
|
||||
|
||||
// LmsScore calculates the longest meaningful substring of a domain. It returns a
|
||||
// value between 0 and 100, representing the length-based percentage of the
|
||||
// meaningful substring.
|
||||
func LmsScore(subject string) float64 {
|
||||
lmsStart := -1
|
||||
lmsStop := -1
|
||||
longestLms := 0
|
||||
|
||||
for i, c := range subject {
|
||||
if int(c) >= int('a') && int(c) <= int('z') {
|
||||
if lmsStart == -1 {
|
||||
lmsStart = i
|
||||
}
|
||||
} else {
|
||||
if lmsStart > -1 {
|
||||
lmsStop = i
|
||||
if lmsStop-lmsStart > longestLms {
|
||||
longestLms = lmsStop - lmsStart
|
||||
}
|
||||
lmsStart = -1
|
||||
}
|
||||
}
|
||||
}
|
||||
if lmsStop == -1 {
|
||||
longestLms = len(subject)
|
||||
}
|
||||
// fmt.Printf("algs: lms score of %s is %.2f\n", subject, (float64(longest_lms) * 100.0 / float64(len(subject))))
|
||||
return (float64(longestLms) * 100.0 / float64(len(subject)))
|
||||
}
|
||||
21
service/detection/dga/lms_test.go
Normal file
21
service/detection/dga/lms_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package dga
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLmsScoreOfDomain(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDomain(t, "g.symcd.com.", 100, 100)
|
||||
testDomain(t, "www.google.com.", 100, 100)
|
||||
testDomain(t, "55ttt5.12abc3.test.com.", 68, 69)
|
||||
testDomain(t, "mbtq6opnuodp34gcrma65fxacgxv5ukr7lq6xuhr4mhoibe7.yvqptrozfbnqyemchpovw3q5xwjibuxfsgb72mix3znhpfhc.i2n7jh2gadqaadck3zs3vg3hbv5pkmwzeay4gc75etyettbb.isi5mhmowtfriu33uxzmgvjur5g2p3tloynwohfrggee6fkn.meop7kqyd5gwxxa3.er.spotify.com.", 0, 31)
|
||||
}
|
||||
|
||||
func testDomain(t *testing.T, domain string, min, max float64) {
|
||||
t.Helper()
|
||||
|
||||
score := LmsScoreOfDomain(domain)
|
||||
if score < min || score > max {
|
||||
t.Errorf("domain %s has scored %.2f, but should be between %.0f and %.0f", domain, score, min, max)
|
||||
}
|
||||
}
|
||||
215
service/firewall/api.go
Normal file
215
service/firewall/api.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/utils"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
)
|
||||
|
||||
const (
|
||||
deniedMsgUnidentified = `%wFailed to identify the requesting process. Reload to try again.`
|
||||
|
||||
deniedMsgSystem = `%wSystem access to the Portmaster API is not permitted.
|
||||
You can enable the Development Mode to disable API authentication for development purposes.`
|
||||
|
||||
deniedMsgUnauthorized = `%wThe requesting process is not authorized to access the Portmaster API.
|
||||
Checked process paths:
|
||||
%s
|
||||
|
||||
The authorized root path is %s.
|
||||
You can enable the Development Mode to disable API authentication for development purposes.
|
||||
For production use please create an API key in the settings.`
|
||||
|
||||
deniedMsgMisconfigured = `%wThe authentication system is misconfigured.`
|
||||
)
|
||||
|
||||
var (
|
||||
dataRoot *utils.DirStructure
|
||||
|
||||
apiPortSet bool
|
||||
apiIP net.IP
|
||||
apiPort uint16
|
||||
)
|
||||
|
||||
func prepAPIAuth() error {
|
||||
dataRoot = dataroot.Root()
|
||||
return api.SetAuthenticator(apiAuthenticator)
|
||||
}
|
||||
|
||||
func startAPIAuth() {
|
||||
var err error
|
||||
apiIP, apiPort, err = netutils.ParseIPPort(apiListenAddress())
|
||||
if err != nil {
|
||||
log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err)
|
||||
return
|
||||
}
|
||||
apiPortSet = true
|
||||
log.Tracef("filter: api port set to %d", apiPort)
|
||||
}
|
||||
|
||||
func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, err error) {
|
||||
if configReady.IsSet() && devMode() {
|
||||
return &api.AuthToken{
|
||||
Read: api.PermitSelf,
|
||||
Write: api.PermitSelf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// get local IP/Port
|
||||
localIP, localPort, err := netutils.ParseIPPort(s.Addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get local IP/Port: %w", err)
|
||||
}
|
||||
|
||||
// get remote IP/Port
|
||||
remoteIP, remotePort, err := netutils.ParseIPPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get remote IP/Port: %w", err)
|
||||
}
|
||||
|
||||
// Check if the request is even local.
|
||||
myIP, err := netenv.IsMyIP(remoteIP)
|
||||
if err == nil && !myIP {
|
||||
// Return to caller that the request was not handled.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Tracer(r.Context()).Tracef("filter: authenticating API request from %s", r.RemoteAddr)
|
||||
|
||||
// It is important that this works, retry 5 times: every 500ms for 2.5s.
|
||||
var retry bool
|
||||
for tries := 0; tries < 5; tries++ {
|
||||
retry, err = authenticateAPIRequest(
|
||||
r.Context(),
|
||||
&packet.Info{
|
||||
Inbound: false, // outbound as we are looking for the process of the source address
|
||||
Version: packet.IPv4,
|
||||
Protocol: packet.TCP,
|
||||
Src: remoteIP, // source as in the process we are looking for
|
||||
SrcPort: remotePort, // source as in the process we are looking for
|
||||
Dst: localIP,
|
||||
DstPort: localPort,
|
||||
PID: process.UndefinedProcessID,
|
||||
},
|
||||
)
|
||||
if !retry {
|
||||
break
|
||||
}
|
||||
|
||||
// wait a little
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.AuthToken{
|
||||
Read: api.PermitSelf,
|
||||
Write: api.PermitSelf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bool, err error) {
|
||||
var procsChecked []string
|
||||
var originalPid int
|
||||
|
||||
// Get authenticated path.
|
||||
authenticatedPath := updates.RootPath()
|
||||
if authenticatedPath == "" {
|
||||
return false, fmt.Errorf(deniedMsgMisconfigured, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
|
||||
}
|
||||
// Get real path.
|
||||
authenticatedPath, err = filepath.EvalSymlinks(authenticatedPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf(deniedMsgUnidentified, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
|
||||
}
|
||||
// Add filepath separator to confine to directory.
|
||||
authenticatedPath += string(filepath.Separator)
|
||||
|
||||
// Get process of request.
|
||||
pid, _, _ := process.GetPidOfConnection(ctx, pktInfo)
|
||||
if pid < 0 {
|
||||
return false, fmt.Errorf(deniedMsgUnidentified, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
|
||||
}
|
||||
proc, err := process.GetOrFindProcess(ctx, pid)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Debugf("filter: failed to get process of api request: %s", err)
|
||||
originalPid = process.UnidentifiedProcessID
|
||||
} else {
|
||||
originalPid = proc.Pid
|
||||
var previousPid int
|
||||
|
||||
// Find parent for up to two levels, if we don't match the path.
|
||||
checkLevels := 2
|
||||
checkLevelsLoop:
|
||||
for i := 0; i < checkLevels+1; i++ {
|
||||
// Check for eligible path.
|
||||
switch proc.Pid {
|
||||
case process.UnidentifiedProcessID, process.SystemProcessID:
|
||||
break checkLevelsLoop
|
||||
default: // normal process
|
||||
// Check if the requesting process is in database root / updates dir.
|
||||
if realPath, err := filepath.EvalSymlinks(proc.Path); err == nil {
|
||||
if strings.HasPrefix(realPath, authenticatedPath) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add checked path to list.
|
||||
procsChecked = append(procsChecked, proc.Path)
|
||||
|
||||
// Get the parent process.
|
||||
if i < checkLevels {
|
||||
// save previous PID
|
||||
previousPid = proc.Pid
|
||||
|
||||
// get parent process
|
||||
proc, err = process.GetOrFindProcess(ctx, proc.ParentPid)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Debugf("filter: failed to get parent process of api request: %s", err)
|
||||
break
|
||||
}
|
||||
|
||||
// abort if we are looping
|
||||
if proc.Pid == previousPid {
|
||||
// this also catches -1 pid loops
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch originalPid {
|
||||
case process.UnidentifiedProcessID:
|
||||
log.Tracer(ctx).Warningf("filter: denying api access: failed to identify process")
|
||||
return true, fmt.Errorf(deniedMsgUnidentified, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
|
||||
|
||||
case process.SystemProcessID:
|
||||
log.Tracer(ctx).Warningf("filter: denying api access: request by system")
|
||||
return false, fmt.Errorf(deniedMsgSystem, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
|
||||
|
||||
default: // normal process
|
||||
log.Tracer(ctx).Warningf("filter: denying api access to %s - also checked %s (trusted root is %s)", procsChecked[0], strings.Join(procsChecked[1:], " "), dataRoot.Path)
|
||||
return false, fmt.Errorf( //nolint:stylecheck // message for user
|
||||
deniedMsgUnauthorized,
|
||||
api.ErrAPIAccessDeniedMessage,
|
||||
strings.Join(procsChecked, "\n"),
|
||||
authenticatedPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
64
service/firewall/bypassing.go
Normal file
64
service/firewall/bypassing.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/service/compat"
|
||||
"github.com/safing/portmaster/service/nameserver/nsutil"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
)
|
||||
|
||||
var resolverFilterLists = []string{"17-DNS"}
|
||||
|
||||
// PreventBypassing checks if the connection should be denied or permitted
|
||||
// based on some bypass protection checks.
|
||||
func PreventBypassing(ctx context.Context, conn *network.Connection) (endpoints.EPResult, string, nsutil.Responder) {
|
||||
// Exclude incoming connections.
|
||||
if conn.Inbound {
|
||||
return endpoints.NoMatch, "", nil
|
||||
}
|
||||
|
||||
// Exclude ICMP.
|
||||
switch packet.IPProtocol(conn.Entity.Protocol) { //nolint:exhaustive // Checking for specific values only.
|
||||
case packet.ICMP, packet.ICMPv6:
|
||||
return endpoints.NoMatch, "", nil
|
||||
}
|
||||
|
||||
// Block firefox canary domain to disable DoH.
|
||||
// This MUST also affect the System Resolver, because the return value must
|
||||
// be correct for this to work.
|
||||
if strings.ToLower(conn.Entity.Domain) == "use-application-dns.net." {
|
||||
return endpoints.Denied,
|
||||
"blocked canary domain to prevent enabling of DNS-over-HTTPs",
|
||||
nsutil.NxDomain()
|
||||
}
|
||||
|
||||
// Exclude DNS requests coming from the System Resolver.
|
||||
// This MUST also affect entities in the secure dns filter list, else the
|
||||
// System Resolver is wrongly accused of bypassing.
|
||||
if conn.Type == network.DNSRequest && conn.Process().IsSystemResolver() {
|
||||
return endpoints.NoMatch, "", nil
|
||||
}
|
||||
|
||||
// Block bypass attempts using an (encrypted) DNS server.
|
||||
switch {
|
||||
case conn.Entity.Port == 53:
|
||||
return endpoints.Denied,
|
||||
"blocked DNS query, manual dns setup required",
|
||||
nsutil.BlockIP()
|
||||
case conn.Entity.Port == 853:
|
||||
// Block connections to port 853 - DNS over TLS.
|
||||
fallthrough
|
||||
case conn.Entity.MatchLists(resolverFilterLists):
|
||||
// Block connection entities in the secure dns filter list.
|
||||
compat.ReportSecureDNSBypassIssue(conn.Process())
|
||||
return endpoints.Denied,
|
||||
"blocked rogue connection to DNS resolver",
|
||||
nsutil.BlockIP()
|
||||
}
|
||||
|
||||
return endpoints.NoMatch, "", nil
|
||||
}
|
||||
153
service/firewall/config.go
Normal file
153
service/firewall/config.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portmaster/service/core"
|
||||
"github.com/safing/portmaster/spn/captain"
|
||||
)
|
||||
|
||||
// Configuration Keys.
|
||||
var (
|
||||
CfgOptionEnableFilterKey = "filter/enable"
|
||||
filterEnabled config.BoolOption
|
||||
|
||||
CfgOptionAskWithSystemNotificationsKey = "filter/askWithSystemNotifications"
|
||||
cfgOptionAskWithSystemNotificationsOrder = 2
|
||||
askWithSystemNotifications config.BoolOption
|
||||
|
||||
CfgOptionAskTimeoutKey = "filter/askTimeout"
|
||||
cfgOptionAskTimeoutOrder = 3
|
||||
askTimeout config.IntOption
|
||||
|
||||
CfgOptionPermanentVerdictsKey = "filter/permanentVerdicts"
|
||||
cfgOptionPermanentVerdictsOrder = 80
|
||||
permanentVerdicts config.BoolOption
|
||||
|
||||
CfgOptionDNSQueryInterceptionKey = "filter/dnsQueryInterception"
|
||||
cfgOptionDNSQueryInterceptionOrder = 81
|
||||
dnsQueryInterception config.BoolOption
|
||||
)
|
||||
|
||||
func registerConfig() error {
|
||||
err := config.Register(&config.Option{
|
||||
Name: "Enable Privacy Filter",
|
||||
Key: CfgOptionEnableFilterKey,
|
||||
Description: "Enable the Privacy Filter. If turned off, all privacy filter protections are fully disabled on this device. Not meant to be disabled in production - only turn off for testing.",
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
DefaultValue: true,
|
||||
Annotations: config.Annotations{
|
||||
config.CategoryAnnotation: "General",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filterEnabled = config.Concurrent.GetAsBool(CfgOptionEnableFilterKey, true)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Permanent Verdicts",
|
||||
Key: CfgOptionPermanentVerdictsKey,
|
||||
Description: "The Portmaster's system integration intercepts every single packet. Usually the first packet is enough for the Portmaster to set the verdict for a connection - ie. to allow or deny it. Making these verdicts permanent means that the Portmaster will tell the system integration that is does not want to see any more packets of that single connection. This brings a major performance increase.",
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
DefaultValue: true,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionPermanentVerdictsOrder,
|
||||
config.CategoryAnnotation: "Advanced",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
permanentVerdicts = config.Concurrent.GetAsBool(CfgOptionPermanentVerdictsKey, true)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Seamless DNS Integration",
|
||||
Key: CfgOptionDNSQueryInterceptionKey,
|
||||
Description: "Intercept and redirect astray DNS queries to the Portmaster's internal DNS server. This enables seamless DNS integration without having to configure the system or other software. However, this may lead to compatibility issues with other software that attempts the same.",
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
DefaultValue: true,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionDNSQueryInterceptionOrder,
|
||||
config.CategoryAnnotation: "Advanced",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dnsQueryInterception = config.Concurrent.GetAsBool(CfgOptionDNSQueryInterceptionKey, true)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Prompt Desktop Notifications",
|
||||
Key: CfgOptionAskWithSystemNotificationsKey,
|
||||
Description: `In addition to showing prompt notifications in the Portmaster App, also send them to the Desktop. This requires the Portmaster Notifier to be running. Requires Desktop Notifications to be enabled.`,
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
DefaultValue: true,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionAskWithSystemNotificationsOrder,
|
||||
config.CategoryAnnotation: "General",
|
||||
config.RequiresAnnotation: config.ValueRequirement{
|
||||
Key: notifications.CfgUseSystemNotificationsKey,
|
||||
Value: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
askWithSystemNotifications = config.Concurrent.GetAsBool(CfgOptionAskWithSystemNotificationsKey, true)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Prompt Timeout",
|
||||
Key: CfgOptionAskTimeoutKey,
|
||||
Description: "How long the Portmaster will wait for a reply to a prompt notification. Please note that Desktop Notifications might not respect this or have their own limits.",
|
||||
OptType: config.OptTypeInt,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
DefaultValue: 60,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionAskTimeoutOrder,
|
||||
config.UnitAnnotation: "seconds",
|
||||
config.CategoryAnnotation: "General",
|
||||
},
|
||||
ValidationRegex: `^[1-9][0-9]{1,5}$`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
askTimeout = config.Concurrent.GetAsInt(CfgOptionAskTimeoutKey, 60)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Config variables for interception and filter module.
|
||||
// Everything is registered by the interception module, as the filter module
|
||||
// can be disabled.
|
||||
var (
|
||||
devMode config.BoolOption
|
||||
apiListenAddress config.StringOption
|
||||
|
||||
tunnelEnabled config.BoolOption
|
||||
useCommunityNodes config.BoolOption
|
||||
|
||||
configReady = abool.New()
|
||||
)
|
||||
|
||||
func getConfig() {
|
||||
devMode = config.Concurrent.GetAsBool(core.CfgDevModeKey, false)
|
||||
apiListenAddress = config.GetAsString(api.CfgDefaultListenAddressKey, "")
|
||||
|
||||
tunnelEnabled = config.Concurrent.GetAsBool(captain.CfgOptionEnableSPNKey, false)
|
||||
useCommunityNodes = config.Concurrent.GetAsBool(captain.CfgOptionUseCommunityNodesKey, true)
|
||||
|
||||
configReady.Set()
|
||||
}
|
||||
348
service/firewall/dns.go
Normal file
348
service/firewall/dns.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
)
|
||||
|
||||
func filterDNSSection(
|
||||
ctx context.Context,
|
||||
entries []dns.RR,
|
||||
p *profile.LayeredProfile,
|
||||
resolverScope netutils.IPScope,
|
||||
sysResolver bool,
|
||||
) ([]dns.RR, []string, int, string) {
|
||||
// Will be filled 1:1 most of the time.
|
||||
goodEntries := make([]dns.RR, 0, len(entries))
|
||||
|
||||
// Will stay empty most of the time.
|
||||
var filteredRecords []string
|
||||
|
||||
// keeps track of the number of valid and allowed
|
||||
// A and AAAA records.
|
||||
var allowedAddressRecords int
|
||||
var interveningOptionKey string
|
||||
|
||||
for _, rr := range entries {
|
||||
// get IP and classification
|
||||
var ip net.IP
|
||||
switch v := rr.(type) {
|
||||
case *dns.A:
|
||||
ip = v.A
|
||||
case *dns.AAAA:
|
||||
ip = v.AAAA
|
||||
default:
|
||||
// add non A/AAAA entries
|
||||
// TODO: Add support for dns.SVCB and dns.HTTPS
|
||||
goodEntries = append(goodEntries, rr)
|
||||
continue
|
||||
}
|
||||
ipScope := netutils.GetIPScope(ip)
|
||||
|
||||
if p.RemoveOutOfScopeDNS() {
|
||||
switch {
|
||||
case ipScope.IsLocalhost():
|
||||
// No DNS should return localhost addresses
|
||||
filteredRecords = append(filteredRecords, formatRR(rr))
|
||||
interveningOptionKey = profile.CfgOptionRemoveOutOfScopeDNSKey
|
||||
log.Tracer(ctx).Tracef("filter: RR violates resolver scope: %s", formatRR(rr))
|
||||
continue
|
||||
|
||||
case resolverScope.IsGlobal() && ipScope.IsLAN() && !sysResolver:
|
||||
// No global DNS should return LAN addresses
|
||||
filteredRecords = append(filteredRecords, formatRR(rr))
|
||||
interveningOptionKey = profile.CfgOptionRemoveOutOfScopeDNSKey
|
||||
log.Tracer(ctx).Tracef("filter: RR violates resolver scope: %s", formatRR(rr))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if p.RemoveBlockedDNS() && !sysResolver {
|
||||
// filter by flags
|
||||
switch {
|
||||
case p.BlockScopeInternet() && ipScope.IsGlobal():
|
||||
filteredRecords = append(filteredRecords, formatRR(rr))
|
||||
interveningOptionKey = profile.CfgOptionBlockScopeInternetKey
|
||||
log.Tracer(ctx).Tracef("filter: RR is in blocked scope Internet: %s", formatRR(rr))
|
||||
continue
|
||||
|
||||
case p.BlockScopeLAN() && ipScope.IsLAN():
|
||||
filteredRecords = append(filteredRecords, formatRR(rr))
|
||||
interveningOptionKey = profile.CfgOptionBlockScopeLANKey
|
||||
log.Tracer(ctx).Tracef("filter: RR is in blocked scope LAN: %s", formatRR(rr))
|
||||
continue
|
||||
|
||||
case p.BlockScopeLocal() && ipScope.IsLocalhost():
|
||||
filteredRecords = append(filteredRecords, formatRR(rr))
|
||||
interveningOptionKey = profile.CfgOptionBlockScopeLocalKey
|
||||
log.Tracer(ctx).Tracef("filter: RR is in blocked scope Localhost: %s", formatRR(rr))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: filter by endpoint list (IP only)
|
||||
}
|
||||
|
||||
// if survived, add to good entries
|
||||
allowedAddressRecords++
|
||||
goodEntries = append(goodEntries, rr)
|
||||
}
|
||||
|
||||
return goodEntries, filteredRecords, allowedAddressRecords, interveningOptionKey
|
||||
}
|
||||
|
||||
func filterDNSResponse(
|
||||
ctx context.Context,
|
||||
conn *network.Connection,
|
||||
p *profile.LayeredProfile,
|
||||
rrCache *resolver.RRCache,
|
||||
sysResolver bool,
|
||||
) *resolver.RRCache {
|
||||
// do not modify own queries
|
||||
if conn.Process().Pid == ownPID {
|
||||
return rrCache
|
||||
}
|
||||
|
||||
// check if DNS response filtering is completely turned off
|
||||
if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() {
|
||||
return rrCache
|
||||
}
|
||||
|
||||
var filteredRecords []string
|
||||
var validIPs int
|
||||
var interveningOptionKey string
|
||||
|
||||
rrCache.Answer, filteredRecords, validIPs, interveningOptionKey = filterDNSSection(ctx, rrCache.Answer, p, rrCache.Resolver.IPScope, sysResolver)
|
||||
if len(filteredRecords) > 0 {
|
||||
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
|
||||
}
|
||||
|
||||
// Don't count the valid IPs in the extra section.
|
||||
rrCache.Extra, filteredRecords, _, _ = filterDNSSection(ctx, rrCache.Extra, p, rrCache.Resolver.IPScope, sysResolver)
|
||||
if len(filteredRecords) > 0 {
|
||||
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
|
||||
}
|
||||
|
||||
if len(rrCache.FilteredEntries) > 0 {
|
||||
rrCache.Filtered = true
|
||||
if validIPs == 0 {
|
||||
switch interveningOptionKey {
|
||||
case profile.CfgOptionBlockScopeInternetKey:
|
||||
conn.Block("Internet access blocked", interveningOptionKey)
|
||||
case profile.CfgOptionBlockScopeLANKey:
|
||||
conn.Block("LAN access blocked", interveningOptionKey)
|
||||
case profile.CfgOptionBlockScopeLocalKey:
|
||||
conn.Block("Localhost access blocked", interveningOptionKey)
|
||||
case profile.CfgOptionRemoveOutOfScopeDNSKey:
|
||||
conn.Block("DNS global/private split-view violation", interveningOptionKey)
|
||||
default:
|
||||
conn.Block("DNS response only contained to-be-blocked IPs", interveningOptionKey)
|
||||
}
|
||||
|
||||
return rrCache
|
||||
}
|
||||
}
|
||||
|
||||
return rrCache
|
||||
}
|
||||
|
||||
// FilterResolvedDNS filters a dns response according to the application
|
||||
// profile and settings.
|
||||
func FilterResolvedDNS(
|
||||
ctx context.Context,
|
||||
conn *network.Connection,
|
||||
q *resolver.Query,
|
||||
rrCache *resolver.RRCache,
|
||||
) *resolver.RRCache {
|
||||
// Check if we have a process and profile.
|
||||
layeredProfile := conn.Process().Profile()
|
||||
if layeredProfile == nil {
|
||||
log.Tracer(ctx).Warning("unknown process or profile")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Don't filter env responses.
|
||||
if rrCache.Resolver.Type == resolver.ServerTypeEnv {
|
||||
return rrCache
|
||||
}
|
||||
|
||||
// special grant for connectivity domains
|
||||
if checkConnectivityDomain(ctx, conn, layeredProfile, nil) {
|
||||
// returns true if check triggered
|
||||
return rrCache
|
||||
}
|
||||
|
||||
// Only filter critical things if request comes from the system resolver.
|
||||
sysResolver := conn.Process().IsSystemResolver()
|
||||
|
||||
// Filter dns records and return if the query is blocked.
|
||||
rrCache = filterDNSResponse(ctx, conn, layeredProfile, rrCache, sysResolver)
|
||||
if conn.Verdict == network.VerdictBlock {
|
||||
return rrCache
|
||||
}
|
||||
|
||||
// Block by CNAMEs.
|
||||
if !sysResolver {
|
||||
mayBlockCNAMEs(ctx, conn, layeredProfile)
|
||||
}
|
||||
|
||||
return rrCache
|
||||
}
|
||||
|
||||
func mayBlockCNAMEs(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile) bool {
|
||||
// if we have CNAMEs and the profile is configured to filter them
|
||||
// we need to re-check the lists and endpoints here
|
||||
if p.FilterCNAMEs() {
|
||||
conn.Entity.ResetLists()
|
||||
conn.Entity.EnableCNAMECheck(ctx, true)
|
||||
|
||||
result, reason := p.MatchEndpoint(ctx, conn.Entity)
|
||||
if result == endpoints.Denied {
|
||||
conn.BlockWithContext(reason.String(), profile.CfgOptionFilterCNAMEKey, reason.Context())
|
||||
return true
|
||||
}
|
||||
|
||||
if result == endpoints.NoMatch {
|
||||
result, reason = p.MatchFilterLists(ctx, conn.Entity)
|
||||
if result == endpoints.Denied {
|
||||
conn.BlockWithContext(reason.String(), profile.CfgOptionFilterCNAMEKey, reason.Context())
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateIPsAndCNAMEs saves all the IP->Name mappings to the cache database and
|
||||
// updates the CNAMEs in the Connection's Entity.
|
||||
func UpdateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) {
|
||||
// Sanity check input, as this is called from defer.
|
||||
if q == nil || rrCache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get profileID for scoping IPInfo.
|
||||
var profileID string
|
||||
localProfile := conn.Process().Profile().LocalProfile()
|
||||
switch localProfile.ID {
|
||||
case profile.UnidentifiedProfileID,
|
||||
profile.SystemResolverProfileID:
|
||||
profileID = resolver.IPInfoProfileScopeGlobal
|
||||
default:
|
||||
profileID = localProfile.ID
|
||||
}
|
||||
|
||||
// Collect IPs and CNAMEs.
|
||||
cnames := make(map[string]string)
|
||||
ips := make([]net.IP, 0, len(rrCache.Answer))
|
||||
|
||||
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
|
||||
switch v := rr.(type) {
|
||||
case *dns.CNAME:
|
||||
cnames[v.Hdr.Name] = v.Target
|
||||
|
||||
case *dns.A:
|
||||
ips = append(ips, v.A)
|
||||
|
||||
case *dns.AAAA:
|
||||
ips = append(ips, v.AAAA)
|
||||
|
||||
case *dns.SVCB:
|
||||
if len(v.Target) >= 2 { // Ignore "" and ".".
|
||||
cnames[v.Hdr.Name] = v.Target
|
||||
}
|
||||
for _, pair := range v.Value {
|
||||
switch svcbParam := pair.(type) {
|
||||
case *dns.SVCBIPv4Hint:
|
||||
ips = append(ips, svcbParam.Hint...)
|
||||
case *dns.SVCBIPv6Hint:
|
||||
ips = append(ips, svcbParam.Hint...)
|
||||
}
|
||||
}
|
||||
|
||||
case *dns.HTTPS:
|
||||
if len(v.Target) >= 2 { // Ignore "" and ".".
|
||||
cnames[v.Hdr.Name] = v.Target
|
||||
}
|
||||
for _, pair := range v.Value {
|
||||
switch svcbParam := pair.(type) {
|
||||
case *dns.SVCBIPv4Hint:
|
||||
ips = append(ips, svcbParam.Hint...)
|
||||
case *dns.SVCBIPv6Hint:
|
||||
ips = append(ips, svcbParam.Hint...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Package IPs and CNAMEs into IPInfo structs.
|
||||
for _, ip := range ips {
|
||||
// Never save domain attributions for localhost IPs.
|
||||
if netutils.GetIPScope(ip) == netutils.HostLocal {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create new record for this IP.
|
||||
record := resolver.ResolvedDomain{
|
||||
Domain: q.FQDN,
|
||||
Resolver: rrCache.Resolver,
|
||||
DNSRequestContext: rrCache.ToDNSRequestContext(),
|
||||
Expires: rrCache.Expires,
|
||||
}
|
||||
|
||||
// Resolve all CNAMEs in the correct order and add the to the record.
|
||||
domain := q.FQDN
|
||||
for {
|
||||
nextDomain, isCNAME := cnames[domain]
|
||||
if !isCNAME {
|
||||
break
|
||||
}
|
||||
|
||||
record.CNAMEs = append(record.CNAMEs, nextDomain)
|
||||
domain = nextDomain
|
||||
}
|
||||
|
||||
// Update the entity to include the CNAMEs of the query response.
|
||||
conn.Entity.CNAME = record.CNAMEs
|
||||
|
||||
// Check if there is an existing record for this DNS response.
|
||||
// Else create a new one.
|
||||
ipString := ip.String()
|
||||
info, err := resolver.GetIPInfo(profileID, ipString)
|
||||
if err != nil {
|
||||
if !errors.Is(err, database.ErrNotFound) {
|
||||
log.Errorf("nameserver: failed to search for IP info record: %s", err)
|
||||
}
|
||||
|
||||
info = &resolver.IPInfo{
|
||||
IP: ipString,
|
||||
ProfileID: profileID,
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new record to the resolved domains for this IP and scope.
|
||||
info.AddDomain(record)
|
||||
|
||||
// Save if the record is new or has been updated.
|
||||
if err := info.Save(); err != nil {
|
||||
log.Errorf("nameserver: failed to save IP info record: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// formatRR is a friendlier alternative to miekg/dns.RR.String().
|
||||
func formatRR(rr dns.RR) string {
|
||||
return strings.ReplaceAll(rr.String(), "\t", " ")
|
||||
}
|
||||
102
service/firewall/inspection/inspection.go
Normal file
102
service/firewall/inspection/inspection.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package inspection
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
//nolint:golint,stylecheck // FIXME
|
||||
const (
|
||||
DO_NOTHING uint8 = iota
|
||||
BLOCK_PACKET
|
||||
DROP_PACKET
|
||||
BLOCK_CONN
|
||||
DROP_CONN
|
||||
STOP_INSPECTING
|
||||
)
|
||||
|
||||
type inspectorFn func(*network.Connection, packet.Packet) uint8
|
||||
|
||||
var (
|
||||
inspectors []inspectorFn
|
||||
inspectorNames []string
|
||||
inspectVerdicts []network.Verdict
|
||||
inspectorsLock sync.Mutex
|
||||
)
|
||||
|
||||
// RegisterInspector registers a traffic inspector.
|
||||
func RegisterInspector(name string, inspector inspectorFn, inspectVerdict network.Verdict) (index int) {
|
||||
inspectorsLock.Lock()
|
||||
defer inspectorsLock.Unlock()
|
||||
index = len(inspectors)
|
||||
inspectors = append(inspectors, inspector)
|
||||
inspectorNames = append(inspectorNames, name)
|
||||
inspectVerdicts = append(inspectVerdicts, inspectVerdict)
|
||||
return
|
||||
}
|
||||
|
||||
// RunInspectors runs all the applicable inspectors on the given packet.
|
||||
func RunInspectors(conn *network.Connection, pkt packet.Packet) (network.Verdict, bool) {
|
||||
// inspectorsLock.Lock()
|
||||
// defer inspectorsLock.Unlock()
|
||||
|
||||
activeInspectors := conn.GetActiveInspectors()
|
||||
if activeInspectors == nil {
|
||||
activeInspectors = make([]bool, len(inspectors))
|
||||
conn.SetActiveInspectors(activeInspectors)
|
||||
}
|
||||
|
||||
inspectorData := conn.GetInspectorData()
|
||||
if inspectorData == nil {
|
||||
inspectorData = make(map[uint8]interface{})
|
||||
conn.SetInspectorData(inspectorData)
|
||||
}
|
||||
|
||||
continueInspection := false
|
||||
verdict := network.VerdictUndecided
|
||||
|
||||
for key, skip := range activeInspectors {
|
||||
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
|
||||
// check if the active verdict is already past the inspection criteria.
|
||||
if conn.Verdict > inspectVerdicts[key] {
|
||||
activeInspectors[key] = true
|
||||
continue
|
||||
}
|
||||
|
||||
action := inspectors[key](conn, pkt) // Actually run inspector
|
||||
switch action {
|
||||
case DO_NOTHING:
|
||||
if verdict < network.VerdictAccept {
|
||||
verdict = network.VerdictAccept
|
||||
}
|
||||
continueInspection = true
|
||||
case BLOCK_PACKET:
|
||||
if verdict < network.VerdictBlock {
|
||||
verdict = network.VerdictBlock
|
||||
}
|
||||
continueInspection = true
|
||||
case DROP_PACKET:
|
||||
verdict = network.VerdictDrop
|
||||
continueInspection = true
|
||||
case BLOCK_CONN:
|
||||
conn.SetVerdict(network.VerdictBlock, "", "", nil)
|
||||
verdict = conn.Verdict
|
||||
activeInspectors[key] = true
|
||||
case DROP_CONN:
|
||||
conn.SetVerdict(network.VerdictDrop, "", "", nil)
|
||||
verdict = conn.Verdict
|
||||
activeInspectors[key] = true
|
||||
case STOP_INSPECTING:
|
||||
activeInspectors[key] = true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return verdict, continueInspection
|
||||
}
|
||||
147
service/firewall/interception/ebpf/bandwidth/bpf_bpfeb.go
Normal file
147
service/firewall/interception/ebpf/bandwidth/bpf_bpfeb.go
Normal file
@@ -0,0 +1,147 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
type bpfSkInfo struct {
|
||||
Rx uint64
|
||||
Tx uint64
|
||||
Reported uint64
|
||||
}
|
||||
|
||||
type bpfSkKey struct {
|
||||
SrcIp [4]uint32
|
||||
DstIp [4]uint32
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
Protocol uint8
|
||||
Ipv6 uint8
|
||||
_ [2]byte
|
||||
}
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
SocketOperations *ebpf.ProgramSpec `ebpf:"socket_operations"`
|
||||
UdpRecvmsg *ebpf.ProgramSpec `ebpf:"udp_recvmsg"`
|
||||
UdpSendmsg *ebpf.ProgramSpec `ebpf:"udp_sendmsg"`
|
||||
Udpv6Recvmsg *ebpf.ProgramSpec `ebpf:"udpv6_recvmsg"`
|
||||
Udpv6Sendmsg *ebpf.ProgramSpec `ebpf:"udpv6_sendmsg"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
PmBandwidthMap *ebpf.MapSpec `ebpf:"pm_bandwidth_map"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
PmBandwidthMap *ebpf.Map `ebpf:"pm_bandwidth_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.PmBandwidthMap,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
SocketOperations *ebpf.Program `ebpf:"socket_operations"`
|
||||
UdpRecvmsg *ebpf.Program `ebpf:"udp_recvmsg"`
|
||||
UdpSendmsg *ebpf.Program `ebpf:"udp_sendmsg"`
|
||||
Udpv6Recvmsg *ebpf.Program `ebpf:"udpv6_recvmsg"`
|
||||
Udpv6Sendmsg *ebpf.Program `ebpf:"udpv6_sendmsg"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.SocketOperations,
|
||||
p.UdpRecvmsg,
|
||||
p.UdpSendmsg,
|
||||
p.Udpv6Recvmsg,
|
||||
p.Udpv6Sendmsg,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfeb.o
|
||||
var _BpfBytes []byte
|
||||
BIN
service/firewall/interception/ebpf/bandwidth/bpf_bpfeb.o
Normal file
BIN
service/firewall/interception/ebpf/bandwidth/bpf_bpfeb.o
Normal file
Binary file not shown.
147
service/firewall/interception/ebpf/bandwidth/bpf_bpfel.go
Normal file
147
service/firewall/interception/ebpf/bandwidth/bpf_bpfel.go
Normal file
@@ -0,0 +1,147 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build 386 || amd64 || amd64p32 || arm || arm64 || loong64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
type bpfSkInfo struct {
|
||||
Rx uint64
|
||||
Tx uint64
|
||||
Reported uint64
|
||||
}
|
||||
|
||||
type bpfSkKey struct {
|
||||
SrcIp [4]uint32
|
||||
DstIp [4]uint32
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
Protocol uint8
|
||||
Ipv6 uint8
|
||||
_ [2]byte
|
||||
}
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
SocketOperations *ebpf.ProgramSpec `ebpf:"socket_operations"`
|
||||
UdpRecvmsg *ebpf.ProgramSpec `ebpf:"udp_recvmsg"`
|
||||
UdpSendmsg *ebpf.ProgramSpec `ebpf:"udp_sendmsg"`
|
||||
Udpv6Recvmsg *ebpf.ProgramSpec `ebpf:"udpv6_recvmsg"`
|
||||
Udpv6Sendmsg *ebpf.ProgramSpec `ebpf:"udpv6_sendmsg"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
PmBandwidthMap *ebpf.MapSpec `ebpf:"pm_bandwidth_map"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
PmBandwidthMap *ebpf.Map `ebpf:"pm_bandwidth_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.PmBandwidthMap,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
SocketOperations *ebpf.Program `ebpf:"socket_operations"`
|
||||
UdpRecvmsg *ebpf.Program `ebpf:"udp_recvmsg"`
|
||||
UdpSendmsg *ebpf.Program `ebpf:"udp_sendmsg"`
|
||||
Udpv6Recvmsg *ebpf.Program `ebpf:"udpv6_recvmsg"`
|
||||
Udpv6Sendmsg *ebpf.Program `ebpf:"udpv6_sendmsg"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.SocketOperations,
|
||||
p.UdpRecvmsg,
|
||||
p.UdpSendmsg,
|
||||
p.Udpv6Recvmsg,
|
||||
p.Udpv6Sendmsg,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfel.o
|
||||
var _BpfBytes []byte
|
||||
BIN
service/firewall/interception/ebpf/bandwidth/bpf_bpfel.o
Normal file
BIN
service/firewall/interception/ebpf/bandwidth/bpf_bpfel.o
Normal file
Binary file not shown.
192
service/firewall/interception/ebpf/bandwidth/interface.go
Normal file
192
service/firewall/interception/ebpf/bandwidth/interface.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
"github.com/cilium/ebpf/link"
|
||||
"github.com/cilium/ebpf/rlimit"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" bpf ../programs/bandwidth.c
|
||||
|
||||
var ebpfLoadingFailed atomic.Uint32
|
||||
|
||||
// BandwidthStatsWorker monitors connection bandwidth using ebpf.
|
||||
func BandwidthStatsWorker(ctx context.Context, collectInterval time.Duration, bandwidthUpdates chan *packet.BandwidthUpdate) error {
|
||||
// Allow the current process to lock memory for eBPF resources.
|
||||
err := rlimit.RemoveMemlock()
|
||||
if err != nil {
|
||||
if ebpfLoadingFailed.Add(1) >= 5 {
|
||||
log.Warningf("ebpf: failed to remove memlock 5 times, giving up with error %s", err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("ebpf: failed to remove memlock: %w", err)
|
||||
}
|
||||
|
||||
// Load pre-compiled programs and maps into the kernel.
|
||||
objs := bpfObjects{}
|
||||
if err := loadBpfObjects(&objs, nil); err != nil {
|
||||
if ebpfLoadingFailed.Add(1) >= 5 {
|
||||
log.Warningf("ebpf: failed to load ebpf object 5 times, giving up with error %s", err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("ebpf: failed to load ebpf object: %w", err)
|
||||
}
|
||||
defer objs.Close() //nolint:errcheck
|
||||
|
||||
// Find the cgroup path
|
||||
path, err := findCgroupPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to find cgroup paths: %w", err)
|
||||
}
|
||||
|
||||
// Attach socket options for monitoring connections
|
||||
sockOptionsLink, err := link.AttachCgroup(link.CgroupOptions{
|
||||
Path: path,
|
||||
Program: objs.bpfPrograms.SocketOperations,
|
||||
Attach: ebpf.AttachCGroupSockOps,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to open module sockops: %w", err)
|
||||
}
|
||||
defer sockOptionsLink.Close() //nolint:errcheck
|
||||
|
||||
// Attach Udp Ipv4 recive message tracing
|
||||
udpv4RMLink, err := link.AttachTracing(link.TracingOptions{
|
||||
Program: objs.bpfPrograms.UdpRecvmsg,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to open trace Udp IPv4 recvmsg: %w", err)
|
||||
}
|
||||
defer udpv4RMLink.Close() //nolint:errcheck
|
||||
|
||||
// Attach UDP IPv4 send message tracing
|
||||
udpv4SMLink, err := link.AttachTracing(link.TracingOptions{
|
||||
Program: objs.bpfPrograms.UdpSendmsg,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to open trace Udp IPv4 sendmsg: %w", err)
|
||||
}
|
||||
defer udpv4SMLink.Close() //nolint:errcheck
|
||||
|
||||
// Attach UDP IPv6 receive message tracing
|
||||
udpv6RMLink, err := link.AttachTracing(link.TracingOptions{
|
||||
Program: objs.bpfPrograms.Udpv6Recvmsg,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to open trace Udp IPv6 recvmsg: %w", err)
|
||||
}
|
||||
defer udpv6RMLink.Close() //nolint:errcheck
|
||||
|
||||
// Attach UDP IPv6 send message tracing
|
||||
udpv6SMLink, err := link.AttachTracing(link.TracingOptions{
|
||||
Program: objs.bpfPrograms.Udpv6Sendmsg,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to open trace Udp IPv6 sendmsg: %w", err)
|
||||
}
|
||||
defer udpv6SMLink.Close() //nolint:errcheck
|
||||
|
||||
// Setup ticker.
|
||||
ticker := time.NewTicker(collectInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Collect bandwidth at every tick.
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
reportBandwidth(ctx, objs, bandwidthUpdates)
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reportBandwidth reports the bandwidth to the given updates channel.
|
||||
func reportBandwidth(ctx context.Context, objs bpfObjects, bandwidthUpdates chan *packet.BandwidthUpdate) {
|
||||
var (
|
||||
skKey bpfSkKey
|
||||
skInfo bpfSkInfo
|
||||
updated int
|
||||
skipped int
|
||||
)
|
||||
|
||||
iter := objs.bpfMaps.PmBandwidthMap.Iterate()
|
||||
for iter.Next(&skKey, &skInfo) {
|
||||
// Check if already reported.
|
||||
if skInfo.Reported >= 1 {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
// Mark as reported and update the map.
|
||||
skInfo.Reported = 1
|
||||
if err := objs.bpfMaps.PmBandwidthMap.Update(&skKey, &skInfo, ebpf.UpdateExist); err != nil {
|
||||
log.Debugf("ebpf: failed to mark bandwidth map entry as reported: %s", err)
|
||||
}
|
||||
|
||||
connID := packet.CreateConnectionID(
|
||||
packet.IPProtocol(skKey.Protocol),
|
||||
convertArrayToIP(skKey.SrcIp, skKey.Ipv6 == 1), skKey.SrcPort,
|
||||
convertArrayToIP(skKey.DstIp, skKey.Ipv6 == 1), skKey.DstPort,
|
||||
false,
|
||||
)
|
||||
update := &packet.BandwidthUpdate{
|
||||
ConnID: connID,
|
||||
BytesReceived: skInfo.Rx,
|
||||
BytesSent: skInfo.Tx,
|
||||
Method: packet.Absolute,
|
||||
}
|
||||
select {
|
||||
case bandwidthUpdates <- update:
|
||||
updated++
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
log.Warningf("ebpf: bandwidth update queue is full (updated=%d, skipped=%d), ignoring rest of batch", updated, skipped)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findCgroupPath returns the default unified path of the cgroup.
|
||||
func findCgroupPath() (string, error) {
|
||||
cgroupPath := "/sys/fs/cgroup"
|
||||
|
||||
var st syscall.Statfs_t
|
||||
err := syscall.Statfs(cgroupPath, &st)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
isCgroupV2Enabled := st.Type == unix.CGROUP2_SUPER_MAGIC
|
||||
if !isCgroupV2Enabled {
|
||||
cgroupPath = filepath.Join(cgroupPath, "unified")
|
||||
}
|
||||
return cgroupPath, nil
|
||||
}
|
||||
|
||||
// convertArrayToIP converts an array of uint32 values to a net.IP address.
|
||||
func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP {
|
||||
if !ipv6 {
|
||||
addressBuf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(addressBuf, input[0])
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
addressBuf := make([]byte, 16)
|
||||
for i := 0; i < 4; i++ {
|
||||
binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
|
||||
}
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
type bpfEvent struct {
|
||||
Saddr [4]uint32
|
||||
Daddr [4]uint32
|
||||
Sport uint16
|
||||
Dport uint16
|
||||
Pid uint32
|
||||
IpVersion uint8
|
||||
Protocol uint8
|
||||
Direction uint8
|
||||
_ [1]byte
|
||||
}
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
TcpConnect *ebpf.ProgramSpec `ebpf:"tcp_connect"`
|
||||
UdpV4Connect *ebpf.ProgramSpec `ebpf:"udp_v4_connect"`
|
||||
UdpV6Connect *ebpf.ProgramSpec `ebpf:"udp_v6_connect"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
PmConnectionEvents *ebpf.MapSpec `ebpf:"pm_connection_events"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
PmConnectionEvents *ebpf.Map `ebpf:"pm_connection_events"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.PmConnectionEvents,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
TcpConnect *ebpf.Program `ebpf:"tcp_connect"`
|
||||
UdpV4Connect *ebpf.Program `ebpf:"udp_v4_connect"`
|
||||
UdpV6Connect *ebpf.Program `ebpf:"udp_v6_connect"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.TcpConnect,
|
||||
p.UdpV4Connect,
|
||||
p.UdpV6Connect,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfeb.o
|
||||
var _BpfBytes []byte
|
||||
Binary file not shown.
@@ -0,0 +1,137 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build 386 || amd64 || amd64p32 || arm || arm64 || loong64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
type bpfEvent struct {
|
||||
Saddr [4]uint32
|
||||
Daddr [4]uint32
|
||||
Sport uint16
|
||||
Dport uint16
|
||||
Pid uint32
|
||||
IpVersion uint8
|
||||
Protocol uint8
|
||||
Direction uint8
|
||||
_ [1]byte
|
||||
}
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
TcpConnect *ebpf.ProgramSpec `ebpf:"tcp_connect"`
|
||||
UdpV4Connect *ebpf.ProgramSpec `ebpf:"udp_v4_connect"`
|
||||
UdpV6Connect *ebpf.ProgramSpec `ebpf:"udp_v6_connect"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
PmConnectionEvents *ebpf.MapSpec `ebpf:"pm_connection_events"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
PmConnectionEvents *ebpf.Map `ebpf:"pm_connection_events"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.PmConnectionEvents,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
TcpConnect *ebpf.Program `ebpf:"tcp_connect"`
|
||||
UdpV4Connect *ebpf.Program `ebpf:"udp_v4_connect"`
|
||||
UdpV6Connect *ebpf.Program `ebpf:"udp_v6_connect"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.TcpConnect,
|
||||
p.UdpV4Connect,
|
||||
p.UdpV6Connect,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfel.o
|
||||
var _BpfBytes []byte
|
||||
Binary file not shown.
176
service/firewall/interception/ebpf/connection_listener/worker.go
Normal file
176
service/firewall/interception/ebpf/connection_listener/worker.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cilium/ebpf/link"
|
||||
"github.com/cilium/ebpf/ringbuf"
|
||||
"github.com/cilium/ebpf/rlimit"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" -type Event bpf ../programs/monitor.c
|
||||
|
||||
var ebpfLoadingFailed atomic.Uint32
|
||||
|
||||
// ConnectionListenerWorker listens to new connections using ebpf.
|
||||
func ConnectionListenerWorker(ctx context.Context, packets chan packet.Packet) error {
|
||||
// Allow the current process to lock memory for eBPF resources.
|
||||
if err := rlimit.RemoveMemlock(); err != nil {
|
||||
if ebpfLoadingFailed.Add(1) >= 5 {
|
||||
log.Warningf("ebpf: failed to remove memlock 5 times, giving up with error %s", err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("ebpf: failed to remove ebpf memlock: %w", err)
|
||||
}
|
||||
|
||||
// Load pre-compiled programs and maps into the kernel.
|
||||
objs := bpfObjects{}
|
||||
if err := loadBpfObjects(&objs, nil); err != nil {
|
||||
if ebpfLoadingFailed.Add(1) >= 5 {
|
||||
log.Warningf("ebpf: failed to load ebpf object 5 times, giving up with error %s", err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("ebpf: failed to load ebpf object: %w", err)
|
||||
}
|
||||
defer objs.Close() //nolint:errcheck
|
||||
|
||||
// Create a link to the tcp_connect program.
|
||||
linkTCPConnect, err := link.AttachTracing(link.TracingOptions{
|
||||
Program: objs.bpfPrograms.TcpConnect,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to attach to tcp_v4_connect: %w", err)
|
||||
}
|
||||
defer linkTCPConnect.Close() //nolint:errcheck
|
||||
|
||||
// Create a link to the udp_v4_connect program.
|
||||
linkUDPV4, err := link.AttachTracing(link.TracingOptions{
|
||||
Program: objs.bpfPrograms.UdpV4Connect,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to attach to udp_v4_connect: %w", err)
|
||||
}
|
||||
defer linkUDPV4.Close() //nolint:errcheck
|
||||
|
||||
// Create a link to the udp_v6_connect program.
|
||||
linkUDPV6, err := link.AttachTracing(link.TracingOptions{
|
||||
Program: objs.bpfPrograms.UdpV6Connect,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to attach to udp_v6_connect: %w", err)
|
||||
}
|
||||
defer linkUDPV6.Close() //nolint:errcheck
|
||||
|
||||
// Create new reader to read events.
|
||||
rd, err := ringbuf.NewReader(objs.bpfMaps.PmConnectionEvents)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ebpf: failed to open ring buffer: %w", err)
|
||||
}
|
||||
defer rd.Close() //nolint:errcheck
|
||||
|
||||
// Start watcher to close the reader when the context is canceled.
|
||||
// TODO: Can we put this into a worker?
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
if err := rd.Close(); err != nil {
|
||||
log.Errorf("ebpf: failed closing ringbuf reader: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
// Read next event
|
||||
record, err := rd.Read()
|
||||
if err != nil {
|
||||
if errors.Is(err, ringbuf.ErrClosed) {
|
||||
// Normal return
|
||||
return nil
|
||||
}
|
||||
log.Errorf("ebpf: failed to read from ring buffer: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var event bpfEvent
|
||||
// Parse the ringbuf event entry into a bpfEvent structure.
|
||||
if err := binary.Read(bytes.NewBuffer(record.RawSample), binary.BigEndian, &event); err != nil {
|
||||
log.Errorf("ebpf: failed to parse ringbuf event: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
pkt := packet.NewInfoPacket(packet.Info{
|
||||
Inbound: event.Direction == 1,
|
||||
InTunnel: false,
|
||||
Version: packet.IPVersion(event.IpVersion),
|
||||
Protocol: packet.IPProtocol(event.Protocol),
|
||||
SrcPort: event.Sport,
|
||||
DstPort: event.Dport,
|
||||
Src: convertArrayToIPv4(event.Saddr, packet.IPVersion(event.IpVersion)),
|
||||
Dst: convertArrayToIPv4(event.Daddr, packet.IPVersion(event.IpVersion)),
|
||||
PID: int(event.Pid),
|
||||
SeenAt: time.Now(),
|
||||
})
|
||||
if isEventValid(event) {
|
||||
// DEBUG:
|
||||
// log.Debugf("ebpf: received valid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt)
|
||||
packets <- pkt
|
||||
} else {
|
||||
log.Warningf("ebpf: received invalid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isEventValid checks whether the given bpfEvent is valid or not.
|
||||
// It returns true if the event is valid, otherwise false.
|
||||
func isEventValid(event bpfEvent) bool {
|
||||
// Check if the destination port is 0
|
||||
if event.Dport == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the source port is 0
|
||||
if event.Sport == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the process ID is 0
|
||||
if event.Pid == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// If the IP version is IPv4
|
||||
if event.IpVersion == 4 {
|
||||
if event.Saddr[0] == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if event.Daddr[0] == 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// convertArrayToIPv4 converts an array of uint32 values to an IPv4 net.IP address.
|
||||
func convertArrayToIPv4(input [4]uint32, ipVersion packet.IPVersion) net.IP {
|
||||
if ipVersion == packet.IPv4 {
|
||||
addressBuf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(addressBuf, input[0])
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
addressBuf := make([]byte, 16)
|
||||
for i := 0; i < 4; i++ {
|
||||
binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
|
||||
}
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
119
service/firewall/interception/ebpf/exec/bpf_bpfeb.go
Normal file
119
service/firewall/interception/ebpf/exec/bpf_bpfeb.go
Normal file
@@ -0,0 +1,119 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
EnterExecve *ebpf.ProgramSpec `ebpf:"enter_execve"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
PmExecMap *ebpf.MapSpec `ebpf:"pm_exec_map"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
PmExecMap *ebpf.Map `ebpf:"pm_exec_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.PmExecMap,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
EnterExecve *ebpf.Program `ebpf:"enter_execve"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.EnterExecve,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfeb.o
|
||||
var _BpfBytes []byte
|
||||
BIN
service/firewall/interception/ebpf/exec/bpf_bpfeb.o
Normal file
BIN
service/firewall/interception/ebpf/exec/bpf_bpfeb.o
Normal file
Binary file not shown.
119
service/firewall/interception/ebpf/exec/bpf_bpfel.go
Normal file
119
service/firewall/interception/ebpf/exec/bpf_bpfel.go
Normal file
@@ -0,0 +1,119 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build 386 || amd64 || amd64p32 || arm || arm64 || loong64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
EnterExecve *ebpf.ProgramSpec `ebpf:"enter_execve"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
PmExecMap *ebpf.MapSpec `ebpf:"pm_exec_map"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
PmExecMap *ebpf.Map `ebpf:"pm_exec_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.PmExecMap,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
EnterExecve *ebpf.Program `ebpf:"enter_execve"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.EnterExecve,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfel.o
|
||||
var _BpfBytes []byte
|
||||
BIN
service/firewall/interception/ebpf/exec/bpf_bpfel.o
Normal file
BIN
service/firewall/interception/ebpf/exec/bpf_bpfel.o
Normal file
Binary file not shown.
249
service/firewall/interception/ebpf/exec/exec.go
Normal file
249
service/firewall/interception/ebpf/exec/exec.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cilium/ebpf/link"
|
||||
"github.com/cilium/ebpf/ringbuf"
|
||||
"github.com/cilium/ebpf/rlimit"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" bpf ../programs/exec.c
|
||||
|
||||
// These constants are defined in `bpf/handler.c` and must be kept in sync.
|
||||
const (
|
||||
arglen = 32
|
||||
argsize = 1024
|
||||
)
|
||||
|
||||
var errTracerClosed = errors.New("tracer is closed")
|
||||
|
||||
// event contains details about each exec call, sent from the eBPF program to
|
||||
// userspace through a perf ring buffer. This type must be kept in sync with
|
||||
// `event_t` in `bpf/handler.c`.
|
||||
type event struct {
|
||||
// Details about the process being launched.
|
||||
Filename [argsize]byte
|
||||
Argv [arglen][argsize]byte
|
||||
Argc uint32
|
||||
UID uint32
|
||||
GID uint32
|
||||
PID uint32
|
||||
|
||||
// Name of the calling process.
|
||||
Comm [argsize]byte
|
||||
}
|
||||
|
||||
// Event contains data about each exec event with many fields for easy
|
||||
// filtering and logging.
|
||||
type Event struct {
|
||||
Filename string `json:"filename"`
|
||||
// Argv contains the raw argv supplied to the process, including argv[0]
|
||||
// (which is equal to `filepath.Base(e.Filename)` in most circumstances).
|
||||
Argv []string `json:"argv"`
|
||||
// Truncated is true if we were unable to read all process arguments into
|
||||
// Argv because there were more than ARGLEN arguments.
|
||||
Truncated bool `json:"truncated"`
|
||||
|
||||
// These values are of the new process. Keep in mind that the exec call may
|
||||
// fail and the PID will be released in such a case.
|
||||
PID uint32 `json:"pid"`
|
||||
UID uint32 `json:"uid"`
|
||||
GID uint32 `json:"gid"`
|
||||
|
||||
// Comm is the "name" of the parent process, usually the filename of the
|
||||
// executable (but not always).
|
||||
Comm string `json:"comm"`
|
||||
}
|
||||
|
||||
// Tracer is the exec tracer itself.
|
||||
// It must be closed after use.
|
||||
type Tracer struct {
|
||||
objs bpfObjects
|
||||
tp link.Link
|
||||
rb *ringbuf.Reader
|
||||
|
||||
closeLock sync.Mutex
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// New instantiates all of the BPF objects into the running kernel, starts
|
||||
// tracing, and returns the created Tracer. After calling this successfully, the
|
||||
// caller should immediately attach a for loop running `h.Read()`.
|
||||
//
|
||||
// The returned Tracer MUST be closed when not needed anymore otherwise kernel
|
||||
// resources may be leaked.
|
||||
func New() (*Tracer, error) {
|
||||
t := &Tracer{
|
||||
tp: nil,
|
||||
rb: nil,
|
||||
|
||||
closeLock: sync.Mutex{},
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err := loadBpfObjects(&t.objs, nil); err != nil {
|
||||
return nil, fmt.Errorf("ebpf: failed to load ebpf object: %w", err)
|
||||
}
|
||||
|
||||
if err := t.start(); err != nil {
|
||||
// Best effort.
|
||||
_ = t.Close()
|
||||
return nil, fmt.Errorf("start tracer: %w", err)
|
||||
}
|
||||
|
||||
// It could be very bad if someone forgot to close this, so we'll try to
|
||||
// detect when it doesn't get closed and log a warning.
|
||||
stack := debug.Stack()
|
||||
runtime.SetFinalizer(t, func(t *Tracer) {
|
||||
err := t.Close()
|
||||
if errors.Is(err, errTracerClosed) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("tracer was finalized but was not closed, created at: %s", stack)
|
||||
log.Infof("tracers must be closed when finished with to avoid leaked kernel resources")
|
||||
if err != nil {
|
||||
log.Errorf("closing tracer failed: %+v", err)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// start loads the eBPF programs and maps into the kernel and starts them.
|
||||
// You should immediately attach a for loop running `h.Read()` after calling
|
||||
// this successfully.
|
||||
func (t *Tracer) start() error {
|
||||
// If we don't startup successfully, we need to make sure all of the
|
||||
// stuff is cleaned up properly or we'll be leaking kernel resources.
|
||||
ok := false
|
||||
defer func() {
|
||||
if !ok {
|
||||
// Best effort.
|
||||
_ = t.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Allow the current process to lock memory for eBPF resources. This
|
||||
// does nothing on 5.11+ kernels which don't need this.
|
||||
err := rlimit.RemoveMemlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove memlock: %w", err)
|
||||
}
|
||||
|
||||
// Attach the eBPF program to the `sys_enter_execve` tracepoint, which
|
||||
// is triggered at the beginning of each `execve()` syscall.
|
||||
t.tp, err = link.Tracepoint("syscalls", "sys_enter_execve", t.objs.EnterExecve, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open tracepoint: %w", err)
|
||||
}
|
||||
|
||||
// Create the reader for the event ringbuf.
|
||||
t.rb, err = ringbuf.NewReader(t.objs.PmExecMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open ringbuf reader: %w", err)
|
||||
}
|
||||
|
||||
ok = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read reads an event from the eBPF program via the ringbuf, parses it and
|
||||
// returns it. If the *tracer is closed during the blocked call, and error that
|
||||
// wraps io.EOF will be returned.
|
||||
func (t *Tracer) Read() (*Event, error) {
|
||||
rb := t.rb
|
||||
if rb == nil {
|
||||
return nil, errors.New("ringbuf reader is not initialized, tracer may not be open or may have been closed")
|
||||
}
|
||||
|
||||
record, err := rb.Read()
|
||||
if err != nil {
|
||||
if errors.Is(err, ringbuf.ErrClosed) {
|
||||
return nil, fmt.Errorf("tracer closed: %w", io.EOF)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("read from ringbuf: %w", err)
|
||||
}
|
||||
|
||||
// Parse the ringbuf event entry into an event structure.
|
||||
var rawEvent event
|
||||
err = binary.Read(bytes.NewBuffer(record.RawSample), binary.NativeEndian, &rawEvent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse raw ringbuf entry into event struct: %w", err)
|
||||
}
|
||||
|
||||
ev := &Event{
|
||||
Filename: unix.ByteSliceToString(rawEvent.Filename[:]),
|
||||
Argv: []string{}, // populated below
|
||||
Truncated: rawEvent.Argc == arglen+1,
|
||||
PID: rawEvent.PID,
|
||||
UID: rawEvent.UID,
|
||||
GID: rawEvent.GID,
|
||||
Comm: unix.ByteSliceToString(rawEvent.Comm[:]),
|
||||
}
|
||||
|
||||
// Copy only the args we're allowed to read from the array. If we read more
|
||||
// than rawEvent.Argc, we could be copying non-zeroed memory.
|
||||
argc := int(rawEvent.Argc)
|
||||
if argc > arglen {
|
||||
argc = arglen
|
||||
}
|
||||
for i := 0; i < argc; i++ {
|
||||
str := unix.ByteSliceToString(rawEvent.Argv[i][:])
|
||||
if strings.TrimSpace(str) != "" {
|
||||
ev.Argv = append(ev.Argv, str)
|
||||
}
|
||||
}
|
||||
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
// Close gracefully closes and frees all resources associated with the eBPF
|
||||
// tracepoints, maps and other resources. Any blocked `Read()` operations will
|
||||
// return an error that wraps `io.EOF`.
|
||||
func (t *Tracer) Close() error {
|
||||
t.closeLock.Lock()
|
||||
defer t.closeLock.Unlock()
|
||||
select {
|
||||
case <-t.closed:
|
||||
return errTracerClosed
|
||||
default:
|
||||
}
|
||||
close(t.closed)
|
||||
runtime.SetFinalizer(t, nil)
|
||||
|
||||
// Close everything started in h.Start() in reverse order.
|
||||
var merr error
|
||||
if t.rb != nil {
|
||||
err := t.rb.Close()
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("close ringbuf reader: %w", err))
|
||||
}
|
||||
}
|
||||
if t.tp != nil {
|
||||
err := t.tp.Close()
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("close tracepoint: %w", err))
|
||||
}
|
||||
}
|
||||
err := t.objs.Close()
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("close eBPF objects: %w", err))
|
||||
}
|
||||
|
||||
return merr
|
||||
}
|
||||
219
service/firewall/interception/ebpf/programs/bandwidth.c
Normal file
219
service/firewall/interception/ebpf/programs/bandwidth.c
Normal file
@@ -0,0 +1,219 @@
|
||||
#include "vmlinux-x86.h"
|
||||
#include "bpf/bpf_helpers.h"
|
||||
#include "bpf/bpf_tracing.h"
|
||||
#include "bpf/bpf_core_read.h"
|
||||
|
||||
#define AF_INET 2
|
||||
#define AF_INET6 10
|
||||
|
||||
#define PROTOCOL_TCP 6
|
||||
#define PROTOCOL_UDP 17
|
||||
|
||||
char __license[] SEC("license") = "GPL";
|
||||
|
||||
struct sk_key {
|
||||
u32 src_ip[4];
|
||||
u32 dst_ip[4];
|
||||
u16 src_port;
|
||||
u16 dst_port;
|
||||
u8 protocol;
|
||||
u8 ipv6;
|
||||
};
|
||||
|
||||
struct sk_info {
|
||||
u64 rx;
|
||||
u64 tx;
|
||||
u64 reported;
|
||||
};
|
||||
|
||||
// Max number of connections that will be kept. Increse the number if it's not enough.
|
||||
#define SOCKOPS_MAP_SIZE 5000
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_LRU_HASH);
|
||||
__uint(max_entries, SOCKOPS_MAP_SIZE);
|
||||
__type(key, struct sk_key);
|
||||
__type(value, struct sk_info);
|
||||
} pm_bandwidth_map SEC(".maps");
|
||||
|
||||
SEC("sockops")
|
||||
int socket_operations(struct bpf_sock_ops *skops) {
|
||||
switch (skops->op) {
|
||||
case BPF_SOCK_OPS_TCP_CONNECT_CB: // Outgoing connections
|
||||
// Set flag so any modification on the socket, will trigger this function.
|
||||
bpf_sock_ops_cb_flags_set(skops, BPF_SOCK_OPS_ALL_CB_FLAGS);
|
||||
return 0;
|
||||
case BPF_SOCK_OPS_TCP_LISTEN_CB: // Listening ports
|
||||
bpf_sock_ops_cb_flags_set(skops, BPF_SOCK_OPS_ALL_CB_FLAGS);
|
||||
// No rx tx data for this socket object.
|
||||
return 0;
|
||||
case BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB: // Incoming connections
|
||||
// Set flag so any modification on the socket, will trigger this function.
|
||||
bpf_sock_ops_cb_flags_set(skops, BPF_SOCK_OPS_ALL_CB_FLAGS);
|
||||
return 0;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
struct bpf_sock *sk = skops->sk;
|
||||
if (sk == NULL) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct sk_key key = {0};
|
||||
key.protocol = PROTOCOL_TCP;
|
||||
if(sk->family == AF_INET) {
|
||||
// Generate key for IPv4
|
||||
key.src_ip[0] = sk->src_ip4;
|
||||
key.src_port = sk->src_port;
|
||||
key.dst_ip[0] = sk->dst_ip4;
|
||||
key.dst_port = __builtin_bswap16(sk->dst_port);
|
||||
key.ipv6 = 0;
|
||||
|
||||
struct sk_info newInfo = {0};
|
||||
newInfo.rx = skops->bytes_received;
|
||||
newInfo.tx = skops->bytes_acked;
|
||||
|
||||
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
|
||||
} else if(sk->family == AF_INET6){
|
||||
// Generate key for IPv6
|
||||
key.src_ip[0] = sk->src_ip6[0];
|
||||
key.src_ip[1] = sk->src_ip6[1];
|
||||
key.src_ip[2] = sk->src_ip6[2];
|
||||
key.src_ip[3] = sk->src_ip6[3];
|
||||
key.src_port = sk->src_port;
|
||||
|
||||
key.dst_ip[0] = sk->dst_ip6[0];
|
||||
key.dst_ip[1] = sk->dst_ip6[1];
|
||||
key.dst_ip[2] = sk->dst_ip6[2];
|
||||
key.dst_ip[3] = sk->dst_ip6[3];
|
||||
key.dst_port = __builtin_bswap16(sk->dst_port);
|
||||
|
||||
key.ipv6 = 1;
|
||||
|
||||
struct sk_info newInfo = {0};
|
||||
newInfo.rx = skops->bytes_received;
|
||||
newInfo.tx = skops->bytes_acked;
|
||||
|
||||
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// udp_sendmsg hookes to the respective kernel function and saves the bandwidth data
|
||||
SEC("fentry/udp_sendmsg")
|
||||
int BPF_PROG(udp_sendmsg, struct sock *sk, struct msghdr *msg, size_t len) {
|
||||
struct sock_common *skc = &sk->__sk_common;
|
||||
|
||||
// Create a key for the map and set all the nececery information.
|
||||
struct sk_key key = {0};
|
||||
key.protocol = PROTOCOL_UDP;
|
||||
key.src_ip[0] = skc->skc_rcv_saddr;
|
||||
key.dst_ip[0] = skc->skc_daddr;
|
||||
key.src_port = skc->skc_num;
|
||||
key.dst_port = __builtin_bswap16(skc->skc_dport);
|
||||
key.ipv6 = 0;
|
||||
|
||||
// Update the map with the new information
|
||||
struct sk_info *info = bpf_map_lookup_elem(&pm_bandwidth_map, &key);
|
||||
if (info != NULL) {
|
||||
__sync_fetch_and_add(&info->tx, len); // TODO: Use atomic instead.
|
||||
__sync_fetch_and_and(&info->reported, 0); // TODO: Use atomic instead.
|
||||
} else {
|
||||
struct sk_info newInfo = {0};
|
||||
|
||||
newInfo.tx = len;
|
||||
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
|
||||
}
|
||||
|
||||
return 0;
|
||||
};
|
||||
|
||||
// udp_recvmsg hookes to the respective kernel function and saves the bandwidth data
|
||||
SEC("fentry/udp_recvmsg")
|
||||
int BPF_PROG(udp_recvmsg, struct sock *sk, struct msghdr *msg, size_t len, int flags, int *addr_len) {
|
||||
struct sock_common *skc = &sk->__sk_common;
|
||||
|
||||
// Create a key for the map and set all the nececery information.
|
||||
struct sk_key key = {0};
|
||||
key.protocol = PROTOCOL_UDP;
|
||||
key.src_ip[0] = skc->skc_rcv_saddr;
|
||||
key.dst_ip[0] = skc->skc_daddr;
|
||||
key.src_port = skc->skc_num;
|
||||
key.dst_port = __builtin_bswap16(skc->skc_dport);
|
||||
key.ipv6 = 0;
|
||||
|
||||
// Update the map with the new information
|
||||
struct sk_info *info = bpf_map_lookup_elem(&pm_bandwidth_map, &key);
|
||||
if (info != NULL) {
|
||||
__sync_fetch_and_add(&info->rx, len); // TODO: Use atomic instead.
|
||||
__sync_fetch_and_and(&info->reported, 0); // TODO: Use atomic instead.
|
||||
} else {
|
||||
struct sk_info newInfo = {0};
|
||||
|
||||
newInfo.rx = len;
|
||||
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
|
||||
}
|
||||
|
||||
return 0;
|
||||
};
|
||||
|
||||
// udpv6_sendmsg hookes to the respective kernel function and saves the bandwidth data
|
||||
SEC("fentry/udpv6_sendmsg")
|
||||
int BPF_PROG(udpv6_sendmsg, struct sock *sk, struct msghdr *msg, size_t len) {
|
||||
struct sock_common *skc = &sk->__sk_common;
|
||||
|
||||
// Create a key for the map and set all the nececery information.
|
||||
struct sk_key key = {0};
|
||||
key.protocol = PROTOCOL_UDP;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
key.src_ip[i] = skc->skc_v6_rcv_saddr.in6_u.u6_addr32[i];
|
||||
key.dst_ip[i] = skc->skc_v6_rcv_saddr.in6_u.u6_addr32[i];
|
||||
}
|
||||
key.src_port = skc->skc_num;
|
||||
key.dst_port = __builtin_bswap16(skc->skc_dport);
|
||||
key.ipv6 = 1;
|
||||
|
||||
// Update the map with the new information
|
||||
struct sk_info *info = bpf_map_lookup_elem(&pm_bandwidth_map, &key);
|
||||
if (info != NULL) {
|
||||
__sync_fetch_and_add(&info->tx, len); // TODO: Use atomic instead.
|
||||
__sync_fetch_and_and(&info->reported, 0); // TODO: Use atomic instead.
|
||||
} else {
|
||||
struct sk_info newInfo = {0};
|
||||
newInfo.tx = len;
|
||||
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// udpv6_recvmsg hookes to the respective kernel function and saves the bandwidth data
|
||||
SEC("fentry/udpv6_recvmsg")
|
||||
int BPF_PROG(udpv6_recvmsg, struct sock *sk, struct msghdr *msg, size_t len, int flags, int *addr_len) {
|
||||
struct sock_common *skc = &sk->__sk_common;
|
||||
|
||||
// Create a key for the map and set all the nececery information.
|
||||
struct sk_key key = {0};
|
||||
key.protocol = PROTOCOL_UDP;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
key.src_ip[i] = skc->skc_v6_rcv_saddr.in6_u.u6_addr32[i];
|
||||
key.dst_ip[i] = skc->skc_v6_rcv_saddr.in6_u.u6_addr32[i];
|
||||
}
|
||||
key.src_port = skc->skc_num;
|
||||
key.dst_port = __builtin_bswap16(skc->skc_dport);
|
||||
key.ipv6 = 1;
|
||||
|
||||
// Update the map with the new information
|
||||
struct sk_info *info = bpf_map_lookup_elem(&pm_bandwidth_map, &key);
|
||||
if (info != NULL) {
|
||||
__sync_fetch_and_add(&info->rx, len); // TODO: Use atomic instead.
|
||||
__sync_fetch_and_and(&info->reported, 0); // TODO: Use atomic instead.
|
||||
} else {
|
||||
struct sk_info newInfo = {0};
|
||||
newInfo.rx = len;
|
||||
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
484
service/firewall/interception/ebpf/programs/bpf/bpf_core_read.h
Normal file
484
service/firewall/interception/ebpf/programs/bpf/bpf_core_read.h
Normal file
@@ -0,0 +1,484 @@
|
||||
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
|
||||
#ifndef __BPF_CORE_READ_H__
|
||||
#define __BPF_CORE_READ_H__
|
||||
|
||||
/*
|
||||
* enum bpf_field_info_kind is passed as a second argument into
|
||||
* __builtin_preserve_field_info() built-in to get a specific aspect of
|
||||
* a field, captured as a first argument. __builtin_preserve_field_info(field,
|
||||
* info_kind) returns __u32 integer and produces BTF field relocation, which
|
||||
* is understood and processed by libbpf during BPF object loading. See
|
||||
* selftests/bpf for examples.
|
||||
*/
|
||||
enum bpf_field_info_kind {
|
||||
BPF_FIELD_BYTE_OFFSET = 0, /* field byte offset */
|
||||
BPF_FIELD_BYTE_SIZE = 1,
|
||||
BPF_FIELD_EXISTS = 2, /* field existence in target kernel */
|
||||
BPF_FIELD_SIGNED = 3,
|
||||
BPF_FIELD_LSHIFT_U64 = 4,
|
||||
BPF_FIELD_RSHIFT_U64 = 5,
|
||||
};
|
||||
|
||||
/* second argument to __builtin_btf_type_id() built-in */
|
||||
enum bpf_type_id_kind {
|
||||
BPF_TYPE_ID_LOCAL = 0, /* BTF type ID in local program */
|
||||
BPF_TYPE_ID_TARGET = 1, /* BTF type ID in target kernel */
|
||||
};
|
||||
|
||||
/* second argument to __builtin_preserve_type_info() built-in */
|
||||
enum bpf_type_info_kind {
|
||||
BPF_TYPE_EXISTS = 0, /* type existence in target kernel */
|
||||
BPF_TYPE_SIZE = 1, /* type size in target kernel */
|
||||
BPF_TYPE_MATCHES = 2, /* type match in target kernel */
|
||||
};
|
||||
|
||||
/* second argument to __builtin_preserve_enum_value() built-in */
|
||||
enum bpf_enum_value_kind {
|
||||
BPF_ENUMVAL_EXISTS = 0, /* enum value existence in kernel */
|
||||
BPF_ENUMVAL_VALUE = 1, /* enum value value relocation */
|
||||
};
|
||||
|
||||
#define __CORE_RELO(src, field, info) \
|
||||
__builtin_preserve_field_info((src)->field, BPF_FIELD_##info)
|
||||
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
#define __CORE_BITFIELD_PROBE_READ(dst, src, fld) \
|
||||
bpf_probe_read_kernel( \
|
||||
(void *)dst, \
|
||||
__CORE_RELO(src, fld, BYTE_SIZE), \
|
||||
(const void *)src + __CORE_RELO(src, fld, BYTE_OFFSET))
|
||||
#else
|
||||
/* semantics of LSHIFT_64 assumes loading values into low-ordered bytes, so
|
||||
* for big-endian we need to adjust destination pointer accordingly, based on
|
||||
* field byte size
|
||||
*/
|
||||
#define __CORE_BITFIELD_PROBE_READ(dst, src, fld) \
|
||||
bpf_probe_read_kernel( \
|
||||
(void *)dst + (8 - __CORE_RELO(src, fld, BYTE_SIZE)), \
|
||||
__CORE_RELO(src, fld, BYTE_SIZE), \
|
||||
(const void *)src + __CORE_RELO(src, fld, BYTE_OFFSET))
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Extract bitfield, identified by s->field, and return its value as u64.
|
||||
* All this is done in relocatable manner, so bitfield changes such as
|
||||
* signedness, bit size, offset changes, this will be handled automatically.
|
||||
* This version of macro is using bpf_probe_read_kernel() to read underlying
|
||||
* integer storage. Macro functions as an expression and its return type is
|
||||
* bpf_probe_read_kernel()'s return value: 0, on success, <0 on error.
|
||||
*/
|
||||
#define BPF_CORE_READ_BITFIELD_PROBED(s, field) ({ \
|
||||
unsigned long long val = 0; \
|
||||
\
|
||||
__CORE_BITFIELD_PROBE_READ(&val, s, field); \
|
||||
val <<= __CORE_RELO(s, field, LSHIFT_U64); \
|
||||
if (__CORE_RELO(s, field, SIGNED)) \
|
||||
val = ((long long)val) >> __CORE_RELO(s, field, RSHIFT_U64); \
|
||||
else \
|
||||
val = val >> __CORE_RELO(s, field, RSHIFT_U64); \
|
||||
val; \
|
||||
})
|
||||
|
||||
/*
|
||||
* Extract bitfield, identified by s->field, and return its value as u64.
|
||||
* This version of macro is using direct memory reads and should be used from
|
||||
* BPF program types that support such functionality (e.g., typed raw
|
||||
* tracepoints).
|
||||
*/
|
||||
#define BPF_CORE_READ_BITFIELD(s, field) ({ \
|
||||
const void *p = (const void *)s + __CORE_RELO(s, field, BYTE_OFFSET); \
|
||||
unsigned long long val; \
|
||||
\
|
||||
/* This is a so-called barrier_var() operation that makes specified \
|
||||
* variable "a black box" for optimizing compiler. \
|
||||
* It forces compiler to perform BYTE_OFFSET relocation on p and use \
|
||||
* its calculated value in the switch below, instead of applying \
|
||||
* the same relocation 4 times for each individual memory load. \
|
||||
*/ \
|
||||
asm volatile("" : "=r"(p) : "0"(p)); \
|
||||
\
|
||||
switch (__CORE_RELO(s, field, BYTE_SIZE)) { \
|
||||
case 1: val = *(const unsigned char *)p; break; \
|
||||
case 2: val = *(const unsigned short *)p; break; \
|
||||
case 4: val = *(const unsigned int *)p; break; \
|
||||
case 8: val = *(const unsigned long long *)p; break; \
|
||||
} \
|
||||
val <<= __CORE_RELO(s, field, LSHIFT_U64); \
|
||||
if (__CORE_RELO(s, field, SIGNED)) \
|
||||
val = ((long long)val) >> __CORE_RELO(s, field, RSHIFT_U64); \
|
||||
else \
|
||||
val = val >> __CORE_RELO(s, field, RSHIFT_U64); \
|
||||
val; \
|
||||
})
|
||||
|
||||
#define ___bpf_field_ref1(field) (field)
|
||||
#define ___bpf_field_ref2(type, field) (((typeof(type) *)0)->field)
|
||||
#define ___bpf_field_ref(args...) \
|
||||
___bpf_apply(___bpf_field_ref, ___bpf_narg(args))(args)
|
||||
|
||||
/*
|
||||
* Convenience macro to check that field actually exists in target kernel's.
|
||||
* Returns:
|
||||
* 1, if matching field is present in target kernel;
|
||||
* 0, if no matching field found.
|
||||
*
|
||||
* Supports two forms:
|
||||
* - field reference through variable access:
|
||||
* bpf_core_field_exists(p->my_field);
|
||||
* - field reference through type and field names:
|
||||
* bpf_core_field_exists(struct my_type, my_field).
|
||||
*/
|
||||
#define bpf_core_field_exists(field...) \
|
||||
__builtin_preserve_field_info(___bpf_field_ref(field), BPF_FIELD_EXISTS)
|
||||
|
||||
/*
|
||||
* Convenience macro to get the byte size of a field. Works for integers,
|
||||
* struct/unions, pointers, arrays, and enums.
|
||||
*
|
||||
* Supports two forms:
|
||||
* - field reference through variable access:
|
||||
* bpf_core_field_size(p->my_field);
|
||||
* - field reference through type and field names:
|
||||
* bpf_core_field_size(struct my_type, my_field).
|
||||
*/
|
||||
#define bpf_core_field_size(field...) \
|
||||
__builtin_preserve_field_info(___bpf_field_ref(field), BPF_FIELD_BYTE_SIZE)
|
||||
|
||||
/*
|
||||
* Convenience macro to get field's byte offset.
|
||||
*
|
||||
* Supports two forms:
|
||||
* - field reference through variable access:
|
||||
* bpf_core_field_offset(p->my_field);
|
||||
* - field reference through type and field names:
|
||||
* bpf_core_field_offset(struct my_type, my_field).
|
||||
*/
|
||||
#define bpf_core_field_offset(field...) \
|
||||
__builtin_preserve_field_info(___bpf_field_ref(field), BPF_FIELD_BYTE_OFFSET)
|
||||
|
||||
/*
|
||||
* Convenience macro to get BTF type ID of a specified type, using a local BTF
|
||||
* information. Return 32-bit unsigned integer with type ID from program's own
|
||||
* BTF. Always succeeds.
|
||||
*/
|
||||
#define bpf_core_type_id_local(type) \
|
||||
__builtin_btf_type_id(*(typeof(type) *)0, BPF_TYPE_ID_LOCAL)
|
||||
|
||||
/*
|
||||
* Convenience macro to get BTF type ID of a target kernel's type that matches
|
||||
* specified local type.
|
||||
* Returns:
|
||||
* - valid 32-bit unsigned type ID in kernel BTF;
|
||||
* - 0, if no matching type was found in a target kernel BTF.
|
||||
*/
|
||||
#define bpf_core_type_id_kernel(type) \
|
||||
__builtin_btf_type_id(*(typeof(type) *)0, BPF_TYPE_ID_TARGET)
|
||||
|
||||
/*
|
||||
* Convenience macro to check that provided named type
|
||||
* (struct/union/enum/typedef) exists in a target kernel.
|
||||
* Returns:
|
||||
* 1, if such type is present in target kernel's BTF;
|
||||
* 0, if no matching type is found.
|
||||
*/
|
||||
#define bpf_core_type_exists(type) \
|
||||
__builtin_preserve_type_info(*(typeof(type) *)0, BPF_TYPE_EXISTS)
|
||||
|
||||
/*
|
||||
* Convenience macro to check that provided named type
|
||||
* (struct/union/enum/typedef) "matches" that in a target kernel.
|
||||
* Returns:
|
||||
* 1, if the type matches in the target kernel's BTF;
|
||||
* 0, if the type does not match any in the target kernel
|
||||
*/
|
||||
#define bpf_core_type_matches(type) \
|
||||
__builtin_preserve_type_info(*(typeof(type) *)0, BPF_TYPE_MATCHES)
|
||||
|
||||
/*
|
||||
* Convenience macro to get the byte size of a provided named type
|
||||
* (struct/union/enum/typedef) in a target kernel.
|
||||
* Returns:
|
||||
* >= 0 size (in bytes), if type is present in target kernel's BTF;
|
||||
* 0, if no matching type is found.
|
||||
*/
|
||||
#define bpf_core_type_size(type) \
|
||||
__builtin_preserve_type_info(*(typeof(type) *)0, BPF_TYPE_SIZE)
|
||||
|
||||
/*
|
||||
* Convenience macro to check that provided enumerator value is defined in
|
||||
* a target kernel.
|
||||
* Returns:
|
||||
* 1, if specified enum type and its enumerator value are present in target
|
||||
* kernel's BTF;
|
||||
* 0, if no matching enum and/or enum value within that enum is found.
|
||||
*/
|
||||
#define bpf_core_enum_value_exists(enum_type, enum_value) \
|
||||
__builtin_preserve_enum_value(*(typeof(enum_type) *)enum_value, BPF_ENUMVAL_EXISTS)
|
||||
|
||||
/*
|
||||
* Convenience macro to get the integer value of an enumerator value in
|
||||
* a target kernel.
|
||||
* Returns:
|
||||
* 64-bit value, if specified enum type and its enumerator value are
|
||||
* present in target kernel's BTF;
|
||||
* 0, if no matching enum and/or enum value within that enum is found.
|
||||
*/
|
||||
#define bpf_core_enum_value(enum_type, enum_value) \
|
||||
__builtin_preserve_enum_value(*(typeof(enum_type) *)enum_value, BPF_ENUMVAL_VALUE)
|
||||
|
||||
/*
|
||||
* bpf_core_read() abstracts away bpf_probe_read_kernel() call and captures
|
||||
* offset relocation for source address using __builtin_preserve_access_index()
|
||||
* built-in, provided by Clang.
|
||||
*
|
||||
* __builtin_preserve_access_index() takes as an argument an expression of
|
||||
* taking an address of a field within struct/union. It makes compiler emit
|
||||
* a relocation, which records BTF type ID describing root struct/union and an
|
||||
* accessor string which describes exact embedded field that was used to take
|
||||
* an address. See detailed description of this relocation format and
|
||||
* semantics in comments to struct bpf_field_reloc in libbpf_internal.h.
|
||||
*
|
||||
* This relocation allows libbpf to adjust BPF instruction to use correct
|
||||
* actual field offset, based on target kernel BTF type that matches original
|
||||
* (local) BTF, used to record relocation.
|
||||
*/
|
||||
#define bpf_core_read(dst, sz, src) \
|
||||
bpf_probe_read_kernel(dst, sz, (const void *)__builtin_preserve_access_index(src))
|
||||
|
||||
/* NOTE: see comments for BPF_CORE_READ_USER() about the proper types use. */
|
||||
#define bpf_core_read_user(dst, sz, src) \
|
||||
bpf_probe_read_user(dst, sz, (const void *)__builtin_preserve_access_index(src))
|
||||
/*
|
||||
* bpf_core_read_str() is a thin wrapper around bpf_probe_read_str()
|
||||
* additionally emitting BPF CO-RE field relocation for specified source
|
||||
* argument.
|
||||
*/
|
||||
#define bpf_core_read_str(dst, sz, src) \
|
||||
bpf_probe_read_kernel_str(dst, sz, (const void *)__builtin_preserve_access_index(src))
|
||||
|
||||
/* NOTE: see comments for BPF_CORE_READ_USER() about the proper types use. */
|
||||
#define bpf_core_read_user_str(dst, sz, src) \
|
||||
bpf_probe_read_user_str(dst, sz, (const void *)__builtin_preserve_access_index(src))
|
||||
|
||||
#define ___concat(a, b) a ## b
|
||||
#define ___apply(fn, n) ___concat(fn, n)
|
||||
#define ___nth(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, __11, N, ...) N
|
||||
|
||||
/*
|
||||
* return number of provided arguments; used for switch-based variadic macro
|
||||
* definitions (see ___last, ___arrow, etc below)
|
||||
*/
|
||||
#define ___narg(...) ___nth(_, ##__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
|
||||
/*
|
||||
* return 0 if no arguments are passed, N - otherwise; used for
|
||||
* recursively-defined macros to specify termination (0) case, and generic
|
||||
* (N) case (e.g., ___read_ptrs, ___core_read)
|
||||
*/
|
||||
#define ___empty(...) ___nth(_, ##__VA_ARGS__, N, N, N, N, N, N, N, N, N, N, 0)
|
||||
|
||||
#define ___last1(x) x
|
||||
#define ___last2(a, x) x
|
||||
#define ___last3(a, b, x) x
|
||||
#define ___last4(a, b, c, x) x
|
||||
#define ___last5(a, b, c, d, x) x
|
||||
#define ___last6(a, b, c, d, e, x) x
|
||||
#define ___last7(a, b, c, d, e, f, x) x
|
||||
#define ___last8(a, b, c, d, e, f, g, x) x
|
||||
#define ___last9(a, b, c, d, e, f, g, h, x) x
|
||||
#define ___last10(a, b, c, d, e, f, g, h, i, x) x
|
||||
#define ___last(...) ___apply(___last, ___narg(__VA_ARGS__))(__VA_ARGS__)
|
||||
|
||||
#define ___nolast2(a, _) a
|
||||
#define ___nolast3(a, b, _) a, b
|
||||
#define ___nolast4(a, b, c, _) a, b, c
|
||||
#define ___nolast5(a, b, c, d, _) a, b, c, d
|
||||
#define ___nolast6(a, b, c, d, e, _) a, b, c, d, e
|
||||
#define ___nolast7(a, b, c, d, e, f, _) a, b, c, d, e, f
|
||||
#define ___nolast8(a, b, c, d, e, f, g, _) a, b, c, d, e, f, g
|
||||
#define ___nolast9(a, b, c, d, e, f, g, h, _) a, b, c, d, e, f, g, h
|
||||
#define ___nolast10(a, b, c, d, e, f, g, h, i, _) a, b, c, d, e, f, g, h, i
|
||||
#define ___nolast(...) ___apply(___nolast, ___narg(__VA_ARGS__))(__VA_ARGS__)
|
||||
|
||||
#define ___arrow1(a) a
|
||||
#define ___arrow2(a, b) a->b
|
||||
#define ___arrow3(a, b, c) a->b->c
|
||||
#define ___arrow4(a, b, c, d) a->b->c->d
|
||||
#define ___arrow5(a, b, c, d, e) a->b->c->d->e
|
||||
#define ___arrow6(a, b, c, d, e, f) a->b->c->d->e->f
|
||||
#define ___arrow7(a, b, c, d, e, f, g) a->b->c->d->e->f->g
|
||||
#define ___arrow8(a, b, c, d, e, f, g, h) a->b->c->d->e->f->g->h
|
||||
#define ___arrow9(a, b, c, d, e, f, g, h, i) a->b->c->d->e->f->g->h->i
|
||||
#define ___arrow10(a, b, c, d, e, f, g, h, i, j) a->b->c->d->e->f->g->h->i->j
|
||||
#define ___arrow(...) ___apply(___arrow, ___narg(__VA_ARGS__))(__VA_ARGS__)
|
||||
|
||||
#define ___type(...) typeof(___arrow(__VA_ARGS__))
|
||||
|
||||
#define ___read(read_fn, dst, src_type, src, accessor) \
|
||||
read_fn((void *)(dst), sizeof(*(dst)), &((src_type)(src))->accessor)
|
||||
|
||||
/* "recursively" read a sequence of inner pointers using local __t var */
|
||||
#define ___rd_first(fn, src, a) ___read(fn, &__t, ___type(src), src, a);
|
||||
#define ___rd_last(fn, ...) \
|
||||
___read(fn, &__t, ___type(___nolast(__VA_ARGS__)), __t, ___last(__VA_ARGS__));
|
||||
#define ___rd_p1(fn, ...) const void *__t; ___rd_first(fn, __VA_ARGS__)
|
||||
#define ___rd_p2(fn, ...) ___rd_p1(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
|
||||
#define ___rd_p3(fn, ...) ___rd_p2(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
|
||||
#define ___rd_p4(fn, ...) ___rd_p3(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
|
||||
#define ___rd_p5(fn, ...) ___rd_p4(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
|
||||
#define ___rd_p6(fn, ...) ___rd_p5(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
|
||||
#define ___rd_p7(fn, ...) ___rd_p6(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
|
||||
#define ___rd_p8(fn, ...) ___rd_p7(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
|
||||
#define ___rd_p9(fn, ...) ___rd_p8(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
|
||||
#define ___read_ptrs(fn, src, ...) \
|
||||
___apply(___rd_p, ___narg(__VA_ARGS__))(fn, src, __VA_ARGS__)
|
||||
|
||||
#define ___core_read0(fn, fn_ptr, dst, src, a) \
|
||||
___read(fn, dst, ___type(src), src, a);
|
||||
#define ___core_readN(fn, fn_ptr, dst, src, ...) \
|
||||
___read_ptrs(fn_ptr, src, ___nolast(__VA_ARGS__)) \
|
||||
___read(fn, dst, ___type(src, ___nolast(__VA_ARGS__)), __t, \
|
||||
___last(__VA_ARGS__));
|
||||
#define ___core_read(fn, fn_ptr, dst, src, a, ...) \
|
||||
___apply(___core_read, ___empty(__VA_ARGS__))(fn, fn_ptr, dst, \
|
||||
src, a, ##__VA_ARGS__)
|
||||
|
||||
/*
|
||||
* BPF_CORE_READ_INTO() is a more performance-conscious variant of
|
||||
* BPF_CORE_READ(), in which final field is read into user-provided storage.
|
||||
* See BPF_CORE_READ() below for more details on general usage.
|
||||
*/
|
||||
#define BPF_CORE_READ_INTO(dst, src, a, ...) ({ \
|
||||
___core_read(bpf_core_read, bpf_core_read, \
|
||||
dst, (src), a, ##__VA_ARGS__) \
|
||||
})
|
||||
|
||||
/*
|
||||
* Variant of BPF_CORE_READ_INTO() for reading from user-space memory.
|
||||
*
|
||||
* NOTE: see comments for BPF_CORE_READ_USER() about the proper types use.
|
||||
*/
|
||||
#define BPF_CORE_READ_USER_INTO(dst, src, a, ...) ({ \
|
||||
___core_read(bpf_core_read_user, bpf_core_read_user, \
|
||||
dst, (src), a, ##__VA_ARGS__) \
|
||||
})
|
||||
|
||||
/* Non-CO-RE variant of BPF_CORE_READ_INTO() */
|
||||
#define BPF_PROBE_READ_INTO(dst, src, a, ...) ({ \
|
||||
___core_read(bpf_probe_read_kernel, bpf_probe_read_kernel, \
|
||||
dst, (src), a, ##__VA_ARGS__) \
|
||||
})
|
||||
|
||||
/* Non-CO-RE variant of BPF_CORE_READ_USER_INTO().
|
||||
*
|
||||
* As no CO-RE relocations are emitted, source types can be arbitrary and are
|
||||
* not restricted to kernel types only.
|
||||
*/
|
||||
#define BPF_PROBE_READ_USER_INTO(dst, src, a, ...) ({ \
|
||||
___core_read(bpf_probe_read_user, bpf_probe_read_user, \
|
||||
dst, (src), a, ##__VA_ARGS__) \
|
||||
})
|
||||
|
||||
/*
|
||||
* BPF_CORE_READ_STR_INTO() does same "pointer chasing" as
|
||||
* BPF_CORE_READ() for intermediate pointers, but then executes (and returns
|
||||
* corresponding error code) bpf_core_read_str() for final string read.
|
||||
*/
|
||||
#define BPF_CORE_READ_STR_INTO(dst, src, a, ...) ({ \
|
||||
___core_read(bpf_core_read_str, bpf_core_read, \
|
||||
dst, (src), a, ##__VA_ARGS__) \
|
||||
})
|
||||
|
||||
/*
|
||||
* Variant of BPF_CORE_READ_STR_INTO() for reading from user-space memory.
|
||||
*
|
||||
* NOTE: see comments for BPF_CORE_READ_USER() about the proper types use.
|
||||
*/
|
||||
#define BPF_CORE_READ_USER_STR_INTO(dst, src, a, ...) ({ \
|
||||
___core_read(bpf_core_read_user_str, bpf_core_read_user, \
|
||||
dst, (src), a, ##__VA_ARGS__) \
|
||||
})
|
||||
|
||||
/* Non-CO-RE variant of BPF_CORE_READ_STR_INTO() */
|
||||
#define BPF_PROBE_READ_STR_INTO(dst, src, a, ...) ({ \
|
||||
___core_read(bpf_probe_read_kernel_str, bpf_probe_read_kernel, \
|
||||
dst, (src), a, ##__VA_ARGS__) \
|
||||
})
|
||||
|
||||
/*
|
||||
* Non-CO-RE variant of BPF_CORE_READ_USER_STR_INTO().
|
||||
*
|
||||
* As no CO-RE relocations are emitted, source types can be arbitrary and are
|
||||
* not restricted to kernel types only.
|
||||
*/
|
||||
#define BPF_PROBE_READ_USER_STR_INTO(dst, src, a, ...) ({ \
|
||||
___core_read(bpf_probe_read_user_str, bpf_probe_read_user, \
|
||||
dst, (src), a, ##__VA_ARGS__) \
|
||||
})
|
||||
|
||||
/*
|
||||
* BPF_CORE_READ() is used to simplify BPF CO-RE relocatable read, especially
|
||||
* when there are few pointer chasing steps.
|
||||
* E.g., what in non-BPF world (or in BPF w/ BCC) would be something like:
|
||||
* int x = s->a.b.c->d.e->f->g;
|
||||
* can be succinctly achieved using BPF_CORE_READ as:
|
||||
* int x = BPF_CORE_READ(s, a.b.c, d.e, f, g);
|
||||
*
|
||||
* BPF_CORE_READ will decompose above statement into 4 bpf_core_read (BPF
|
||||
* CO-RE relocatable bpf_probe_read_kernel() wrapper) calls, logically
|
||||
* equivalent to:
|
||||
* 1. const void *__t = s->a.b.c;
|
||||
* 2. __t = __t->d.e;
|
||||
* 3. __t = __t->f;
|
||||
* 4. return __t->g;
|
||||
*
|
||||
* Equivalence is logical, because there is a heavy type casting/preservation
|
||||
* involved, as well as all the reads are happening through
|
||||
* bpf_probe_read_kernel() calls using __builtin_preserve_access_index() to
|
||||
* emit CO-RE relocations.
|
||||
*
|
||||
* N.B. Only up to 9 "field accessors" are supported, which should be more
|
||||
* than enough for any practical purpose.
|
||||
*/
|
||||
#define BPF_CORE_READ(src, a, ...) ({ \
|
||||
___type((src), a, ##__VA_ARGS__) __r; \
|
||||
BPF_CORE_READ_INTO(&__r, (src), a, ##__VA_ARGS__); \
|
||||
__r; \
|
||||
})
|
||||
|
||||
/*
|
||||
* Variant of BPF_CORE_READ() for reading from user-space memory.
|
||||
*
|
||||
* NOTE: all the source types involved are still *kernel types* and need to
|
||||
* exist in kernel (or kernel module) BTF, otherwise CO-RE relocation will
|
||||
* fail. Custom user types are not relocatable with CO-RE.
|
||||
* The typical situation in which BPF_CORE_READ_USER() might be used is to
|
||||
* read kernel UAPI types from the user-space memory passed in as a syscall
|
||||
* input argument.
|
||||
*/
|
||||
#define BPF_CORE_READ_USER(src, a, ...) ({ \
|
||||
___type((src), a, ##__VA_ARGS__) __r; \
|
||||
BPF_CORE_READ_USER_INTO(&__r, (src), a, ##__VA_ARGS__); \
|
||||
__r; \
|
||||
})
|
||||
|
||||
/* Non-CO-RE variant of BPF_CORE_READ() */
|
||||
#define BPF_PROBE_READ(src, a, ...) ({ \
|
||||
___type((src), a, ##__VA_ARGS__) __r; \
|
||||
BPF_PROBE_READ_INTO(&__r, (src), a, ##__VA_ARGS__); \
|
||||
__r; \
|
||||
})
|
||||
|
||||
/*
|
||||
* Non-CO-RE variant of BPF_CORE_READ_USER().
|
||||
*
|
||||
* As no CO-RE relocations are emitted, source types can be arbitrary and are
|
||||
* not restricted to kernel types only.
|
||||
*/
|
||||
#define BPF_PROBE_READ_USER(src, a, ...) ({ \
|
||||
___type((src), a, ##__VA_ARGS__) __r; \
|
||||
BPF_PROBE_READ_USER_INTO(&__r, (src), a, ##__VA_ARGS__); \
|
||||
__r; \
|
||||
})
|
||||
|
||||
#endif
|
||||
|
||||
4748
service/firewall/interception/ebpf/programs/bpf/bpf_helper_defs.h
Normal file
4748
service/firewall/interception/ebpf/programs/bpf/bpf_helper_defs.h
Normal file
File diff suppressed because it is too large
Load Diff
397
service/firewall/interception/ebpf/programs/bpf/bpf_helpers.h
Normal file
397
service/firewall/interception/ebpf/programs/bpf/bpf_helpers.h
Normal file
@@ -0,0 +1,397 @@
|
||||
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
|
||||
#ifndef __BPF_HELPERS__
|
||||
#define __BPF_HELPERS__
|
||||
|
||||
/*
|
||||
* Note that bpf programs need to include either
|
||||
* vmlinux.h (auto-generated from BTF) or linux/types.h
|
||||
* in advance since bpf_helper_defs.h uses such types
|
||||
* as __u64.
|
||||
*/
|
||||
#include "bpf_helper_defs.h"
|
||||
|
||||
#define __uint(name, val) int (*name)[val]
|
||||
#define __type(name, val) typeof(val) *name
|
||||
#define __array(name, val) typeof(val) *name[]
|
||||
|
||||
/*
|
||||
* Helper macro to place programs, maps, license in
|
||||
* different sections in elf_bpf file. Section names
|
||||
* are interpreted by libbpf depending on the context (BPF programs, BPF maps,
|
||||
* extern variables, etc).
|
||||
* To allow use of SEC() with externs (e.g., for extern .maps declarations),
|
||||
* make sure __attribute__((unused)) doesn't trigger compilation warning.
|
||||
*/
|
||||
#if __GNUC__ && !__clang__
|
||||
|
||||
/*
|
||||
* Pragma macros are broken on GCC
|
||||
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55578
|
||||
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90400
|
||||
*/
|
||||
#define SEC(name) __attribute__((section(name), used))
|
||||
|
||||
#else
|
||||
|
||||
#define SEC(name) \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wignored-attributes\"") \
|
||||
__attribute__((section(name), used)) \
|
||||
_Pragma("GCC diagnostic pop") \
|
||||
|
||||
#endif
|
||||
|
||||
/* Avoid 'linux/stddef.h' definition of '__always_inline'. */
|
||||
#undef __always_inline
|
||||
#define __always_inline inline __attribute__((always_inline))
|
||||
|
||||
#ifndef __noinline
|
||||
#define __noinline __attribute__((noinline))
|
||||
#endif
|
||||
#ifndef __weak
|
||||
#define __weak __attribute__((weak))
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Use __hidden attribute to mark a non-static BPF subprogram effectively
|
||||
* static for BPF verifier's verification algorithm purposes, allowing more
|
||||
* extensive and permissive BPF verification process, taking into account
|
||||
* subprogram's caller context.
|
||||
*/
|
||||
#define __hidden __attribute__((visibility("hidden")))
|
||||
|
||||
/* When utilizing vmlinux.h with BPF CO-RE, user BPF programs can't include
|
||||
* any system-level headers (such as stddef.h, linux/version.h, etc), and
|
||||
* commonly-used macros like NULL and KERNEL_VERSION aren't available through
|
||||
* vmlinux.h. This just adds unnecessary hurdles and forces users to re-define
|
||||
* them on their own. So as a convenience, provide such definitions here.
|
||||
*/
|
||||
#ifndef NULL
|
||||
#define NULL ((void *)0)
|
||||
#endif
|
||||
|
||||
#ifndef KERNEL_VERSION
|
||||
#define KERNEL_VERSION(a, b, c) (((a) << 16) + ((b) << 8) + ((c) > 255 ? 255 : (c)))
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Helper macros to manipulate data structures
|
||||
*/
|
||||
#ifndef offsetof
|
||||
#define offsetof(TYPE, MEMBER) ((unsigned long)&((TYPE *)0)->MEMBER)
|
||||
#endif
|
||||
#ifndef container_of
|
||||
#define container_of(ptr, type, member) \
|
||||
({ \
|
||||
void *__mptr = (void *)(ptr); \
|
||||
((type *)(__mptr - offsetof(type, member))); \
|
||||
})
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Compiler (optimization) barrier.
|
||||
*/
|
||||
#ifndef barrier
|
||||
#define barrier() asm volatile("" ::: "memory")
|
||||
#endif
|
||||
|
||||
/* Variable-specific compiler (optimization) barrier. It's a no-op which makes
|
||||
* compiler believe that there is some black box modification of a given
|
||||
* variable and thus prevents compiler from making extra assumption about its
|
||||
* value and potential simplifications and optimizations on this variable.
|
||||
*
|
||||
* E.g., compiler might often delay or even omit 32-bit to 64-bit casting of
|
||||
* a variable, making some code patterns unverifiable. Putting barrier_var()
|
||||
* in place will ensure that cast is performed before the barrier_var()
|
||||
* invocation, because compiler has to pessimistically assume that embedded
|
||||
* asm section might perform some extra operations on that variable.
|
||||
*
|
||||
* This is a variable-specific variant of more global barrier().
|
||||
*/
|
||||
#ifndef barrier_var
|
||||
#define barrier_var(var) asm volatile("" : "+r"(var))
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Helper macro to throw a compilation error if __bpf_unreachable() gets
|
||||
* built into the resulting code. This works given BPF back end does not
|
||||
* implement __builtin_trap(). This is useful to assert that certain paths
|
||||
* of the program code are never used and hence eliminated by the compiler.
|
||||
*
|
||||
* For example, consider a switch statement that covers known cases used by
|
||||
* the program. __bpf_unreachable() can then reside in the default case. If
|
||||
* the program gets extended such that a case is not covered in the switch
|
||||
* statement, then it will throw a build error due to the default case not
|
||||
* being compiled out.
|
||||
*/
|
||||
#ifndef __bpf_unreachable
|
||||
# define __bpf_unreachable() __builtin_trap()
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Helper function to perform a tail call with a constant/immediate map slot.
|
||||
*/
|
||||
#if __clang_major__ >= 8 && defined(__bpf__)
|
||||
static __always_inline void
|
||||
bpf_tail_call_static(void *ctx, const void *map, const __u32 slot)
|
||||
{
|
||||
if (!__builtin_constant_p(slot))
|
||||
__bpf_unreachable();
|
||||
|
||||
/*
|
||||
* Provide a hard guarantee that LLVM won't optimize setting r2 (map
|
||||
* pointer) and r3 (constant map index) from _different paths_ ending
|
||||
* up at the _same_ call insn as otherwise we won't be able to use the
|
||||
* jmpq/nopl retpoline-free patching by the x86-64 JIT in the kernel
|
||||
* given they mismatch. See also d2e4c1e6c294 ("bpf: Constant map key
|
||||
* tracking for prog array pokes") for details on verifier tracking.
|
||||
*
|
||||
* Note on clobber list: we need to stay in-line with BPF calling
|
||||
* convention, so even if we don't end up using r0, r4, r5, we need
|
||||
* to mark them as clobber so that LLVM doesn't end up using them
|
||||
* before / after the call.
|
||||
*/
|
||||
asm volatile("r1 = %[ctx]\n\t"
|
||||
"r2 = %[map]\n\t"
|
||||
"r3 = %[slot]\n\t"
|
||||
"call 12"
|
||||
:: [ctx]"r"(ctx), [map]"r"(map), [slot]"i"(slot)
|
||||
: "r0", "r1", "r2", "r3", "r4", "r5");
|
||||
}
|
||||
#endif
|
||||
|
||||
enum libbpf_pin_type {
|
||||
LIBBPF_PIN_NONE,
|
||||
/* PIN_BY_NAME: pin maps by name (in /sys/fs/bpf by default) */
|
||||
LIBBPF_PIN_BY_NAME,
|
||||
};
|
||||
|
||||
enum libbpf_tristate {
|
||||
TRI_NO = 0,
|
||||
TRI_YES = 1,
|
||||
TRI_MODULE = 2,
|
||||
};
|
||||
|
||||
#define __kconfig __attribute__((section(".kconfig")))
|
||||
#define __ksym __attribute__((section(".ksyms")))
|
||||
#define __kptr_untrusted __attribute__((btf_type_tag("kptr_untrusted")))
|
||||
#define __kptr __attribute__((btf_type_tag("kptr")))
|
||||
|
||||
#define bpf_ksym_exists(sym) ({ \
|
||||
_Static_assert(!__builtin_constant_p(!!sym), #sym " should be marked as __weak"); \
|
||||
!!sym; \
|
||||
})
|
||||
|
||||
#ifndef ___bpf_concat
|
||||
#define ___bpf_concat(a, b) a ## b
|
||||
#endif
|
||||
#ifndef ___bpf_apply
|
||||
#define ___bpf_apply(fn, n) ___bpf_concat(fn, n)
|
||||
#endif
|
||||
#ifndef ___bpf_nth
|
||||
#define ___bpf_nth(_, _1, _2, _3, _4, _5, _6, _7, _8, _9, _a, _b, _c, N, ...) N
|
||||
#endif
|
||||
#ifndef ___bpf_narg
|
||||
#define ___bpf_narg(...) \
|
||||
___bpf_nth(_, ##__VA_ARGS__, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
|
||||
#endif
|
||||
|
||||
#define ___bpf_fill0(arr, p, x) do {} while (0)
|
||||
#define ___bpf_fill1(arr, p, x) arr[p] = x
|
||||
#define ___bpf_fill2(arr, p, x, args...) arr[p] = x; ___bpf_fill1(arr, p + 1, args)
|
||||
#define ___bpf_fill3(arr, p, x, args...) arr[p] = x; ___bpf_fill2(arr, p + 1, args)
|
||||
#define ___bpf_fill4(arr, p, x, args...) arr[p] = x; ___bpf_fill3(arr, p + 1, args)
|
||||
#define ___bpf_fill5(arr, p, x, args...) arr[p] = x; ___bpf_fill4(arr, p + 1, args)
|
||||
#define ___bpf_fill6(arr, p, x, args...) arr[p] = x; ___bpf_fill5(arr, p + 1, args)
|
||||
#define ___bpf_fill7(arr, p, x, args...) arr[p] = x; ___bpf_fill6(arr, p + 1, args)
|
||||
#define ___bpf_fill8(arr, p, x, args...) arr[p] = x; ___bpf_fill7(arr, p + 1, args)
|
||||
#define ___bpf_fill9(arr, p, x, args...) arr[p] = x; ___bpf_fill8(arr, p + 1, args)
|
||||
#define ___bpf_fill10(arr, p, x, args...) arr[p] = x; ___bpf_fill9(arr, p + 1, args)
|
||||
#define ___bpf_fill11(arr, p, x, args...) arr[p] = x; ___bpf_fill10(arr, p + 1, args)
|
||||
#define ___bpf_fill12(arr, p, x, args...) arr[p] = x; ___bpf_fill11(arr, p + 1, args)
|
||||
#define ___bpf_fill(arr, args...) \
|
||||
___bpf_apply(___bpf_fill, ___bpf_narg(args))(arr, 0, args)
|
||||
|
||||
/*
|
||||
* BPF_SEQ_PRINTF to wrap bpf_seq_printf to-be-printed values
|
||||
* in a structure.
|
||||
*/
|
||||
#define BPF_SEQ_PRINTF(seq, fmt, args...) \
|
||||
({ \
|
||||
static const char ___fmt[] = fmt; \
|
||||
unsigned long long ___param[___bpf_narg(args)]; \
|
||||
\
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
|
||||
___bpf_fill(___param, args); \
|
||||
_Pragma("GCC diagnostic pop") \
|
||||
\
|
||||
bpf_seq_printf(seq, ___fmt, sizeof(___fmt), \
|
||||
___param, sizeof(___param)); \
|
||||
})
|
||||
|
||||
/*
|
||||
* BPF_SNPRINTF wraps the bpf_snprintf helper with variadic arguments instead of
|
||||
* an array of u64.
|
||||
*/
|
||||
#define BPF_SNPRINTF(out, out_size, fmt, args...) \
|
||||
({ \
|
||||
static const char ___fmt[] = fmt; \
|
||||
unsigned long long ___param[___bpf_narg(args)]; \
|
||||
\
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
|
||||
___bpf_fill(___param, args); \
|
||||
_Pragma("GCC diagnostic pop") \
|
||||
\
|
||||
bpf_snprintf(out, out_size, ___fmt, \
|
||||
___param, sizeof(___param)); \
|
||||
})
|
||||
|
||||
#ifdef BPF_NO_GLOBAL_DATA
|
||||
#define BPF_PRINTK_FMT_MOD
|
||||
#else
|
||||
#define BPF_PRINTK_FMT_MOD static const
|
||||
#endif
|
||||
|
||||
#define __bpf_printk(fmt, ...) \
|
||||
({ \
|
||||
BPF_PRINTK_FMT_MOD char ____fmt[] = fmt; \
|
||||
bpf_trace_printk(____fmt, sizeof(____fmt), \
|
||||
##__VA_ARGS__); \
|
||||
})
|
||||
|
||||
/*
|
||||
* __bpf_vprintk wraps the bpf_trace_vprintk helper with variadic arguments
|
||||
* instead of an array of u64.
|
||||
*/
|
||||
#define __bpf_vprintk(fmt, args...) \
|
||||
({ \
|
||||
static const char ___fmt[] = fmt; \
|
||||
unsigned long long ___param[___bpf_narg(args)]; \
|
||||
\
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
|
||||
___bpf_fill(___param, args); \
|
||||
_Pragma("GCC diagnostic pop") \
|
||||
\
|
||||
bpf_trace_vprintk(___fmt, sizeof(___fmt), \
|
||||
___param, sizeof(___param)); \
|
||||
})
|
||||
|
||||
/* Use __bpf_printk when bpf_printk call has 3 or fewer fmt args
|
||||
* Otherwise use __bpf_vprintk
|
||||
*/
|
||||
#define ___bpf_pick_printk(...) \
|
||||
___bpf_nth(_, ##__VA_ARGS__, __bpf_vprintk, __bpf_vprintk, __bpf_vprintk, \
|
||||
__bpf_vprintk, __bpf_vprintk, __bpf_vprintk, __bpf_vprintk, \
|
||||
__bpf_vprintk, __bpf_vprintk, __bpf_printk /*3*/, __bpf_printk /*2*/,\
|
||||
__bpf_printk /*1*/, __bpf_printk /*0*/)
|
||||
|
||||
/* Helper macro to print out debug messages */
|
||||
#define bpf_printk(fmt, args...) ___bpf_pick_printk(args)(fmt, ##args)
|
||||
|
||||
struct bpf_iter_num;
|
||||
|
||||
extern int bpf_iter_num_new(struct bpf_iter_num *it, int start, int end) __weak __ksym;
|
||||
extern int *bpf_iter_num_next(struct bpf_iter_num *it) __weak __ksym;
|
||||
extern void bpf_iter_num_destroy(struct bpf_iter_num *it) __weak __ksym;
|
||||
|
||||
#ifndef bpf_for_each
|
||||
/* bpf_for_each(iter_type, cur_elem, args...) provides generic construct for
|
||||
* using BPF open-coded iterators without having to write mundane explicit
|
||||
* low-level loop logic. Instead, it provides for()-like generic construct
|
||||
* that can be used pretty naturally. E.g., for some hypothetical cgroup
|
||||
* iterator, you'd write:
|
||||
*
|
||||
* struct cgroup *cg, *parent_cg = <...>;
|
||||
*
|
||||
* bpf_for_each(cgroup, cg, parent_cg, CG_ITER_CHILDREN) {
|
||||
* bpf_printk("Child cgroup id = %d", cg->cgroup_id);
|
||||
* if (cg->cgroup_id == 123)
|
||||
* break;
|
||||
* }
|
||||
*
|
||||
* I.e., it looks almost like high-level for each loop in other languages,
|
||||
* supports continue/break, and is verifiable by BPF verifier.
|
||||
*
|
||||
* For iterating integers, the difference betwen bpf_for_each(num, i, N, M)
|
||||
* and bpf_for(i, N, M) is in that bpf_for() provides additional proof to
|
||||
* verifier that i is in [N, M) range, and in bpf_for_each() case i is `int
|
||||
* *`, not just `int`. So for integers bpf_for() is more convenient.
|
||||
*
|
||||
* Note: this macro relies on C99 feature of allowing to declare variables
|
||||
* inside for() loop, bound to for() loop lifetime. It also utilizes GCC
|
||||
* extension: __attribute__((cleanup(<func>))), supported by both GCC and
|
||||
* Clang.
|
||||
*/
|
||||
#define bpf_for_each(type, cur, args...) for ( \
|
||||
/* initialize and define destructor */ \
|
||||
struct bpf_iter_##type ___it __attribute__((aligned(8), /* enforce, just in case */, \
|
||||
cleanup(bpf_iter_##type##_destroy))), \
|
||||
/* ___p pointer is just to call bpf_iter_##type##_new() *once* to init ___it */ \
|
||||
*___p __attribute__((unused)) = ( \
|
||||
bpf_iter_##type##_new(&___it, ##args), \
|
||||
/* this is a workaround for Clang bug: it currently doesn't emit BTF */ \
|
||||
/* for bpf_iter_##type##_destroy() when used from cleanup() attribute */ \
|
||||
(void)bpf_iter_##type##_destroy, (void *)0); \
|
||||
/* iteration and termination check */ \
|
||||
(((cur) = bpf_iter_##type##_next(&___it))); \
|
||||
)
|
||||
#endif /* bpf_for_each */
|
||||
|
||||
#ifndef bpf_for
|
||||
/* bpf_for(i, start, end) implements a for()-like looping construct that sets
|
||||
* provided integer variable *i* to values starting from *start* through,
|
||||
* but not including, *end*. It also proves to BPF verifier that *i* belongs
|
||||
* to range [start, end), so this can be used for accessing arrays without
|
||||
* extra checks.
|
||||
*
|
||||
* Note: *start* and *end* are assumed to be expressions with no side effects
|
||||
* and whose values do not change throughout bpf_for() loop execution. They do
|
||||
* not have to be statically known or constant, though.
|
||||
*
|
||||
* Note: similarly to bpf_for_each(), it relies on C99 feature of declaring for()
|
||||
* loop bound variables and cleanup attribute, supported by GCC and Clang.
|
||||
*/
|
||||
#define bpf_for(i, start, end) for ( \
|
||||
/* initialize and define destructor */ \
|
||||
struct bpf_iter_num ___it __attribute__((aligned(8), /* enforce, just in case */ \
|
||||
cleanup(bpf_iter_num_destroy))), \
|
||||
/* ___p pointer is necessary to call bpf_iter_num_new() *once* to init ___it */ \
|
||||
*___p __attribute__((unused)) = ( \
|
||||
bpf_iter_num_new(&___it, (start), (end)), \
|
||||
/* this is a workaround for Clang bug: it currently doesn't emit BTF */ \
|
||||
/* for bpf_iter_num_destroy() when used from cleanup() attribute */ \
|
||||
(void)bpf_iter_num_destroy, (void *)0); \
|
||||
({ \
|
||||
/* iteration step */ \
|
||||
int *___t = bpf_iter_num_next(&___it); \
|
||||
/* termination and bounds check */ \
|
||||
(___t && ((i) = *___t, (i) >= (start) && (i) < (end))); \
|
||||
}); \
|
||||
)
|
||||
#endif /* bpf_for */
|
||||
|
||||
#ifndef bpf_repeat
|
||||
/* bpf_repeat(N) performs N iterations without exposing iteration number
|
||||
*
|
||||
* Note: similarly to bpf_for_each(), it relies on C99 feature of declaring for()
|
||||
* loop bound variables and cleanup attribute, supported by GCC and Clang.
|
||||
*/
|
||||
#define bpf_repeat(N) for ( \
|
||||
/* initialize and define destructor */ \
|
||||
struct bpf_iter_num ___it __attribute__((aligned(8), /* enforce, just in case */ \
|
||||
cleanup(bpf_iter_num_destroy))), \
|
||||
/* ___p pointer is necessary to call bpf_iter_num_new() *once* to init ___it */ \
|
||||
*___p __attribute__((unused)) = ( \
|
||||
bpf_iter_num_new(&___it, 0, (N)), \
|
||||
/* this is a workaround for Clang bug: it currently doesn't emit BTF */ \
|
||||
/* for bpf_iter_num_destroy() when used from cleanup() attribute */ \
|
||||
(void)bpf_iter_num_destroy, (void *)0); \
|
||||
bpf_iter_num_next(&___it); \
|
||||
/* nothing here */ \
|
||||
)
|
||||
#endif /* bpf_repeat */
|
||||
|
||||
#endif
|
||||
924
service/firewall/interception/ebpf/programs/bpf/bpf_tracing.h
Normal file
924
service/firewall/interception/ebpf/programs/bpf/bpf_tracing.h
Normal file
@@ -0,0 +1,924 @@
|
||||
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
|
||||
#ifndef __BPF_TRACING_H__
|
||||
#define __BPF_TRACING_H__
|
||||
|
||||
// #include <bpf/bpf_helpers.h>
|
||||
#include "bpf_helpers.h"
|
||||
|
||||
/* Scan the ARCH passed in from ARCH env variable (see Makefile) */
|
||||
#if defined(__TARGET_ARCH_x86)
|
||||
#define bpf_target_x86
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_s390)
|
||||
#define bpf_target_s390
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_arm)
|
||||
#define bpf_target_arm
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_arm64)
|
||||
#define bpf_target_arm64
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_mips)
|
||||
#define bpf_target_mips
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_powerpc)
|
||||
#define bpf_target_powerpc
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_sparc)
|
||||
#define bpf_target_sparc
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_riscv)
|
||||
#define bpf_target_riscv
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_arc)
|
||||
#define bpf_target_arc
|
||||
#define bpf_target_defined
|
||||
#elif defined(__TARGET_ARCH_loongarch)
|
||||
#define bpf_target_loongarch
|
||||
#define bpf_target_defined
|
||||
#else
|
||||
|
||||
/* Fall back to what the compiler says */
|
||||
#if defined(__x86_64__)
|
||||
#define bpf_target_x86
|
||||
#define bpf_target_defined
|
||||
#elif defined(__s390__)
|
||||
#define bpf_target_s390
|
||||
#define bpf_target_defined
|
||||
#elif defined(__arm__)
|
||||
#define bpf_target_arm
|
||||
#define bpf_target_defined
|
||||
#elif defined(__aarch64__)
|
||||
#define bpf_target_arm64
|
||||
#define bpf_target_defined
|
||||
#elif defined(__mips__)
|
||||
#define bpf_target_mips
|
||||
#define bpf_target_defined
|
||||
#elif defined(__powerpc__)
|
||||
#define bpf_target_powerpc
|
||||
#define bpf_target_defined
|
||||
#elif defined(__sparc__)
|
||||
#define bpf_target_sparc
|
||||
#define bpf_target_defined
|
||||
#elif defined(__riscv) && __riscv_xlen == 64
|
||||
#define bpf_target_riscv
|
||||
#define bpf_target_defined
|
||||
#elif defined(__arc__)
|
||||
#define bpf_target_arc
|
||||
#define bpf_target_defined
|
||||
#elif defined(__loongarch__)
|
||||
#define bpf_target_loongarch
|
||||
#define bpf_target_defined
|
||||
#endif /* no compiler target */
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef __BPF_TARGET_MISSING
|
||||
#define __BPF_TARGET_MISSING "GCC error \"Must specify a BPF target arch via __TARGET_ARCH_xxx\""
|
||||
#endif
|
||||
|
||||
#if defined(bpf_target_x86)
|
||||
|
||||
/*
|
||||
* https://en.wikipedia.org/wiki/X86_calling_conventions#System_V_AMD64_ABI
|
||||
*/
|
||||
|
||||
#if defined(__KERNEL__) || defined(__VMLINUX_H__)
|
||||
|
||||
#define __PT_PARM1_REG di
|
||||
#define __PT_PARM2_REG si
|
||||
#define __PT_PARM3_REG dx
|
||||
#define __PT_PARM4_REG cx
|
||||
#define __PT_PARM5_REG r8
|
||||
#define __PT_PARM6_REG r9
|
||||
/*
|
||||
* Syscall uses r10 for PARM4. See arch/x86/entry/entry_64.S:entry_SYSCALL_64
|
||||
* comments in Linux sources. And refer to syscall(2) manpage.
|
||||
*/
|
||||
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG r10
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
|
||||
|
||||
#define __PT_RET_REG sp
|
||||
#define __PT_FP_REG bp
|
||||
#define __PT_RC_REG ax
|
||||
#define __PT_SP_REG sp
|
||||
#define __PT_IP_REG ip
|
||||
|
||||
#else
|
||||
|
||||
#ifdef __i386__
|
||||
|
||||
/* i386 kernel is built with -mregparm=3 */
|
||||
#define __PT_PARM1_REG eax
|
||||
#define __PT_PARM2_REG edx
|
||||
#define __PT_PARM3_REG ecx
|
||||
/* i386 syscall ABI is very different, refer to syscall(2) manpage */
|
||||
#define __PT_PARM1_SYSCALL_REG ebx
|
||||
#define __PT_PARM2_SYSCALL_REG ecx
|
||||
#define __PT_PARM3_SYSCALL_REG edx
|
||||
#define __PT_PARM4_SYSCALL_REG esi
|
||||
#define __PT_PARM5_SYSCALL_REG edi
|
||||
#define __PT_PARM6_SYSCALL_REG ebp
|
||||
|
||||
#define __PT_RET_REG esp
|
||||
#define __PT_FP_REG ebp
|
||||
#define __PT_RC_REG eax
|
||||
#define __PT_SP_REG esp
|
||||
#define __PT_IP_REG eip
|
||||
|
||||
#else /* __i386__ */
|
||||
|
||||
#define __PT_PARM1_REG rdi
|
||||
#define __PT_PARM2_REG rsi
|
||||
#define __PT_PARM3_REG rdx
|
||||
#define __PT_PARM4_REG rcx
|
||||
#define __PT_PARM5_REG r8
|
||||
#define __PT_PARM6_REG r9
|
||||
|
||||
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG r10
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
|
||||
|
||||
#define __PT_RET_REG rsp
|
||||
#define __PT_FP_REG rbp
|
||||
#define __PT_RC_REG rax
|
||||
#define __PT_SP_REG rsp
|
||||
#define __PT_IP_REG rip
|
||||
|
||||
#endif /* __i386__ */
|
||||
|
||||
#endif /* __KERNEL__ || __VMLINUX_H__ */
|
||||
|
||||
#elif defined(bpf_target_s390)
|
||||
|
||||
/*
|
||||
* https://github.com/IBM/s390x-abi/releases/download/v1.6/lzsabi_s390x.pdf
|
||||
*/
|
||||
|
||||
struct pt_regs___s390 {
|
||||
unsigned long orig_gpr2;
|
||||
};
|
||||
|
||||
/* s390 provides user_pt_regs instead of struct pt_regs to userspace */
|
||||
#define __PT_REGS_CAST(x) ((const user_pt_regs *)(x))
|
||||
#define __PT_PARM1_REG gprs[2]
|
||||
#define __PT_PARM2_REG gprs[3]
|
||||
#define __PT_PARM3_REG gprs[4]
|
||||
#define __PT_PARM4_REG gprs[5]
|
||||
#define __PT_PARM5_REG gprs[6]
|
||||
|
||||
#define __PT_PARM1_SYSCALL_REG orig_gpr2
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG gprs[7]
|
||||
#define PT_REGS_PARM1_SYSCALL(x) PT_REGS_PARM1_CORE_SYSCALL(x)
|
||||
#define PT_REGS_PARM1_CORE_SYSCALL(x) \
|
||||
BPF_CORE_READ((const struct pt_regs___s390 *)(x), __PT_PARM1_SYSCALL_REG)
|
||||
|
||||
#define __PT_RET_REG gprs[14]
|
||||
#define __PT_FP_REG gprs[11] /* Works only with CONFIG_FRAME_POINTER */
|
||||
#define __PT_RC_REG gprs[2]
|
||||
#define __PT_SP_REG gprs[15]
|
||||
#define __PT_IP_REG psw.addr
|
||||
|
||||
#elif defined(bpf_target_arm)
|
||||
|
||||
/*
|
||||
* https://github.com/ARM-software/abi-aa/blob/main/aapcs32/aapcs32.rst#machine-registers
|
||||
*/
|
||||
|
||||
#define __PT_PARM1_REG uregs[0]
|
||||
#define __PT_PARM2_REG uregs[1]
|
||||
#define __PT_PARM3_REG uregs[2]
|
||||
#define __PT_PARM4_REG uregs[3]
|
||||
|
||||
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG uregs[4]
|
||||
#define __PT_PARM6_SYSCALL_REG uregs[5]
|
||||
#define __PT_PARM7_SYSCALL_REG uregs[6]
|
||||
|
||||
#define __PT_RET_REG uregs[14]
|
||||
#define __PT_FP_REG uregs[11] /* Works only with CONFIG_FRAME_POINTER */
|
||||
#define __PT_RC_REG uregs[0]
|
||||
#define __PT_SP_REG uregs[13]
|
||||
#define __PT_IP_REG uregs[12]
|
||||
|
||||
#elif defined(bpf_target_arm64)
|
||||
|
||||
/*
|
||||
* https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#machine-registers
|
||||
*/
|
||||
|
||||
struct pt_regs___arm64 {
|
||||
unsigned long orig_x0;
|
||||
};
|
||||
|
||||
/* arm64 provides struct user_pt_regs instead of struct pt_regs to userspace */
|
||||
#define __PT_REGS_CAST(x) ((const struct user_pt_regs *)(x))
|
||||
#define __PT_PARM1_REG regs[0]
|
||||
#define __PT_PARM2_REG regs[1]
|
||||
#define __PT_PARM3_REG regs[2]
|
||||
#define __PT_PARM4_REG regs[3]
|
||||
#define __PT_PARM5_REG regs[4]
|
||||
#define __PT_PARM6_REG regs[5]
|
||||
#define __PT_PARM7_REG regs[6]
|
||||
#define __PT_PARM8_REG regs[7]
|
||||
|
||||
#define __PT_PARM1_SYSCALL_REG orig_x0
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
|
||||
#define PT_REGS_PARM1_SYSCALL(x) PT_REGS_PARM1_CORE_SYSCALL(x)
|
||||
#define PT_REGS_PARM1_CORE_SYSCALL(x) \
|
||||
BPF_CORE_READ((const struct pt_regs___arm64 *)(x), __PT_PARM1_SYSCALL_REG)
|
||||
|
||||
#define __PT_RET_REG regs[30]
|
||||
#define __PT_FP_REG regs[29] /* Works only with CONFIG_FRAME_POINTER */
|
||||
#define __PT_RC_REG regs[0]
|
||||
#define __PT_SP_REG sp
|
||||
#define __PT_IP_REG pc
|
||||
|
||||
#elif defined(bpf_target_mips)
|
||||
|
||||
/*
|
||||
* N64 ABI is assumed right now.
|
||||
* https://en.wikipedia.org/wiki/MIPS_architecture#Calling_conventions
|
||||
*/
|
||||
|
||||
#define __PT_PARM1_REG regs[4]
|
||||
#define __PT_PARM2_REG regs[5]
|
||||
#define __PT_PARM3_REG regs[6]
|
||||
#define __PT_PARM4_REG regs[7]
|
||||
#define __PT_PARM5_REG regs[8]
|
||||
#define __PT_PARM6_REG regs[9]
|
||||
#define __PT_PARM7_REG regs[10]
|
||||
#define __PT_PARM8_REG regs[11]
|
||||
|
||||
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG /* only N32/N64 */
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG /* only N32/N64 */
|
||||
|
||||
#define __PT_RET_REG regs[31]
|
||||
#define __PT_FP_REG regs[30] /* Works only with CONFIG_FRAME_POINTER */
|
||||
#define __PT_RC_REG regs[2]
|
||||
#define __PT_SP_REG regs[29]
|
||||
#define __PT_IP_REG cp0_epc
|
||||
|
||||
#elif defined(bpf_target_powerpc)
|
||||
|
||||
/*
|
||||
* http://refspecs.linux-foundation.org/elf/elfspec_ppc.pdf (page 3-14,
|
||||
* section "Function Calling Sequence")
|
||||
*/
|
||||
|
||||
#define __PT_PARM1_REG gpr[3]
|
||||
#define __PT_PARM2_REG gpr[4]
|
||||
#define __PT_PARM3_REG gpr[5]
|
||||
#define __PT_PARM4_REG gpr[6]
|
||||
#define __PT_PARM5_REG gpr[7]
|
||||
#define __PT_PARM6_REG gpr[8]
|
||||
#define __PT_PARM7_REG gpr[9]
|
||||
#define __PT_PARM8_REG gpr[10]
|
||||
|
||||
/* powerpc does not select ARCH_HAS_SYSCALL_WRAPPER. */
|
||||
#define PT_REGS_SYSCALL_REGS(ctx) ctx
|
||||
#define __PT_PARM1_SYSCALL_REG orig_gpr3
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
|
||||
#if !defined(__arch64__)
|
||||
#define __PT_PARM7_SYSCALL_REG __PT_PARM7_REG /* only powerpc (not powerpc64) */
|
||||
#endif
|
||||
|
||||
#define __PT_RET_REG regs[31]
|
||||
#define __PT_FP_REG __unsupported__
|
||||
#define __PT_RC_REG gpr[3]
|
||||
#define __PT_SP_REG sp
|
||||
#define __PT_IP_REG nip
|
||||
|
||||
#elif defined(bpf_target_sparc)
|
||||
|
||||
/*
|
||||
* https://en.wikipedia.org/wiki/Calling_convention#SPARC
|
||||
*/
|
||||
|
||||
#define __PT_PARM1_REG u_regs[UREG_I0]
|
||||
#define __PT_PARM2_REG u_regs[UREG_I1]
|
||||
#define __PT_PARM3_REG u_regs[UREG_I2]
|
||||
#define __PT_PARM4_REG u_regs[UREG_I3]
|
||||
#define __PT_PARM5_REG u_regs[UREG_I4]
|
||||
#define __PT_PARM6_REG u_regs[UREG_I5]
|
||||
|
||||
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
|
||||
|
||||
#define __PT_RET_REG u_regs[UREG_I7]
|
||||
#define __PT_FP_REG __unsupported__
|
||||
#define __PT_RC_REG u_regs[UREG_I0]
|
||||
#define __PT_SP_REG u_regs[UREG_FP]
|
||||
/* Should this also be a bpf_target check for the sparc case? */
|
||||
#if defined(__arch64__)
|
||||
#define __PT_IP_REG tpc
|
||||
#else
|
||||
#define __PT_IP_REG pc
|
||||
#endif
|
||||
|
||||
#elif defined(bpf_target_riscv)
|
||||
|
||||
/*
|
||||
* https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#risc-v-calling-conventions
|
||||
*/
|
||||
|
||||
#define __PT_REGS_CAST(x) ((const struct user_regs_struct *)(x))
|
||||
#define __PT_PARM1_REG a0
|
||||
#define __PT_PARM2_REG a1
|
||||
#define __PT_PARM3_REG a2
|
||||
#define __PT_PARM4_REG a3
|
||||
#define __PT_PARM5_REG a4
|
||||
#define __PT_PARM6_REG a5
|
||||
#define __PT_PARM7_REG a6
|
||||
#define __PT_PARM8_REG a7
|
||||
|
||||
/* riscv does not select ARCH_HAS_SYSCALL_WRAPPER. */
|
||||
#define PT_REGS_SYSCALL_REGS(ctx) ctx
|
||||
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
|
||||
|
||||
#define __PT_RET_REG ra
|
||||
#define __PT_FP_REG s0
|
||||
#define __PT_RC_REG a0
|
||||
#define __PT_SP_REG sp
|
||||
#define __PT_IP_REG pc
|
||||
|
||||
#elif defined(bpf_target_arc)
|
||||
|
||||
/*
|
||||
* Section "Function Calling Sequence" (page 24):
|
||||
* https://raw.githubusercontent.com/wiki/foss-for-synopsys-dwc-arc-processors/toolchain/files/ARCv2_ABI.pdf
|
||||
*/
|
||||
|
||||
/* arc provides struct user_pt_regs instead of struct pt_regs to userspace */
|
||||
#define __PT_REGS_CAST(x) ((const struct user_regs_struct *)(x))
|
||||
#define __PT_PARM1_REG scratch.r0
|
||||
#define __PT_PARM2_REG scratch.r1
|
||||
#define __PT_PARM3_REG scratch.r2
|
||||
#define __PT_PARM4_REG scratch.r3
|
||||
#define __PT_PARM5_REG scratch.r4
|
||||
#define __PT_PARM6_REG scratch.r5
|
||||
#define __PT_PARM7_REG scratch.r6
|
||||
#define __PT_PARM8_REG scratch.r7
|
||||
|
||||
/* arc does not select ARCH_HAS_SYSCALL_WRAPPER. */
|
||||
#define PT_REGS_SYSCALL_REGS(ctx) ctx
|
||||
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
|
||||
|
||||
#define __PT_RET_REG scratch.blink
|
||||
#define __PT_FP_REG scratch.fp
|
||||
#define __PT_RC_REG scratch.r0
|
||||
#define __PT_SP_REG scratch.sp
|
||||
#define __PT_IP_REG scratch.ret
|
||||
|
||||
#elif defined(bpf_target_loongarch)
|
||||
|
||||
/*
|
||||
* https://docs.kernel.org/loongarch/introduction.html
|
||||
* https://loongson.github.io/LoongArch-Documentation/LoongArch-ELF-ABI-EN.html
|
||||
*/
|
||||
|
||||
/* loongarch provides struct user_pt_regs instead of struct pt_regs to userspace */
|
||||
#define __PT_REGS_CAST(x) ((const struct user_pt_regs *)(x))
|
||||
#define __PT_PARM1_REG regs[4]
|
||||
#define __PT_PARM2_REG regs[5]
|
||||
#define __PT_PARM3_REG regs[6]
|
||||
#define __PT_PARM4_REG regs[7]
|
||||
#define __PT_PARM5_REG regs[8]
|
||||
#define __PT_PARM6_REG regs[9]
|
||||
#define __PT_PARM7_REG regs[10]
|
||||
#define __PT_PARM8_REG regs[11]
|
||||
|
||||
/* loongarch does not select ARCH_HAS_SYSCALL_WRAPPER. */
|
||||
#define PT_REGS_SYSCALL_REGS(ctx) ctx
|
||||
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
|
||||
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
|
||||
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
|
||||
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
|
||||
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
|
||||
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
|
||||
|
||||
#define __PT_RET_REG regs[1]
|
||||
#define __PT_FP_REG regs[22]
|
||||
#define __PT_RC_REG regs[4]
|
||||
#define __PT_SP_REG regs[3]
|
||||
#define __PT_IP_REG csr_era
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(bpf_target_defined)
|
||||
|
||||
struct pt_regs;
|
||||
|
||||
/* allow some architectures to override `struct pt_regs` */
|
||||
#ifndef __PT_REGS_CAST
|
||||
#define __PT_REGS_CAST(x) (x)
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Different architectures support different number of arguments passed
|
||||
* through registers. i386 supports just 3, some arches support up to 8.
|
||||
*/
|
||||
#ifndef __PT_PARM4_REG
|
||||
#define __PT_PARM4_REG __unsupported__
|
||||
#endif
|
||||
#ifndef __PT_PARM5_REG
|
||||
#define __PT_PARM5_REG __unsupported__
|
||||
#endif
|
||||
#ifndef __PT_PARM6_REG
|
||||
#define __PT_PARM6_REG __unsupported__
|
||||
#endif
|
||||
#ifndef __PT_PARM7_REG
|
||||
#define __PT_PARM7_REG __unsupported__
|
||||
#endif
|
||||
#ifndef __PT_PARM8_REG
|
||||
#define __PT_PARM8_REG __unsupported__
|
||||
#endif
|
||||
/*
|
||||
* Similarly, syscall-specific conventions might differ between function call
|
||||
* conventions within each architecutre. All supported architectures pass
|
||||
* either 6 or 7 syscall arguments in registers.
|
||||
*
|
||||
* See syscall(2) manpage for succinct table with information on each arch.
|
||||
*/
|
||||
#ifndef __PT_PARM7_SYSCALL_REG
|
||||
#define __PT_PARM7_SYSCALL_REG __unsupported__
|
||||
#endif
|
||||
|
||||
#define PT_REGS_PARM1(x) (__PT_REGS_CAST(x)->__PT_PARM1_REG)
|
||||
#define PT_REGS_PARM2(x) (__PT_REGS_CAST(x)->__PT_PARM2_REG)
|
||||
#define PT_REGS_PARM3(x) (__PT_REGS_CAST(x)->__PT_PARM3_REG)
|
||||
#define PT_REGS_PARM4(x) (__PT_REGS_CAST(x)->__PT_PARM4_REG)
|
||||
#define PT_REGS_PARM5(x) (__PT_REGS_CAST(x)->__PT_PARM5_REG)
|
||||
#define PT_REGS_PARM6(x) (__PT_REGS_CAST(x)->__PT_PARM6_REG)
|
||||
#define PT_REGS_PARM7(x) (__PT_REGS_CAST(x)->__PT_PARM7_REG)
|
||||
#define PT_REGS_PARM8(x) (__PT_REGS_CAST(x)->__PT_PARM8_REG)
|
||||
#define PT_REGS_RET(x) (__PT_REGS_CAST(x)->__PT_RET_REG)
|
||||
#define PT_REGS_FP(x) (__PT_REGS_CAST(x)->__PT_FP_REG)
|
||||
#define PT_REGS_RC(x) (__PT_REGS_CAST(x)->__PT_RC_REG)
|
||||
#define PT_REGS_SP(x) (__PT_REGS_CAST(x)->__PT_SP_REG)
|
||||
#define PT_REGS_IP(x) (__PT_REGS_CAST(x)->__PT_IP_REG)
|
||||
|
||||
#define PT_REGS_PARM1_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM1_REG)
|
||||
#define PT_REGS_PARM2_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM2_REG)
|
||||
#define PT_REGS_PARM3_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM3_REG)
|
||||
#define PT_REGS_PARM4_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM4_REG)
|
||||
#define PT_REGS_PARM5_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM5_REG)
|
||||
#define PT_REGS_PARM6_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM6_REG)
|
||||
#define PT_REGS_PARM7_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM7_REG)
|
||||
#define PT_REGS_PARM8_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM8_REG)
|
||||
#define PT_REGS_RET_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_RET_REG)
|
||||
#define PT_REGS_FP_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_FP_REG)
|
||||
#define PT_REGS_RC_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_RC_REG)
|
||||
#define PT_REGS_SP_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_SP_REG)
|
||||
#define PT_REGS_IP_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_IP_REG)
|
||||
|
||||
#if defined(bpf_target_powerpc)
|
||||
|
||||
#define BPF_KPROBE_READ_RET_IP(ip, ctx) ({ (ip) = (ctx)->link; })
|
||||
#define BPF_KRETPROBE_READ_RET_IP BPF_KPROBE_READ_RET_IP
|
||||
|
||||
#elif defined(bpf_target_sparc)
|
||||
|
||||
#define BPF_KPROBE_READ_RET_IP(ip, ctx) ({ (ip) = PT_REGS_RET(ctx); })
|
||||
#define BPF_KRETPROBE_READ_RET_IP BPF_KPROBE_READ_RET_IP
|
||||
|
||||
#else
|
||||
|
||||
#define BPF_KPROBE_READ_RET_IP(ip, ctx) \
|
||||
({ bpf_probe_read_kernel(&(ip), sizeof(ip), (void *)PT_REGS_RET(ctx)); })
|
||||
#define BPF_KRETPROBE_READ_RET_IP(ip, ctx) \
|
||||
({ bpf_probe_read_kernel(&(ip), sizeof(ip), (void *)(PT_REGS_FP(ctx) + sizeof(ip))); })
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef PT_REGS_PARM1_SYSCALL
|
||||
#define PT_REGS_PARM1_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM1_SYSCALL_REG)
|
||||
#define PT_REGS_PARM1_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM1_SYSCALL_REG)
|
||||
#endif
|
||||
#ifndef PT_REGS_PARM2_SYSCALL
|
||||
#define PT_REGS_PARM2_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM2_SYSCALL_REG)
|
||||
#define PT_REGS_PARM2_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM2_SYSCALL_REG)
|
||||
#endif
|
||||
#ifndef PT_REGS_PARM3_SYSCALL
|
||||
#define PT_REGS_PARM3_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM3_SYSCALL_REG)
|
||||
#define PT_REGS_PARM3_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM3_SYSCALL_REG)
|
||||
#endif
|
||||
#ifndef PT_REGS_PARM4_SYSCALL
|
||||
#define PT_REGS_PARM4_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM4_SYSCALL_REG)
|
||||
#define PT_REGS_PARM4_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM4_SYSCALL_REG)
|
||||
#endif
|
||||
#ifndef PT_REGS_PARM5_SYSCALL
|
||||
#define PT_REGS_PARM5_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM5_SYSCALL_REG)
|
||||
#define PT_REGS_PARM5_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM5_SYSCALL_REG)
|
||||
#endif
|
||||
#ifndef PT_REGS_PARM6_SYSCALL
|
||||
#define PT_REGS_PARM6_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM6_SYSCALL_REG)
|
||||
#define PT_REGS_PARM6_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM6_SYSCALL_REG)
|
||||
#endif
|
||||
#ifndef PT_REGS_PARM7_SYSCALL
|
||||
#define PT_REGS_PARM7_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM7_SYSCALL_REG)
|
||||
#define PT_REGS_PARM7_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM7_SYSCALL_REG)
|
||||
#endif
|
||||
|
||||
#else /* defined(bpf_target_defined) */
|
||||
|
||||
#define PT_REGS_PARM1(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM2(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM3(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM4(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM5(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM6(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM7(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM8(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_RET(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_FP(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_RC(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_SP(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_IP(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
|
||||
#define PT_REGS_PARM1_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM2_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM3_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM4_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM5_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM6_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM7_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM8_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_RET_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_FP_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_RC_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_SP_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_IP_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
|
||||
#define BPF_KPROBE_READ_RET_IP(ip, ctx) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define BPF_KRETPROBE_READ_RET_IP(ip, ctx) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
|
||||
#define PT_REGS_PARM1_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM2_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM3_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM4_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM5_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM6_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM7_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
|
||||
#define PT_REGS_PARM1_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM2_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM3_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM4_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM5_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM6_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
#define PT_REGS_PARM7_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
|
||||
|
||||
#endif /* defined(bpf_target_defined) */
|
||||
|
||||
/*
|
||||
* When invoked from a syscall handler kprobe, returns a pointer to a
|
||||
* struct pt_regs containing syscall arguments and suitable for passing to
|
||||
* PT_REGS_PARMn_SYSCALL() and PT_REGS_PARMn_CORE_SYSCALL().
|
||||
*/
|
||||
#ifndef PT_REGS_SYSCALL_REGS
|
||||
/* By default, assume that the arch selects ARCH_HAS_SYSCALL_WRAPPER. */
|
||||
#define PT_REGS_SYSCALL_REGS(ctx) ((struct pt_regs *)PT_REGS_PARM1(ctx))
|
||||
#endif
|
||||
|
||||
#ifndef ___bpf_concat
|
||||
#define ___bpf_concat(a, b) a ## b
|
||||
#endif
|
||||
#ifndef ___bpf_apply
|
||||
#define ___bpf_apply(fn, n) ___bpf_concat(fn, n)
|
||||
#endif
|
||||
#ifndef ___bpf_nth
|
||||
#define ___bpf_nth(_, _1, _2, _3, _4, _5, _6, _7, _8, _9, _a, _b, _c, N, ...) N
|
||||
#endif
|
||||
#ifndef ___bpf_narg
|
||||
#define ___bpf_narg(...) ___bpf_nth(_, ##__VA_ARGS__, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
|
||||
#endif
|
||||
|
||||
#define ___bpf_ctx_cast0() ctx
|
||||
#define ___bpf_ctx_cast1(x) ___bpf_ctx_cast0(), (void *)ctx[0]
|
||||
#define ___bpf_ctx_cast2(x, args...) ___bpf_ctx_cast1(args), (void *)ctx[1]
|
||||
#define ___bpf_ctx_cast3(x, args...) ___bpf_ctx_cast2(args), (void *)ctx[2]
|
||||
#define ___bpf_ctx_cast4(x, args...) ___bpf_ctx_cast3(args), (void *)ctx[3]
|
||||
#define ___bpf_ctx_cast5(x, args...) ___bpf_ctx_cast4(args), (void *)ctx[4]
|
||||
#define ___bpf_ctx_cast6(x, args...) ___bpf_ctx_cast5(args), (void *)ctx[5]
|
||||
#define ___bpf_ctx_cast7(x, args...) ___bpf_ctx_cast6(args), (void *)ctx[6]
|
||||
#define ___bpf_ctx_cast8(x, args...) ___bpf_ctx_cast7(args), (void *)ctx[7]
|
||||
#define ___bpf_ctx_cast9(x, args...) ___bpf_ctx_cast8(args), (void *)ctx[8]
|
||||
#define ___bpf_ctx_cast10(x, args...) ___bpf_ctx_cast9(args), (void *)ctx[9]
|
||||
#define ___bpf_ctx_cast11(x, args...) ___bpf_ctx_cast10(args), (void *)ctx[10]
|
||||
#define ___bpf_ctx_cast12(x, args...) ___bpf_ctx_cast11(args), (void *)ctx[11]
|
||||
#define ___bpf_ctx_cast(args...) ___bpf_apply(___bpf_ctx_cast, ___bpf_narg(args))(args)
|
||||
|
||||
/*
|
||||
* BPF_PROG is a convenience wrapper for generic tp_btf/fentry/fexit and
|
||||
* similar kinds of BPF programs, that accept input arguments as a single
|
||||
* pointer to untyped u64 array, where each u64 can actually be a typed
|
||||
* pointer or integer of different size. Instead of requring user to write
|
||||
* manual casts and work with array elements by index, BPF_PROG macro
|
||||
* allows user to declare a list of named and typed input arguments in the
|
||||
* same syntax as for normal C function. All the casting is hidden and
|
||||
* performed transparently, while user code can just assume working with
|
||||
* function arguments of specified type and name.
|
||||
*
|
||||
* Original raw context argument is preserved as well as 'ctx' argument.
|
||||
* This is useful when using BPF helpers that expect original context
|
||||
* as one of the parameters (e.g., for bpf_perf_event_output()).
|
||||
*/
|
||||
#define BPF_PROG(name, args...) \
|
||||
name(unsigned long long *ctx); \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(unsigned long long *ctx, ##args); \
|
||||
typeof(name(0)) name(unsigned long long *ctx) \
|
||||
{ \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
|
||||
return ____##name(___bpf_ctx_cast(args)); \
|
||||
_Pragma("GCC diagnostic pop") \
|
||||
} \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(unsigned long long *ctx, ##args)
|
||||
|
||||
#ifndef ___bpf_nth2
|
||||
#define ___bpf_nth2(_, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \
|
||||
_14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, N, ...) N
|
||||
#endif
|
||||
#ifndef ___bpf_narg2
|
||||
#define ___bpf_narg2(...) \
|
||||
___bpf_nth2(_, ##__VA_ARGS__, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, \
|
||||
6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0)
|
||||
#endif
|
||||
|
||||
#define ___bpf_treg_cnt(t) \
|
||||
__builtin_choose_expr(sizeof(t) == 1, 1, \
|
||||
__builtin_choose_expr(sizeof(t) == 2, 1, \
|
||||
__builtin_choose_expr(sizeof(t) == 4, 1, \
|
||||
__builtin_choose_expr(sizeof(t) == 8, 1, \
|
||||
__builtin_choose_expr(sizeof(t) == 16, 2, \
|
||||
(void)0)))))
|
||||
|
||||
#define ___bpf_reg_cnt0() (0)
|
||||
#define ___bpf_reg_cnt1(t, x) (___bpf_reg_cnt0() + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt2(t, x, args...) (___bpf_reg_cnt1(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt3(t, x, args...) (___bpf_reg_cnt2(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt4(t, x, args...) (___bpf_reg_cnt3(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt5(t, x, args...) (___bpf_reg_cnt4(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt6(t, x, args...) (___bpf_reg_cnt5(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt7(t, x, args...) (___bpf_reg_cnt6(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt8(t, x, args...) (___bpf_reg_cnt7(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt9(t, x, args...) (___bpf_reg_cnt8(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt10(t, x, args...) (___bpf_reg_cnt9(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt11(t, x, args...) (___bpf_reg_cnt10(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt12(t, x, args...) (___bpf_reg_cnt11(args) + ___bpf_treg_cnt(t))
|
||||
#define ___bpf_reg_cnt(args...) ___bpf_apply(___bpf_reg_cnt, ___bpf_narg2(args))(args)
|
||||
|
||||
#define ___bpf_union_arg(t, x, n) \
|
||||
__builtin_choose_expr(sizeof(t) == 1, ({ union { __u8 z[1]; t x; } ___t = { .z = {ctx[n]}}; ___t.x; }), \
|
||||
__builtin_choose_expr(sizeof(t) == 2, ({ union { __u16 z[1]; t x; } ___t = { .z = {ctx[n]} }; ___t.x; }), \
|
||||
__builtin_choose_expr(sizeof(t) == 4, ({ union { __u32 z[1]; t x; } ___t = { .z = {ctx[n]} }; ___t.x; }), \
|
||||
__builtin_choose_expr(sizeof(t) == 8, ({ union { __u64 z[1]; t x; } ___t = {.z = {ctx[n]} }; ___t.x; }), \
|
||||
__builtin_choose_expr(sizeof(t) == 16, ({ union { __u64 z[2]; t x; } ___t = {.z = {ctx[n], ctx[n + 1]} }; ___t.x; }), \
|
||||
(void)0)))))
|
||||
|
||||
#define ___bpf_ctx_arg0(n, args...)
|
||||
#define ___bpf_ctx_arg1(n, t, x) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt1(t, x))
|
||||
#define ___bpf_ctx_arg2(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt2(t, x, args)) ___bpf_ctx_arg1(n, args)
|
||||
#define ___bpf_ctx_arg3(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt3(t, x, args)) ___bpf_ctx_arg2(n, args)
|
||||
#define ___bpf_ctx_arg4(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt4(t, x, args)) ___bpf_ctx_arg3(n, args)
|
||||
#define ___bpf_ctx_arg5(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt5(t, x, args)) ___bpf_ctx_arg4(n, args)
|
||||
#define ___bpf_ctx_arg6(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt6(t, x, args)) ___bpf_ctx_arg5(n, args)
|
||||
#define ___bpf_ctx_arg7(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt7(t, x, args)) ___bpf_ctx_arg6(n, args)
|
||||
#define ___bpf_ctx_arg8(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt8(t, x, args)) ___bpf_ctx_arg7(n, args)
|
||||
#define ___bpf_ctx_arg9(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt9(t, x, args)) ___bpf_ctx_arg8(n, args)
|
||||
#define ___bpf_ctx_arg10(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt10(t, x, args)) ___bpf_ctx_arg9(n, args)
|
||||
#define ___bpf_ctx_arg11(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt11(t, x, args)) ___bpf_ctx_arg10(n, args)
|
||||
#define ___bpf_ctx_arg12(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt12(t, x, args)) ___bpf_ctx_arg11(n, args)
|
||||
#define ___bpf_ctx_arg(args...) ___bpf_apply(___bpf_ctx_arg, ___bpf_narg2(args))(___bpf_reg_cnt(args), args)
|
||||
|
||||
#define ___bpf_ctx_decl0()
|
||||
#define ___bpf_ctx_decl1(t, x) , t x
|
||||
#define ___bpf_ctx_decl2(t, x, args...) , t x ___bpf_ctx_decl1(args)
|
||||
#define ___bpf_ctx_decl3(t, x, args...) , t x ___bpf_ctx_decl2(args)
|
||||
#define ___bpf_ctx_decl4(t, x, args...) , t x ___bpf_ctx_decl3(args)
|
||||
#define ___bpf_ctx_decl5(t, x, args...) , t x ___bpf_ctx_decl4(args)
|
||||
#define ___bpf_ctx_decl6(t, x, args...) , t x ___bpf_ctx_decl5(args)
|
||||
#define ___bpf_ctx_decl7(t, x, args...) , t x ___bpf_ctx_decl6(args)
|
||||
#define ___bpf_ctx_decl8(t, x, args...) , t x ___bpf_ctx_decl7(args)
|
||||
#define ___bpf_ctx_decl9(t, x, args...) , t x ___bpf_ctx_decl8(args)
|
||||
#define ___bpf_ctx_decl10(t, x, args...) , t x ___bpf_ctx_decl9(args)
|
||||
#define ___bpf_ctx_decl11(t, x, args...) , t x ___bpf_ctx_decl10(args)
|
||||
#define ___bpf_ctx_decl12(t, x, args...) , t x ___bpf_ctx_decl11(args)
|
||||
#define ___bpf_ctx_decl(args...) ___bpf_apply(___bpf_ctx_decl, ___bpf_narg2(args))(args)
|
||||
|
||||
/*
|
||||
* BPF_PROG2 is an enhanced version of BPF_PROG in order to handle struct
|
||||
* arguments. Since each struct argument might take one or two u64 values
|
||||
* in the trampoline stack, argument type size is needed to place proper number
|
||||
* of u64 values for each argument. Therefore, BPF_PROG2 has different
|
||||
* syntax from BPF_PROG. For example, for the following BPF_PROG syntax:
|
||||
*
|
||||
* int BPF_PROG(test2, int a, int b) { ... }
|
||||
*
|
||||
* the corresponding BPF_PROG2 syntax is:
|
||||
*
|
||||
* int BPF_PROG2(test2, int, a, int, b) { ... }
|
||||
*
|
||||
* where type and the corresponding argument name are separated by comma.
|
||||
*
|
||||
* Use BPF_PROG2 macro if one of the arguments might be a struct/union larger
|
||||
* than 8 bytes:
|
||||
*
|
||||
* int BPF_PROG2(test_struct_arg, struct bpf_testmod_struct_arg_1, a, int, b,
|
||||
* int, c, int, d, struct bpf_testmod_struct_arg_2, e, int, ret)
|
||||
* {
|
||||
* // access a, b, c, d, e, and ret directly
|
||||
* ...
|
||||
* }
|
||||
*/
|
||||
#define BPF_PROG2(name, args...) \
|
||||
name(unsigned long long *ctx); \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(unsigned long long *ctx ___bpf_ctx_decl(args)); \
|
||||
typeof(name(0)) name(unsigned long long *ctx) \
|
||||
{ \
|
||||
return ____##name(ctx ___bpf_ctx_arg(args)); \
|
||||
} \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(unsigned long long *ctx ___bpf_ctx_decl(args))
|
||||
|
||||
struct pt_regs;
|
||||
|
||||
#define ___bpf_kprobe_args0() ctx
|
||||
#define ___bpf_kprobe_args1(x) ___bpf_kprobe_args0(), (void *)PT_REGS_PARM1(ctx)
|
||||
#define ___bpf_kprobe_args2(x, args...) ___bpf_kprobe_args1(args), (void *)PT_REGS_PARM2(ctx)
|
||||
#define ___bpf_kprobe_args3(x, args...) ___bpf_kprobe_args2(args), (void *)PT_REGS_PARM3(ctx)
|
||||
#define ___bpf_kprobe_args4(x, args...) ___bpf_kprobe_args3(args), (void *)PT_REGS_PARM4(ctx)
|
||||
#define ___bpf_kprobe_args5(x, args...) ___bpf_kprobe_args4(args), (void *)PT_REGS_PARM5(ctx)
|
||||
#define ___bpf_kprobe_args6(x, args...) ___bpf_kprobe_args5(args), (void *)PT_REGS_PARM6(ctx)
|
||||
#define ___bpf_kprobe_args7(x, args...) ___bpf_kprobe_args6(args), (void *)PT_REGS_PARM7(ctx)
|
||||
#define ___bpf_kprobe_args8(x, args...) ___bpf_kprobe_args7(args), (void *)PT_REGS_PARM8(ctx)
|
||||
#define ___bpf_kprobe_args(args...) ___bpf_apply(___bpf_kprobe_args, ___bpf_narg(args))(args)
|
||||
|
||||
/*
|
||||
* BPF_KPROBE serves the same purpose for kprobes as BPF_PROG for
|
||||
* tp_btf/fentry/fexit BPF programs. It hides the underlying platform-specific
|
||||
* low-level way of getting kprobe input arguments from struct pt_regs, and
|
||||
* provides a familiar typed and named function arguments syntax and
|
||||
* semantics of accessing kprobe input paremeters.
|
||||
*
|
||||
* Original struct pt_regs* context is preserved as 'ctx' argument. This might
|
||||
* be necessary when using BPF helpers like bpf_perf_event_output().
|
||||
*/
|
||||
#define BPF_KPROBE(name, args...) \
|
||||
name(struct pt_regs *ctx); \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(struct pt_regs *ctx, ##args); \
|
||||
typeof(name(0)) name(struct pt_regs *ctx) \
|
||||
{ \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
|
||||
return ____##name(___bpf_kprobe_args(args)); \
|
||||
_Pragma("GCC diagnostic pop") \
|
||||
} \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(struct pt_regs *ctx, ##args)
|
||||
|
||||
#define ___bpf_kretprobe_args0() ctx
|
||||
#define ___bpf_kretprobe_args1(x) ___bpf_kretprobe_args0(), (void *)PT_REGS_RC(ctx)
|
||||
#define ___bpf_kretprobe_args(args...) ___bpf_apply(___bpf_kretprobe_args, ___bpf_narg(args))(args)
|
||||
|
||||
/*
|
||||
* BPF_KRETPROBE is similar to BPF_KPROBE, except, it only provides optional
|
||||
* return value (in addition to `struct pt_regs *ctx`), but no input
|
||||
* arguments, because they will be clobbered by the time probed function
|
||||
* returns.
|
||||
*/
|
||||
#define BPF_KRETPROBE(name, args...) \
|
||||
name(struct pt_regs *ctx); \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(struct pt_regs *ctx, ##args); \
|
||||
typeof(name(0)) name(struct pt_regs *ctx) \
|
||||
{ \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
|
||||
return ____##name(___bpf_kretprobe_args(args)); \
|
||||
_Pragma("GCC diagnostic pop") \
|
||||
} \
|
||||
static __always_inline typeof(name(0)) ____##name(struct pt_regs *ctx, ##args)
|
||||
|
||||
/* If kernel has CONFIG_ARCH_HAS_SYSCALL_WRAPPER, read pt_regs directly */
|
||||
#define ___bpf_syscall_args0() ctx
|
||||
#define ___bpf_syscall_args1(x) ___bpf_syscall_args0(), (void *)PT_REGS_PARM1_SYSCALL(regs)
|
||||
#define ___bpf_syscall_args2(x, args...) ___bpf_syscall_args1(args), (void *)PT_REGS_PARM2_SYSCALL(regs)
|
||||
#define ___bpf_syscall_args3(x, args...) ___bpf_syscall_args2(args), (void *)PT_REGS_PARM3_SYSCALL(regs)
|
||||
#define ___bpf_syscall_args4(x, args...) ___bpf_syscall_args3(args), (void *)PT_REGS_PARM4_SYSCALL(regs)
|
||||
#define ___bpf_syscall_args5(x, args...) ___bpf_syscall_args4(args), (void *)PT_REGS_PARM5_SYSCALL(regs)
|
||||
#define ___bpf_syscall_args6(x, args...) ___bpf_syscall_args5(args), (void *)PT_REGS_PARM6_SYSCALL(regs)
|
||||
#define ___bpf_syscall_args7(x, args...) ___bpf_syscall_args6(args), (void *)PT_REGS_PARM7_SYSCALL(regs)
|
||||
#define ___bpf_syscall_args(args...) ___bpf_apply(___bpf_syscall_args, ___bpf_narg(args))(args)
|
||||
|
||||
/* If kernel doesn't have CONFIG_ARCH_HAS_SYSCALL_WRAPPER, we have to BPF_CORE_READ from pt_regs */
|
||||
#define ___bpf_syswrap_args0() ctx
|
||||
#define ___bpf_syswrap_args1(x) ___bpf_syswrap_args0(), (void *)PT_REGS_PARM1_CORE_SYSCALL(regs)
|
||||
#define ___bpf_syswrap_args2(x, args...) ___bpf_syswrap_args1(args), (void *)PT_REGS_PARM2_CORE_SYSCALL(regs)
|
||||
#define ___bpf_syswrap_args3(x, args...) ___bpf_syswrap_args2(args), (void *)PT_REGS_PARM3_CORE_SYSCALL(regs)
|
||||
#define ___bpf_syswrap_args4(x, args...) ___bpf_syswrap_args3(args), (void *)PT_REGS_PARM4_CORE_SYSCALL(regs)
|
||||
#define ___bpf_syswrap_args5(x, args...) ___bpf_syswrap_args4(args), (void *)PT_REGS_PARM5_CORE_SYSCALL(regs)
|
||||
#define ___bpf_syswrap_args6(x, args...) ___bpf_syswrap_args5(args), (void *)PT_REGS_PARM6_CORE_SYSCALL(regs)
|
||||
#define ___bpf_syswrap_args7(x, args...) ___bpf_syswrap_args6(args), (void *)PT_REGS_PARM7_CORE_SYSCALL(regs)
|
||||
#define ___bpf_syswrap_args(args...) ___bpf_apply(___bpf_syswrap_args, ___bpf_narg(args))(args)
|
||||
|
||||
/*
|
||||
* BPF_KSYSCALL is a variant of BPF_KPROBE, which is intended for
|
||||
* tracing syscall functions, like __x64_sys_close. It hides the underlying
|
||||
* platform-specific low-level way of getting syscall input arguments from
|
||||
* struct pt_regs, and provides a familiar typed and named function arguments
|
||||
* syntax and semantics of accessing syscall input parameters.
|
||||
*
|
||||
* Original struct pt_regs * context is preserved as 'ctx' argument. This might
|
||||
* be necessary when using BPF helpers like bpf_perf_event_output().
|
||||
*
|
||||
* At the moment BPF_KSYSCALL does not transparently handle all the calling
|
||||
* convention quirks for the following syscalls:
|
||||
*
|
||||
* - mmap(): __ARCH_WANT_SYS_OLD_MMAP.
|
||||
* - clone(): CONFIG_CLONE_BACKWARDS, CONFIG_CLONE_BACKWARDS2 and
|
||||
* CONFIG_CLONE_BACKWARDS3.
|
||||
* - socket-related syscalls: __ARCH_WANT_SYS_SOCKETCALL.
|
||||
* - compat syscalls.
|
||||
*
|
||||
* This may or may not change in the future. User needs to take extra measures
|
||||
* to handle such quirks explicitly, if necessary.
|
||||
*
|
||||
* This macro relies on BPF CO-RE support and virtual __kconfig externs.
|
||||
*/
|
||||
#define BPF_KSYSCALL(name, args...) \
|
||||
name(struct pt_regs *ctx); \
|
||||
extern _Bool LINUX_HAS_SYSCALL_WRAPPER __kconfig; \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(struct pt_regs *ctx, ##args); \
|
||||
typeof(name(0)) name(struct pt_regs *ctx) \
|
||||
{ \
|
||||
struct pt_regs *regs = LINUX_HAS_SYSCALL_WRAPPER \
|
||||
? (struct pt_regs *)PT_REGS_PARM1(ctx) \
|
||||
: ctx; \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
|
||||
if (LINUX_HAS_SYSCALL_WRAPPER) \
|
||||
return ____##name(___bpf_syswrap_args(args)); \
|
||||
else \
|
||||
return ____##name(___bpf_syscall_args(args)); \
|
||||
_Pragma("GCC diagnostic pop") \
|
||||
} \
|
||||
static __always_inline typeof(name(0)) \
|
||||
____##name(struct pt_regs *ctx, ##args)
|
||||
|
||||
#define BPF_KPROBE_SYSCALL BPF_KSYSCALL
|
||||
|
||||
/* BPF_UPROBE and BPF_URETPROBE are identical to BPF_KPROBE and BPF_KRETPROBE,
|
||||
* but are named way less confusingly for SEC("uprobe") and SEC("uretprobe")
|
||||
* use cases.
|
||||
*/
|
||||
#define BPF_UPROBE(name, args...) BPF_KPROBE(name, ##args)
|
||||
#define BPF_URETPROBE(name, args...) BPF_KRETPROBE(name, ##args)
|
||||
|
||||
#endif
|
||||
116
service/firewall/interception/ebpf/programs/exec.c
Normal file
116
service/firewall/interception/ebpf/programs/exec.c
Normal file
@@ -0,0 +1,116 @@
|
||||
#include "vmlinux-x86.h"
|
||||
#include "bpf/bpf_helpers.h"
|
||||
#include "bpf/bpf_tracing.h"
|
||||
|
||||
#define ARGLEN 32 // maximum amount of args in argv we'll copy
|
||||
#define ARGSIZE 1024 // maximum byte length of each arg in argv we'll copy
|
||||
|
||||
char __license[] SEC("license") = "GPL";
|
||||
|
||||
// Ring buffer for all connection events
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_RINGBUF);
|
||||
__uint(max_entries, 1 << 24);
|
||||
} pm_exec_map SEC(".maps");
|
||||
|
||||
// This struct is defined according to
|
||||
// /sys/kernel/debug/tracing/events/syscalls/sys_enter_execve/format
|
||||
struct exec_info {
|
||||
u16 common_type; // offset=0, size=2
|
||||
u8 common_flags; // offset=2, size=1
|
||||
u8 common_preempt_count; // offset=3, size=1
|
||||
s32 common_pid; // offset=4, size=4
|
||||
|
||||
s32 syscall_nr; // offset=8, size=4
|
||||
u32 pad; // offset=12, size=4 (pad)
|
||||
const u8 *filename; // offset=16, size=8 (ptr)
|
||||
const u8 *const *argv; // offset=24, size=8 (ptr)
|
||||
const u8 *const *envp; // offset=32, size=8 (ptr)
|
||||
};
|
||||
|
||||
// The event struct. This struct must be kept in sync with the Golang
|
||||
// counterpart.
|
||||
struct event_t {
|
||||
// Details about the process being launched.
|
||||
u8 filename[ARGSIZE];
|
||||
u8 argv[ARGLEN][ARGSIZE];
|
||||
u32 argc; // set to ARGLEN + 1 if there were more than ARGLEN arguments
|
||||
u32 uid;
|
||||
u32 gid;
|
||||
u32 pid;
|
||||
|
||||
// Name of the calling process.
|
||||
u8 comm[ARGSIZE];
|
||||
};
|
||||
|
||||
// Tracepoint at the top of execve() syscall.
|
||||
SEC("tracepoint/syscalls/sys_enter_execve")
|
||||
s32 enter_execve(struct exec_info *ctx) {
|
||||
// Reserve memory for our event on the `events` ring buffer defined above.
|
||||
struct event_t *event;
|
||||
event = bpf_ringbuf_reserve(&pm_exec_map, sizeof(struct event_t), 0);
|
||||
if (!event) {
|
||||
bpf_printk("could not reserve ringbuf memory");
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Store process/calling process details.
|
||||
u64 uidgid = bpf_get_current_uid_gid();
|
||||
u64 pidtgid = bpf_get_current_pid_tgid();
|
||||
event->uid = uidgid; // uid is the first 32 bits
|
||||
event->gid = uidgid >> 32; // gid is the last 32 bits NOLINT(readability-magic-numbers)
|
||||
event->pid = pidtgid; // pid is the first 32 bits
|
||||
s32 ret = bpf_get_current_comm(&event->comm, sizeof(event->comm));
|
||||
if (ret) {
|
||||
bpf_printk("could not get current comm: %d", ret);
|
||||
bpf_ringbuf_discard(event, 0);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Write the filename in addition to argv[0] because the filename contains
|
||||
// the full path to the file which could be more useful in some situations.
|
||||
ret = bpf_probe_read_user_str(event->filename, sizeof(event->filename), ctx->filename);
|
||||
if (ret < 0) {
|
||||
bpf_printk("could not read filename into event struct: %d", ret);
|
||||
bpf_ringbuf_discard(event, 0);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Copy everything from ctx->argv to event->argv, incrementing event->argc
|
||||
// as we go.
|
||||
for (s32 i = 0; i < ARGLEN; i++) {
|
||||
if (!(&ctx->argv[i])) {
|
||||
goto out;
|
||||
}
|
||||
|
||||
// Copying the arg into it's own variable before copying it into
|
||||
// event->argv[i] prevents memory corruption.
|
||||
const u8 *argp = NULL;
|
||||
ret = bpf_probe_read_user(&argp, sizeof(argp), &ctx->argv[i]);
|
||||
if (ret || !argp) {
|
||||
goto out;
|
||||
}
|
||||
|
||||
// Copy argp to event->argv[i].
|
||||
ret = bpf_probe_read_user_str(event->argv[i], sizeof(event->argv[i]), argp);
|
||||
if (ret < 0) {
|
||||
bpf_printk("read argv %d: %d", i, ret);
|
||||
goto out;
|
||||
}
|
||||
|
||||
event->argc++;
|
||||
}
|
||||
|
||||
// This won't get hit if we `goto out` in the loop above. This is to signify
|
||||
// to userspace that we couldn't copy all of the arguments because it
|
||||
// exceeded ARGLEN.
|
||||
event->argc++;
|
||||
|
||||
out:
|
||||
// Write the event to the ring buffer and notify userspace. This will cause
|
||||
// the `Read()` call in userspace to return if it was blocked.
|
||||
bpf_ringbuf_submit(event, 0);
|
||||
|
||||
return 0;
|
||||
}
|
||||
186
service/firewall/interception/ebpf/programs/monitor.c
Normal file
186
service/firewall/interception/ebpf/programs/monitor.c
Normal file
@@ -0,0 +1,186 @@
|
||||
#include "vmlinux-x86.h"
|
||||
#include "bpf/bpf_helpers.h"
|
||||
#include "bpf/bpf_tracing.h"
|
||||
|
||||
// IP Version
|
||||
#define AF_INET 2
|
||||
#define AF_INET6 10
|
||||
|
||||
// Protocols
|
||||
#define TCP 6
|
||||
#define UDP 17
|
||||
#define UDPLite 136
|
||||
|
||||
#define OUTBOUND 0
|
||||
#define INBOUND 1
|
||||
|
||||
char __license[] SEC("license") = "GPL";
|
||||
|
||||
// Ring buffer for all connection events
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_RINGBUF);
|
||||
__uint(max_entries, 1 << 24);
|
||||
} pm_connection_events SEC(".maps");
|
||||
|
||||
// Event struct that will be sent to Go on each new connection. (The name should be the same as the go generate command)
|
||||
struct Event {
|
||||
u32 saddr[4];
|
||||
u32 daddr[4];
|
||||
u16 sport;
|
||||
u16 dport;
|
||||
u32 pid;
|
||||
u8 ipVersion;
|
||||
u8 protocol;
|
||||
u8 direction;
|
||||
};
|
||||
struct Event *unused __attribute__((unused));
|
||||
|
||||
// Fentry of tcp_connect will be executed when equivalent kernel function is called.
|
||||
// In the kernel all IP address and ports should be set before tcp_connect is called. [this-function] -> tcp_connect
|
||||
SEC("fentry/tcp_connect")
|
||||
int BPF_PROG(tcp_connect, struct sock *sk) {
|
||||
// Alloc space for the event
|
||||
struct Event *tcp_info;
|
||||
tcp_info = bpf_ringbuf_reserve(&pm_connection_events, sizeof(struct Event), 0);
|
||||
if (!tcp_info) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Read PID (Careful: This is the Thread Group ID in kernel speak!)
|
||||
tcp_info->pid = __builtin_bswap32((u32)(bpf_get_current_pid_tgid() >> 32));
|
||||
|
||||
// Set protocol
|
||||
tcp_info->protocol = TCP;
|
||||
|
||||
// Set direction
|
||||
tcp_info->direction = OUTBOUND;
|
||||
|
||||
// Set src and dist ports
|
||||
tcp_info->sport = __builtin_bswap16(sk->__sk_common.skc_num);
|
||||
tcp_info->dport = sk->__sk_common.skc_dport;
|
||||
|
||||
// Set src and dist IPs
|
||||
if (sk->__sk_common.skc_family == AF_INET) {
|
||||
tcp_info->saddr[0] = __builtin_bswap32(sk->__sk_common.skc_rcv_saddr);
|
||||
tcp_info->daddr[0] = __builtin_bswap32(sk->__sk_common.skc_daddr);
|
||||
// Set IP version
|
||||
tcp_info->ipVersion = 4;
|
||||
} else if (sk->__sk_common.skc_family == AF_INET6) {
|
||||
for(int i = 0; i < 4; i++) {
|
||||
tcp_info->saddr[i] = __builtin_bswap32(sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[i]);
|
||||
}
|
||||
for(int i = 0; i < 4; i++) {
|
||||
tcp_info->daddr[i] = __builtin_bswap32(sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[i]);
|
||||
}
|
||||
// Set IP version
|
||||
tcp_info->ipVersion = 6;
|
||||
}
|
||||
|
||||
// Send event
|
||||
bpf_ringbuf_submit(tcp_info, 0);
|
||||
return 0;
|
||||
};
|
||||
|
||||
// Fexit(function exit) of udp_v4_connect will be executed after the ip4_datagram_connect kernel function is called.
|
||||
// ip4_datagram_connect -> udp_v4_connect
|
||||
SEC("fexit/ip4_datagram_connect")
|
||||
int BPF_PROG(udp_v4_connect, struct sock *sk) {
|
||||
// Ignore everything else then IPv4
|
||||
if (sk->__sk_common.skc_family != AF_INET) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ip4_datagram_connect return error
|
||||
if (sk->__sk_common.skc_dport == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Allocate space for the event.
|
||||
struct Event *udp_info;
|
||||
udp_info = bpf_ringbuf_reserve(&pm_connection_events, sizeof(struct Event), 0);
|
||||
if (!udp_info) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Read PID (Careful: This is the Thread Group ID in kernel speak!)
|
||||
udp_info->pid = __builtin_bswap32((u32)(bpf_get_current_pid_tgid() >> 32));
|
||||
|
||||
// Set src and dst ports
|
||||
udp_info->sport = __builtin_bswap16(sk->__sk_common.skc_num);
|
||||
udp_info->dport = sk->__sk_common.skc_dport;
|
||||
|
||||
// Set src and dst IPs
|
||||
udp_info->saddr[0] = __builtin_bswap32(sk->__sk_common.skc_rcv_saddr);
|
||||
udp_info->daddr[0] = __builtin_bswap32(sk->__sk_common.skc_daddr);
|
||||
|
||||
// Set IP version
|
||||
udp_info->ipVersion = 4;
|
||||
|
||||
// Set protocol
|
||||
if(sk->sk_protocol == IPPROTO_UDPLITE) {
|
||||
udp_info->protocol = UDPLite;
|
||||
} else {
|
||||
udp_info->protocol = UDP;
|
||||
}
|
||||
|
||||
// Send event
|
||||
bpf_ringbuf_submit(udp_info, 0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Fentry(function enter) of udp_v6_connect will be executed after the ip6_datagram_connect kernel function is called.
|
||||
// ip6_datagram_connect -> udp_v6_connect
|
||||
SEC("fexit/ip6_datagram_connect")
|
||||
int BPF_PROG(udp_v6_connect, struct sock *sk) {
|
||||
// Ignore everything else then IPv6
|
||||
if (sk->__sk_common.skc_family != AF_INET6) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ip6_datagram_connect return error
|
||||
if (sk->__sk_common.skc_dport == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Make sure its udp6 socket
|
||||
struct udp6_sock *us = bpf_skc_to_udp6_sock(sk);
|
||||
if (!us) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Allocate space for the event.
|
||||
struct Event *udp_info;
|
||||
udp_info = bpf_ringbuf_reserve(&pm_connection_events, sizeof(struct Event), 0);
|
||||
if (!udp_info) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Read PID (Careful: This is the Thread Group ID in kernel speak!)
|
||||
udp_info->pid = __builtin_bswap32((u32)(bpf_get_current_pid_tgid() >> 32));
|
||||
|
||||
// Set src and dst ports
|
||||
udp_info->sport = __builtin_bswap16(sk->__sk_common.skc_num);
|
||||
udp_info->dport = sk->__sk_common.skc_dport;
|
||||
|
||||
// Set src and dst IPs
|
||||
for(int i = 0; i < 4; i++) {
|
||||
udp_info->saddr[i] = __builtin_bswap32(sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[i]);
|
||||
}
|
||||
for(int i = 0; i < 4; i++) {
|
||||
udp_info->daddr[i] = __builtin_bswap32(sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[i]);
|
||||
}
|
||||
|
||||
// IP version
|
||||
udp_info->ipVersion = 6;
|
||||
|
||||
// Set protocol
|
||||
if(sk->sk_protocol == IPPROTO_UDPLITE) {
|
||||
udp_info->protocol = UDPLite;
|
||||
} else {
|
||||
udp_info->protocol = UDP;
|
||||
}
|
||||
|
||||
// Send event
|
||||
bpf_ringbuf_submit(udp_info, 0);
|
||||
return 0;
|
||||
}
|
||||
17
service/firewall/interception/ebpf/programs/update.sh
Executable file
17
service/firewall/interception/ebpf/programs/update.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Version of libbpf to fetch headers from
|
||||
LIBBPF_VERSION=1.2.0
|
||||
|
||||
# The headers we want
|
||||
prefix=libbpf-"$LIBBPF_VERSION"
|
||||
headers=(
|
||||
"$prefix"/src/bpf_core_read.h
|
||||
"$prefix"/src/bpf_helper_defs.h
|
||||
"$prefix"/src/bpf_helpers.h
|
||||
"$prefix"/src/bpf_tracing.h
|
||||
)
|
||||
|
||||
# Fetch libbpf release and extract the desired headers
|
||||
curl -sL "https://github.com/libbpf/libbpf/archive/refs/tags/v${LIBBPF_VERSION}.tar.gz" | \
|
||||
tar -xz --xform='s#.*/#bpf/#' "${headers[@]}"
|
||||
120348
service/firewall/interception/ebpf/programs/vmlinux-x86.h
Normal file
120348
service/firewall/interception/ebpf/programs/vmlinux-x86.h
Normal file
File diff suppressed because it is too large
Load Diff
30
service/firewall/interception/interception_default.go
Normal file
30
service/firewall/interception/interception_default.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build !windows && !linux
|
||||
|
||||
package interception
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
// start starts the interception.
|
||||
func startInterception(_ chan packet.Packet) error {
|
||||
log.Critical("interception: this platform has no support for packet interception - a lot of functionality will be broken")
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop starts the interception.
|
||||
func stopInterception() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again.
|
||||
func ResetVerdictOfAllConnections() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateVerdictOfConnection updates the verdict of the given connection in the OS integration.
|
||||
func UpdateVerdictOfConnection(conn *network.Connection) error {
|
||||
return nil
|
||||
}
|
||||
48
service/firewall/interception/interception_linux.go
Normal file
48
service/firewall/interception/interception_linux.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package interception
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
bandwidth "github.com/safing/portmaster/service/firewall/interception/ebpf/bandwidth"
|
||||
conn_listener "github.com/safing/portmaster/service/firewall/interception/ebpf/connection_listener"
|
||||
"github.com/safing/portmaster/service/firewall/interception/nfq"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
// start starts the interception.
|
||||
func startInterception(packets chan packet.Packet) error {
|
||||
// Start packet interception via nfqueue.
|
||||
err := StartNfqueueInterception(packets)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start ebpf new connection listener.
|
||||
module.StartServiceWorker("ebpf connection listener", 0, func(ctx context.Context) error {
|
||||
return conn_listener.ConnectionListenerWorker(ctx, packets)
|
||||
})
|
||||
|
||||
// Start ebpf bandwidth stats monitor.
|
||||
module.StartServiceWorker("ebpf bandwidth stats monitor", 0, func(ctx context.Context) error {
|
||||
return bandwidth.BandwidthStatsWorker(ctx, 1*time.Second, BandwidthUpdates)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop starts the interception.
|
||||
func stopInterception() error {
|
||||
return StopNfqueueInterception()
|
||||
}
|
||||
|
||||
// ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again.
|
||||
func ResetVerdictOfAllConnections() error {
|
||||
return nfq.DeleteAllMarkedConnection()
|
||||
}
|
||||
|
||||
// UpdateVerdictOfConnection deletes the verdict of the given connection so it can be initialized again with the next packet.
|
||||
func UpdateVerdictOfConnection(conn *network.Connection) error {
|
||||
return nfq.DeleteMarkedConnection(conn)
|
||||
}
|
||||
68
service/firewall/interception/interception_windows.go
Normal file
68
service/firewall/interception/interception_windows.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package interception
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/firewall/interception/windowskext"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
)
|
||||
|
||||
// start starts the interception.
|
||||
func startInterception(packets chan packet.Packet) error {
|
||||
kextFile, err := updates.GetPlatformFile("kext/portmaster-kext.sys")
|
||||
if err != nil {
|
||||
return fmt.Errorf("interception: could not get kext sys: %s", err)
|
||||
}
|
||||
|
||||
err = windowskext.Init(kextFile.Path())
|
||||
if err != nil {
|
||||
return fmt.Errorf("interception: could not init windows kext: %s", err)
|
||||
}
|
||||
|
||||
err = windowskext.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("interception: could not start windows kext: %s", err)
|
||||
}
|
||||
|
||||
// Start packet handler.
|
||||
module.StartServiceWorker("kext packet handler", 0, func(ctx context.Context) error {
|
||||
windowskext.Handler(ctx, packets)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start bandwidth stats monitor.
|
||||
module.StartServiceWorker("kext bandwidth stats monitor", 0, func(ctx context.Context) error {
|
||||
return windowskext.BandwidthStatsWorker(ctx, 1*time.Second, BandwidthUpdates)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop starts the interception.
|
||||
func stopInterception() error {
|
||||
return windowskext.Stop()
|
||||
}
|
||||
|
||||
// ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again.
|
||||
func ResetVerdictOfAllConnections() error {
|
||||
return windowskext.ClearCache()
|
||||
}
|
||||
|
||||
// UpdateVerdictOfConnection updates the verdict of the given connection in the kernel extension.
|
||||
func UpdateVerdictOfConnection(conn *network.Connection) error {
|
||||
return windowskext.UpdateVerdict(conn)
|
||||
}
|
||||
|
||||
// GetKextVersion returns the version of the kernel extension.
|
||||
func GetKextVersion() (string, error) {
|
||||
version, err := windowskext.GetVersion()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return version.String(), nil
|
||||
}
|
||||
80
service/firewall/interception/introspection.go
Normal file
80
service/firewall/interception/introspection.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package interception
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
var (
|
||||
packetMetricsDestination string
|
||||
metrics = &packetMetrics{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&packetMetricsDestination, "write-packet-metrics", "", "write packet metrics to the specified file")
|
||||
}
|
||||
|
||||
type (
|
||||
performanceRecord struct {
|
||||
start int64
|
||||
duration time.Duration
|
||||
verdict string
|
||||
}
|
||||
|
||||
packetMetrics struct {
|
||||
done chan struct{}
|
||||
l sync.Mutex
|
||||
records []*performanceRecord
|
||||
}
|
||||
)
|
||||
|
||||
func (pm *packetMetrics) record(tp *tracedPacket, verdict string) {
|
||||
go func(start int64, duration time.Duration) {
|
||||
pm.l.Lock()
|
||||
defer pm.l.Unlock()
|
||||
|
||||
pm.records = append(pm.records, &performanceRecord{
|
||||
start: start,
|
||||
duration: duration,
|
||||
verdict: verdict,
|
||||
})
|
||||
}(tp.start.UnixNano(), time.Since(tp.start))
|
||||
}
|
||||
|
||||
func (pm *packetMetrics) writeMetrics() {
|
||||
if packetMetricsDestination == "" {
|
||||
return
|
||||
}
|
||||
|
||||
f, err := os.Create(packetMetricsDestination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create packet metrics file: %s", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.done:
|
||||
return
|
||||
case <-time.After(time.Second * 5):
|
||||
}
|
||||
pm.l.Lock()
|
||||
records := pm.records
|
||||
pm.records = nil
|
||||
pm.l.Unlock()
|
||||
|
||||
for _, r := range records {
|
||||
fmt.Fprintf(f, "%d;%s;%s;%.2f\n", r.start, r.verdict, r.duration, float64(r.duration)/float64(time.Microsecond))
|
||||
}
|
||||
}
|
||||
}
|
||||
63
service/firewall/interception/module.go
Normal file
63
service/firewall/interception/module.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package interception
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
// Packets is a stream of interception network packest.
|
||||
Packets = make(chan packet.Packet, 1000)
|
||||
|
||||
// BandwidthUpdates is a stream of bandwidth usage update for connections.
|
||||
BandwidthUpdates = make(chan *packet.BandwidthUpdate, 1000)
|
||||
|
||||
disableInterception bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&disableInterception, "disable-interception", false, "disable packet interception; this breaks a lot of functionality")
|
||||
|
||||
module = modules.Register("interception", prep, start, stop, "base", "updates", "network", "notifications", "profiles")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts the interception.
|
||||
func start() error {
|
||||
if disableInterception {
|
||||
log.Warning("interception: packet interception is disabled via flag - this breaks a lot of functionality")
|
||||
return nil
|
||||
}
|
||||
|
||||
inputPackets := Packets
|
||||
if packetMetricsDestination != "" {
|
||||
go metrics.writeMetrics()
|
||||
inputPackets = make(chan packet.Packet)
|
||||
go func() {
|
||||
for p := range inputPackets {
|
||||
Packets <- tracePacket(p)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return startInterception(inputPackets)
|
||||
}
|
||||
|
||||
// Stop starts the interception.
|
||||
func stop() error {
|
||||
if disableInterception {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(metrics.done)
|
||||
|
||||
return stopInterception()
|
||||
}
|
||||
125
service/firewall/interception/nfq/conntrack.go
Normal file
125
service/firewall/interception/nfq/conntrack.go
Normal file
@@ -0,0 +1,125 @@
|
||||
//go:build linux
|
||||
|
||||
package nfq
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
ct "github.com/florianl/go-conntrack"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
)
|
||||
|
||||
var nfct *ct.Nfct // Conntrack handler. NFCT: Network Filter Connection Tracking.
|
||||
|
||||
// InitNFCT initializes the network filter conntrack library.
|
||||
func InitNFCT() error {
|
||||
var err error
|
||||
nfct, err = ct.Open(&ct.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TeardownNFCT deinitializes the network filter conntrack library.
|
||||
func TeardownNFCT() {
|
||||
if nfct != nil {
|
||||
_ = nfct.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteAllMarkedConnection deletes all marked entries from the conntrack table.
|
||||
func DeleteAllMarkedConnection() error {
|
||||
if nfct == nil {
|
||||
return fmt.Errorf("nfq: nfct not initialized")
|
||||
}
|
||||
|
||||
// Delete all ipv4 marked connections
|
||||
deleted := deleteMarkedConnections(nfct, ct.IPv4)
|
||||
|
||||
if netenv.IPv6Enabled() {
|
||||
// Delete all ipv6 marked connections
|
||||
deleted += deleteMarkedConnections(nfct, ct.IPv6)
|
||||
}
|
||||
|
||||
log.Infof("nfq: deleted %d conntrack entries to reset permanent connection verdicts", deleted)
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteMarkedConnections(nfct *ct.Nfct, f ct.Family) (deleted int) {
|
||||
// initialize variables
|
||||
permanentFlags := []uint32{MarkAcceptAlways, MarkBlockAlways, MarkDropAlways, MarkRerouteNS, MarkRerouteSPN}
|
||||
filter := ct.FilterAttr{}
|
||||
filter.MarkMask = []byte{0xFF, 0xFF, 0xFF, 0xFF}
|
||||
filter.Mark = []byte{0x00, 0x00, 0x00, 0x00} // 4 zeros starting value
|
||||
|
||||
numberOfErrors := 0
|
||||
var deleteError error = nil
|
||||
// Get all connections from the specified family (ipv4 or ipv6)
|
||||
for _, mark := range permanentFlags {
|
||||
binary.BigEndian.PutUint32(filter.Mark, mark) // Little endian is in reverse not sure why. BigEndian makes it in correct order.
|
||||
currentConnections, err := nfct.Query(ct.Conntrack, f, filter)
|
||||
if err != nil {
|
||||
log.Warningf("nfq: error on conntrack query: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, connection := range currentConnections {
|
||||
deleteError = nfct.Delete(ct.Conntrack, ct.IPv4, connection)
|
||||
if err != nil {
|
||||
numberOfErrors++
|
||||
} else {
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if numberOfErrors > 0 {
|
||||
log.Warningf("nfq: failed to delete %d conntrack entries last error is: %s", numberOfErrors, deleteError)
|
||||
}
|
||||
return deleted
|
||||
}
|
||||
|
||||
// DeleteMarkedConnection removes a specific connection from the conntrack table.
|
||||
func DeleteMarkedConnection(conn *network.Connection) error {
|
||||
if nfct == nil {
|
||||
return fmt.Errorf("nfq: nfct not initialized")
|
||||
}
|
||||
|
||||
con := ct.Con{
|
||||
Origin: &ct.IPTuple{
|
||||
Src: &conn.LocalIP,
|
||||
Dst: &conn.Entity.IP,
|
||||
Proto: &ct.ProtoTuple{
|
||||
Number: &conn.Entity.Protocol,
|
||||
SrcPort: &conn.LocalPort,
|
||||
DstPort: &conn.Entity.Port,
|
||||
},
|
||||
},
|
||||
}
|
||||
connections, err := nfct.Get(ct.Conntrack, ct.IPv4, con)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nfq: failed to find entry for connection %s: %w", conn.String(), err)
|
||||
}
|
||||
|
||||
if len(connections) > 1 {
|
||||
log.Warningf("nfq: multiple entries found for single connection: %s -> %d", conn.String(), len(connections))
|
||||
}
|
||||
|
||||
for _, connection := range connections {
|
||||
deleteErr := nfct.Delete(ct.Conntrack, ct.IPv4, connection)
|
||||
if err == nil {
|
||||
err = deleteErr
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warningf("nfq: error while deleting conntrack entries for connection %s: %s", conn.String(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
241
service/firewall/interception/nfq/nfq.go
Normal file
241
service/firewall/interception/nfq/nfq.go
Normal file
@@ -0,0 +1,241 @@
|
||||
//go:build linux
|
||||
|
||||
// Package nfq contains a nfqueue library experiment.
|
||||
package nfq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/florianl/go-nfqueue"
|
||||
"github.com/tevino/abool"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
pmpacket "github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
)
|
||||
|
||||
// Queue wraps a nfqueue.
|
||||
type Queue struct {
|
||||
id uint16
|
||||
afFamily uint8
|
||||
nf atomic.Value
|
||||
packets chan pmpacket.Packet
|
||||
cancelSocketCallback context.CancelFunc
|
||||
restart chan struct{}
|
||||
|
||||
pendingVerdicts uint64
|
||||
verdictCompleted chan struct{}
|
||||
}
|
||||
|
||||
func (q *Queue) getNfq() *nfqueue.Nfqueue {
|
||||
return q.nf.Load().(*nfqueue.Nfqueue) //nolint:forcetypeassert // TODO: Check.
|
||||
}
|
||||
|
||||
// New opens a new nfQueue.
|
||||
func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
|
||||
afFamily := unix.AF_INET
|
||||
if v6 {
|
||||
afFamily = unix.AF_INET6
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
q := &Queue{
|
||||
id: qid,
|
||||
afFamily: uint8(afFamily),
|
||||
nf: atomic.Value{},
|
||||
restart: make(chan struct{}, 1),
|
||||
packets: make(chan pmpacket.Packet, 1000),
|
||||
cancelSocketCallback: cancel,
|
||||
verdictCompleted: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
// Do not retry if the first one fails immediately as it
|
||||
// might point to a deeper integration error that's not fixable
|
||||
// with retrying ...
|
||||
if err := q.open(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
Wait:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-q.restart:
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
for {
|
||||
err := q.open(ctx)
|
||||
if err == nil {
|
||||
continue Wait
|
||||
}
|
||||
|
||||
// Wait 100 ms and then try again ...
|
||||
log.Errorf("Failed to open nfqueue: %s", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// open opens a new netlink socket and creates a new nfqueue.
|
||||
// Upon success, the new nfqueue is atomically stored in Queue.nf.
|
||||
// Users must use Queue.getNfq to access it. open does not care about
|
||||
// any other value or queue that might be stored in Queue.nf at
|
||||
// the time open is called.
|
||||
func (q *Queue) open(ctx context.Context) error {
|
||||
cfg := &nfqueue.Config{
|
||||
NfQueue: q.id,
|
||||
MaxPacketLen: 1600, // mtu is normally around 1500, make sure to capture it.
|
||||
MaxQueueLen: 0xffff,
|
||||
AfFamily: q.afFamily,
|
||||
Copymode: nfqueue.NfQnlCopyPacket,
|
||||
ReadTimeout: 1000 * time.Millisecond,
|
||||
WriteTimeout: 1000 * time.Millisecond,
|
||||
}
|
||||
|
||||
nf, err := nfqueue.Open(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := nf.RegisterWithErrorFunc(ctx, q.packetHandler(ctx), q.handleError); err != nil {
|
||||
_ = nf.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
q.nf.Store(nf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Queue) handleError(e error) int {
|
||||
// embedded interface is required to work-around some
|
||||
// dep-vendoring weirdness
|
||||
if opError, ok := e.(interface { //nolint:errorlint // TODO: Check if we can remove workaround.
|
||||
Timeout() bool
|
||||
Temporary() bool
|
||||
}); ok {
|
||||
if opError.Timeout() || opError.Temporary() {
|
||||
c := atomic.LoadUint64(&q.pendingVerdicts)
|
||||
if c > 0 {
|
||||
log.Tracef("nfqueue: waiting for %d pending verdicts", c)
|
||||
|
||||
for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here
|
||||
<-q.verdictCompleted
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the queue was already closed. Unfortunately, the exposed error
|
||||
// variable is in an internal stdlib package. Therefore, check for the error
|
||||
// string instead. :(
|
||||
// Official error variable is defined here:
|
||||
// https://github.com/golang/go/blob/0e85fd7561de869add933801c531bf25dee9561c/src/internal/poll/fd.go#L24
|
||||
if !strings.HasSuffix(e.Error(), "use of closed file") {
|
||||
log.Errorf("nfqueue: encountered error while receiving packets: %s\n", e.Error())
|
||||
}
|
||||
|
||||
// Close the existing socket
|
||||
if nf := q.getNfq(); nf != nil {
|
||||
// Call Close() on the Con directly, as nf.Close() calls waitgroup.Wait(), which then may deadlock.
|
||||
_ = nf.Con.Close()
|
||||
}
|
||||
|
||||
// Trigger a restart of the queue
|
||||
q.restart <- struct{}{}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int {
|
||||
return func(attrs nfqueue.Attribute) int {
|
||||
if attrs.PacketID == nil {
|
||||
// we need a packet id to set a verdict,
|
||||
// if we don't get an ID there's hardly anything
|
||||
// we can do.
|
||||
return 0
|
||||
}
|
||||
|
||||
pkt := &packet{
|
||||
pktID: *attrs.PacketID,
|
||||
queue: q,
|
||||
verdictSet: make(chan struct{}),
|
||||
verdictPending: abool.New(),
|
||||
}
|
||||
pkt.Info().PID = process.UndefinedProcessID
|
||||
pkt.Info().SeenAt = time.Now()
|
||||
|
||||
if attrs.Payload == nil {
|
||||
// There is not payload.
|
||||
log.Warningf("nfqueue: packet #%d has no payload", pkt.pktID)
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := pmpacket.Parse(*attrs.Payload, &pkt.Base); err != nil {
|
||||
log.Warningf("nfqueue: failed to parse payload: %s", err)
|
||||
_ = pkt.Drop()
|
||||
return 0
|
||||
}
|
||||
|
||||
select {
|
||||
case q.packets <- pkt:
|
||||
// DEBUG:
|
||||
// log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt))
|
||||
case <-ctx.Done():
|
||||
return 0
|
||||
case <-time.After(time.Second):
|
||||
log.Warningf("nfqueue: failed to queue packet (%s since it was handed over by the kernel)", time.Since(pkt.Info().SeenAt))
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-pkt.verdictSet:
|
||||
|
||||
case <-time.After(20 * time.Second):
|
||||
log.Warningf("nfqueue: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt))
|
||||
if err := pkt.Drop(); err != nil {
|
||||
log.Warningf("nfqueue: failed to apply default-drop to unveridcted packet %s (%s -> %s)", pkt.ID(), pkt.Info().Src, pkt.Info().Dst)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return 0 // continue calling this fn
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy destroys the queue. Any error encountered is logged.
|
||||
func (q *Queue) Destroy() {
|
||||
if q == nil {
|
||||
return
|
||||
}
|
||||
|
||||
q.cancelSocketCallback()
|
||||
|
||||
if nf := q.getNfq(); nf != nil {
|
||||
if err := nf.Close(); err != nil {
|
||||
log.Errorf("nfqueue: failed to close queue %d: %s", q.id, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PacketChannel returns the packet channel.
|
||||
func (q *Queue) PacketChannel() <-chan pmpacket.Packet {
|
||||
return q.packets
|
||||
}
|
||||
178
service/firewall/interception/nfq/packet.go
Normal file
178
service/firewall/interception/nfq/packet.go
Normal file
@@ -0,0 +1,178 @@
|
||||
//go:build linux
|
||||
|
||||
package nfq
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/florianl/go-nfqueue"
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
pmpacket "github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
// Firewalling marks used by the Portmaster.
|
||||
// See TODO on packet.mark() on their relevance
|
||||
// and a possibility to remove most IPtables rules.
|
||||
const (
|
||||
MarkAccept = 1700
|
||||
MarkBlock = 1701
|
||||
MarkDrop = 1702
|
||||
MarkAcceptAlways = 1710
|
||||
MarkBlockAlways = 1711
|
||||
MarkDropAlways = 1712
|
||||
MarkRerouteNS = 1799
|
||||
MarkRerouteSPN = 1717
|
||||
)
|
||||
|
||||
func markToString(mark int) string {
|
||||
switch mark {
|
||||
case MarkAccept:
|
||||
return "Accept"
|
||||
case MarkBlock:
|
||||
return "Block"
|
||||
case MarkDrop:
|
||||
return "Drop"
|
||||
case MarkAcceptAlways:
|
||||
return "AcceptAlways"
|
||||
case MarkBlockAlways:
|
||||
return "BlockAlways"
|
||||
case MarkDropAlways:
|
||||
return "DropAlways"
|
||||
case MarkRerouteNS:
|
||||
return "RerouteNS"
|
||||
case MarkRerouteSPN:
|
||||
return "RerouteSPN"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// packet implements the packet.Packet interface.
|
||||
type packet struct {
|
||||
pmpacket.Base
|
||||
pktID uint32
|
||||
queue *Queue
|
||||
verdictSet chan struct{}
|
||||
verdictPending *abool.AtomicBool
|
||||
}
|
||||
|
||||
func (pkt *packet) ID() string {
|
||||
return fmt.Sprintf("pkt:%d qid:%d", pkt.pktID, pkt.queue.id)
|
||||
}
|
||||
|
||||
// LoadPacketData does nothing on Linux, as data is always fully parsed.
|
||||
func (pkt *packet) LoadPacketData() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(ppacher): revisit the following behavior:
|
||||
//
|
||||
// The legacy implementation of nfqueue (and the interception) module
|
||||
// always accept a packet but may mark it so that a subsequent rule in
|
||||
// the C17 chain drops, rejects or modifies it.
|
||||
//
|
||||
// For drop/return we could use the actual nfQueue verdicts Drop and Stop.
|
||||
// Re-routing to local NS or SPN can be done by modifying the packet here
|
||||
// and using SetVerdictModPacket and reject can be implemented using a simple
|
||||
// raw-socket.
|
||||
func (pkt *packet) mark(mark int) (err error) {
|
||||
if pkt.verdictPending.SetToIf(false, true) {
|
||||
defer close(pkt.verdictSet)
|
||||
return pkt.setMark(mark)
|
||||
}
|
||||
|
||||
return errors.New("verdict already set")
|
||||
}
|
||||
|
||||
func (pkt *packet) setMark(mark int) error {
|
||||
atomic.AddUint64(&pkt.queue.pendingVerdicts, 1)
|
||||
|
||||
defer func() {
|
||||
atomic.AddUint64(&pkt.queue.pendingVerdicts, ^uint64(0))
|
||||
select {
|
||||
case pkt.queue.verdictCompleted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if err := pkt.queue.getNfq().SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil {
|
||||
// embedded interface is required to work-around some
|
||||
// dep-vendoring weirdness
|
||||
if opErr, ok := err.(interface { //nolint:errorlint // TODO: Check if we can remove workaround.
|
||||
Timeout() bool
|
||||
Temporary() bool
|
||||
}); ok {
|
||||
if opErr.Timeout() || opErr.Temporary() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
log.Tracer(pkt.Ctx()).Errorf("nfqueue: failed to set verdict %s for %s (%s -> %s): %s", markToString(mark), pkt.ID(), pkt.Info().Src, pkt.Info().Dst, err)
|
||||
return err
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// DEBUG:
|
||||
// log.Tracer(pkt.Ctx()).Tracef(
|
||||
// "nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s",
|
||||
// pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id,
|
||||
// markToString(mark), time.Since(pkt.Info().SeenAt),
|
||||
// )
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pkt *packet) Accept() error {
|
||||
return pkt.mark(MarkAccept)
|
||||
}
|
||||
|
||||
func (pkt *packet) Block() error {
|
||||
if pkt.Info().Protocol == pmpacket.ICMP {
|
||||
// ICMP packets attributed to a blocked connection are always allowed, as
|
||||
// rejection ICMP packets will have the same mark as the blocked
|
||||
// connection. This is why we need to drop blocked ICMP packets instead.
|
||||
return pkt.mark(MarkDrop)
|
||||
}
|
||||
return pkt.mark(MarkBlock)
|
||||
}
|
||||
|
||||
func (pkt *packet) Drop() error {
|
||||
return pkt.mark(MarkDrop)
|
||||
}
|
||||
|
||||
func (pkt *packet) PermanentAccept() error {
|
||||
// If the packet is localhost only, do not permanently accept the outgoing
|
||||
// packet, as the packet mark will be copied to the connection mark, which
|
||||
// will stick and it will bypass the incoming queue.
|
||||
if !pkt.Info().Inbound && pkt.Info().Dst.IsLoopback() {
|
||||
return pkt.Accept()
|
||||
}
|
||||
|
||||
return pkt.mark(MarkAcceptAlways)
|
||||
}
|
||||
|
||||
func (pkt *packet) PermanentBlock() error {
|
||||
if pkt.Info().Protocol == pmpacket.ICMP || pkt.Info().Protocol == pmpacket.ICMPv6 {
|
||||
// ICMP packets attributed to a blocked connection are always allowed, as
|
||||
// rejection ICMP packets will have the same mark as the blocked
|
||||
// connection. This is why we need to drop blocked ICMP packets instead.
|
||||
return pkt.mark(MarkDropAlways)
|
||||
}
|
||||
return pkt.mark(MarkBlockAlways)
|
||||
}
|
||||
|
||||
func (pkt *packet) PermanentDrop() error {
|
||||
return pkt.mark(MarkDropAlways)
|
||||
}
|
||||
|
||||
func (pkt *packet) RerouteToNameserver() error {
|
||||
return pkt.mark(MarkRerouteNS)
|
||||
}
|
||||
|
||||
func (pkt *packet) RerouteToTunnel() error {
|
||||
return pkt.mark(MarkRerouteSPN)
|
||||
}
|
||||
354
service/firewall/interception/nfqueue_linux.go
Normal file
354
service/firewall/interception/nfqueue_linux.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package interception
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/firewall/interception/nfq"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
var (
|
||||
v4chains []string
|
||||
v4rules []string
|
||||
v4once []string
|
||||
|
||||
v6chains []string
|
||||
v6rules []string
|
||||
v6once []string
|
||||
|
||||
out4Queue nfQueue
|
||||
in4Queue nfQueue
|
||||
out6Queue nfQueue
|
||||
in6Queue nfQueue
|
||||
|
||||
shutdownSignal = make(chan struct{})
|
||||
|
||||
experimentalNfqueueBackend bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&experimentalNfqueueBackend, "experimental-nfqueue", false, "(deprecated flag; always used)")
|
||||
}
|
||||
|
||||
// nfQueue encapsulates nfQueue providers.
|
||||
type nfQueue interface {
|
||||
PacketChannel() <-chan packet.Packet
|
||||
Destroy()
|
||||
}
|
||||
|
||||
func init() {
|
||||
v4chains = []string{
|
||||
"mangle PORTMASTER-INGEST-OUTPUT",
|
||||
"mangle PORTMASTER-INGEST-INPUT",
|
||||
"filter PORTMASTER-FILTER",
|
||||
"nat PORTMASTER-REDIRECT",
|
||||
}
|
||||
|
||||
v4rules = []string{
|
||||
"mangle PORTMASTER-INGEST-OUTPUT -j CONNMARK --restore-mark",
|
||||
"mangle PORTMASTER-INGEST-OUTPUT -m mark --mark 0 -j NFQUEUE --queue-num 17040 --queue-bypass",
|
||||
|
||||
"mangle PORTMASTER-INGEST-INPUT -j CONNMARK --restore-mark",
|
||||
"mangle PORTMASTER-INGEST-INPUT -m mark --mark 0 -j NFQUEUE --queue-num 17140 --queue-bypass",
|
||||
|
||||
"filter PORTMASTER-FILTER -m mark --mark 0 -j DROP",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1700 -j RETURN",
|
||||
// Accepting ICMP packets with mark 1701 is required for rejecting to work,
|
||||
// as the rejection ICMP packet will have the same mark. Blocked ICMP
|
||||
// packets will always result in a drop within the Portmaster.
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1701 -p icmp -j RETURN",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1701 -j REJECT --reject-with icmp-admin-prohibited",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1702 -j DROP",
|
||||
"filter PORTMASTER-FILTER -j CONNMARK --save-mark",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1710 -j RETURN",
|
||||
// Accepting ICMP packets with mark 1711 is required for rejecting to work,
|
||||
// as the rejection ICMP packet will have the same mark. Blocked ICMP
|
||||
// packets will always result in a drop within the Portmaster.
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1711 -p icmp -j RETURN",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1711 -j REJECT --reject-with icmp-admin-prohibited",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1712 -j DROP",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1717 -j RETURN",
|
||||
|
||||
"nat PORTMASTER-REDIRECT -m mark --mark 1799 -p udp -j DNAT --to 127.0.0.17:53",
|
||||
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p tcp -j DNAT --to 127.0.0.17:717",
|
||||
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p udp -j DNAT --to 127.0.0.17:717",
|
||||
// "nat PORTMASTER-REDIRECT -m mark --mark 1717 ! -p tcp ! -p udp -j DNAT --to 127.0.0.17",
|
||||
}
|
||||
|
||||
v4once = []string{
|
||||
"mangle OUTPUT -j PORTMASTER-INGEST-OUTPUT",
|
||||
"mangle INPUT -j PORTMASTER-INGEST-INPUT",
|
||||
"filter OUTPUT -j PORTMASTER-FILTER",
|
||||
"filter INPUT -j PORTMASTER-FILTER",
|
||||
"nat OUTPUT -j PORTMASTER-REDIRECT",
|
||||
}
|
||||
|
||||
v6chains = []string{
|
||||
"mangle PORTMASTER-INGEST-OUTPUT",
|
||||
"mangle PORTMASTER-INGEST-INPUT",
|
||||
"filter PORTMASTER-FILTER",
|
||||
"nat PORTMASTER-REDIRECT",
|
||||
}
|
||||
|
||||
v6rules = []string{
|
||||
"mangle PORTMASTER-INGEST-OUTPUT -j CONNMARK --restore-mark",
|
||||
"mangle PORTMASTER-INGEST-OUTPUT -m mark --mark 0 -j NFQUEUE --queue-num 17060 --queue-bypass",
|
||||
|
||||
"mangle PORTMASTER-INGEST-INPUT -j CONNMARK --restore-mark",
|
||||
"mangle PORTMASTER-INGEST-INPUT -m mark --mark 0 -j NFQUEUE --queue-num 17160 --queue-bypass",
|
||||
|
||||
"filter PORTMASTER-FILTER -m mark --mark 0 -j DROP",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1700 -j RETURN",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1701 -p icmpv6 -j RETURN",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1701 -j REJECT --reject-with icmp6-adm-prohibited",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1702 -j DROP",
|
||||
"filter PORTMASTER-FILTER -j CONNMARK --save-mark",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1710 -j RETURN",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1711 -p icmpv6 -j RETURN",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1711 -j REJECT --reject-with icmp6-adm-prohibited",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1712 -j DROP",
|
||||
"filter PORTMASTER-FILTER -m mark --mark 1717 -j RETURN",
|
||||
|
||||
"nat PORTMASTER-REDIRECT -m mark --mark 1799 -p udp -j DNAT --to [::1]:53",
|
||||
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p tcp -j DNAT --to [::1]:717",
|
||||
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p udp -j DNAT --to [::1]:717",
|
||||
// "nat PORTMASTER-REDIRECT -m mark --mark 1717 ! -p tcp ! -p udp -j DNAT --to [::1]",
|
||||
}
|
||||
|
||||
v6once = []string{
|
||||
"mangle OUTPUT -j PORTMASTER-INGEST-OUTPUT",
|
||||
"mangle INPUT -j PORTMASTER-INGEST-INPUT",
|
||||
"filter OUTPUT -j PORTMASTER-FILTER",
|
||||
"filter INPUT -j PORTMASTER-FILTER",
|
||||
"nat OUTPUT -j PORTMASTER-REDIRECT",
|
||||
}
|
||||
|
||||
// Reverse because we'd like to insert in a loop
|
||||
_ = sort.Reverse(sort.StringSlice(v4once)) // silence vet (sort is used just like in the docs)
|
||||
_ = sort.Reverse(sort.StringSlice(v6once)) // silence vet (sort is used just like in the docs)
|
||||
}
|
||||
|
||||
func activateNfqueueFirewall() error {
|
||||
if err := activateIPTables(iptables.ProtocolIPv4, v4rules, v4once, v4chains); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if netenv.IPv6Enabled() {
|
||||
if err := activateIPTables(iptables.ProtocolIPv6, v6rules, v6once, v6chains); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := nfq.InitNFCT(); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = nfq.DeleteAllMarkedConnection()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeactivateNfqueueFirewall drops portmaster related IP tables rules.
|
||||
// Any errors encountered accumulated into a *multierror.Error.
|
||||
func DeactivateNfqueueFirewall() error {
|
||||
// IPv4
|
||||
var result *multierror.Error
|
||||
if err := deactivateIPTables(iptables.ProtocolIPv4, v4once, v4chains); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// IPv6
|
||||
if netenv.IPv6Enabled() {
|
||||
if err := deactivateIPTables(iptables.ProtocolIPv6, v6once, v6chains); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
_ = nfq.DeleteAllMarkedConnection()
|
||||
nfq.TeardownNFCT()
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func activateIPTables(protocol iptables.Protocol, rules, once, chains []string) error {
|
||||
tbls, err := iptables.NewWithProtocol(protocol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
splittedRule := strings.Split(chain, " ")
|
||||
if err = tbls.ClearChain(splittedRule[0], splittedRule[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
splittedRule := strings.Split(rule, " ")
|
||||
if err = tbls.Append(splittedRule[0], splittedRule[1], splittedRule[2:]...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range once {
|
||||
splittedRule := strings.Split(rule, " ")
|
||||
ok, err := tbls.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
if err = tbls.Insert(splittedRule[0], splittedRule[1], 1, splittedRule[2:]...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deactivateIPTables(protocol iptables.Protocol, rules, chains []string) error {
|
||||
tbls, err := iptables.NewWithProtocol(protocol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var multierr *multierror.Error
|
||||
|
||||
for _, rule := range rules {
|
||||
splittedRule := strings.Split(rule, " ")
|
||||
ok, err := tbls.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
|
||||
if err != nil {
|
||||
multierr = multierror.Append(multierr, err)
|
||||
}
|
||||
if ok {
|
||||
if err = tbls.Delete(splittedRule[0], splittedRule[1], splittedRule[2:]...); err != nil {
|
||||
multierr = multierror.Append(multierr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
splittedRule := strings.Split(chain, " ")
|
||||
if err = tbls.ClearChain(splittedRule[0], splittedRule[1]); err != nil {
|
||||
multierr = multierror.Append(multierr, err)
|
||||
}
|
||||
if err = tbls.DeleteChain(splittedRule[0], splittedRule[1]); err != nil {
|
||||
multierr = multierror.Append(multierr, err)
|
||||
}
|
||||
}
|
||||
|
||||
return multierr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// StartNfqueueInterception starts the nfqueue interception.
|
||||
func StartNfqueueInterception(packets chan<- packet.Packet) (err error) {
|
||||
// @deprecated, remove in v1
|
||||
if experimentalNfqueueBackend {
|
||||
log.Warningf("[DEPRECATED] --experimental-nfqueue has been deprecated as the backend is now used by default")
|
||||
log.Warningf("[DEPRECATED] please remove the flag from your configuration!")
|
||||
}
|
||||
|
||||
err = activateNfqueueFirewall()
|
||||
if err != nil {
|
||||
_ = StopNfqueueInterception()
|
||||
return fmt.Errorf("could not initialize nfqueue: %w", err)
|
||||
}
|
||||
|
||||
out4Queue, err = nfq.New(17040, false)
|
||||
if err != nil {
|
||||
_ = StopNfqueueInterception()
|
||||
return fmt.Errorf("nfqueue(IPv4, out): %w", err)
|
||||
}
|
||||
in4Queue, err = nfq.New(17140, false)
|
||||
if err != nil {
|
||||
_ = StopNfqueueInterception()
|
||||
return fmt.Errorf("nfqueue(IPv4, in): %w", err)
|
||||
}
|
||||
|
||||
if netenv.IPv6Enabled() {
|
||||
out6Queue, err = nfq.New(17060, true)
|
||||
if err != nil {
|
||||
_ = StopNfqueueInterception()
|
||||
return fmt.Errorf("nfqueue(IPv6, out): %w", err)
|
||||
}
|
||||
in6Queue, err = nfq.New(17160, true)
|
||||
if err != nil {
|
||||
_ = StopNfqueueInterception()
|
||||
return fmt.Errorf("nfqueue(IPv6, in): %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Warningf("interception: no IPv6 stack detected, disabling IPv6 network integration")
|
||||
out6Queue = &disabledNfQueue{}
|
||||
in6Queue = &disabledNfQueue{}
|
||||
}
|
||||
|
||||
module.StartServiceWorker("nfqueue packet handler", 0, func(_ context.Context) error {
|
||||
return handleInterception(packets)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopNfqueueInterception stops the nfqueue interception.
|
||||
func StopNfqueueInterception() error {
|
||||
defer close(shutdownSignal)
|
||||
|
||||
if out4Queue != nil {
|
||||
out4Queue.Destroy()
|
||||
}
|
||||
if in4Queue != nil {
|
||||
in4Queue.Destroy()
|
||||
}
|
||||
if out6Queue != nil {
|
||||
out6Queue.Destroy()
|
||||
}
|
||||
if in6Queue != nil {
|
||||
in6Queue.Destroy()
|
||||
}
|
||||
|
||||
err := DeactivateNfqueueFirewall()
|
||||
if err != nil {
|
||||
return fmt.Errorf("interception: error while deactivating nfqueue: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleInterception(packets chan<- packet.Packet) error {
|
||||
for {
|
||||
var pkt packet.Packet
|
||||
select {
|
||||
case <-shutdownSignal:
|
||||
return nil
|
||||
case pkt = <-out4Queue.PacketChannel():
|
||||
pkt.SetOutbound()
|
||||
case pkt = <-in4Queue.PacketChannel():
|
||||
pkt.SetInbound()
|
||||
case pkt = <-out6Queue.PacketChannel():
|
||||
pkt.SetOutbound()
|
||||
case pkt = <-in6Queue.PacketChannel():
|
||||
pkt.SetInbound()
|
||||
}
|
||||
|
||||
select {
|
||||
case packets <- pkt:
|
||||
case <-shutdownSignal:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type disabledNfQueue struct{}
|
||||
|
||||
func (dnfq *disabledNfQueue) PacketChannel() <-chan packet.Packet {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dnfq *disabledNfQueue) Destroy() {}
|
||||
67
service/firewall/interception/packet_tracer.go
Normal file
67
service/firewall/interception/packet_tracer.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package interception
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
type tracedPacket struct {
|
||||
start time.Time
|
||||
packet.Packet
|
||||
}
|
||||
|
||||
func tracePacket(p packet.Packet) packet.Packet {
|
||||
return &tracedPacket{
|
||||
start: time.Now(),
|
||||
Packet: p,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *tracedPacket) markServed(v string) {
|
||||
if packetMetricsDestination == "" {
|
||||
return
|
||||
}
|
||||
|
||||
metrics.record(p, v)
|
||||
}
|
||||
|
||||
func (p *tracedPacket) Accept() error {
|
||||
defer p.markServed("accept")
|
||||
return p.Packet.Accept()
|
||||
}
|
||||
|
||||
func (p *tracedPacket) Block() error {
|
||||
defer p.markServed("block")
|
||||
return p.Packet.Block()
|
||||
}
|
||||
|
||||
func (p *tracedPacket) Drop() error {
|
||||
defer p.markServed("drop")
|
||||
return p.Packet.Drop()
|
||||
}
|
||||
|
||||
func (p *tracedPacket) PermanentAccept() error {
|
||||
defer p.markServed("perm-accept")
|
||||
return p.Packet.PermanentAccept()
|
||||
}
|
||||
|
||||
func (p *tracedPacket) PermanentBlock() error {
|
||||
defer p.markServed("perm-block")
|
||||
return p.Packet.PermanentBlock()
|
||||
}
|
||||
|
||||
func (p *tracedPacket) PermanentDrop() error {
|
||||
defer p.markServed("perm-drop")
|
||||
return p.Packet.PermanentDrop()
|
||||
}
|
||||
|
||||
func (p *tracedPacket) RerouteToNameserver() error {
|
||||
defer p.markServed("reroute-ns")
|
||||
return p.Packet.RerouteToNameserver()
|
||||
}
|
||||
|
||||
func (p *tracedPacket) RerouteToTunnel() error {
|
||||
defer p.markServed("reroute-tunnel")
|
||||
return p.Packet.RerouteToTunnel()
|
||||
}
|
||||
132
service/firewall/interception/windowskext/bandwidth_stats.go
Normal file
132
service/firewall/interception/windowskext/bandwidth_stats.go
Normal file
@@ -0,0 +1,132 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
// This file contains example code how to read bandwidth stats from the kext. Its not ment to be used in production.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
type Rxtxdata struct {
|
||||
rx uint64
|
||||
tx uint64
|
||||
}
|
||||
|
||||
type Key struct {
|
||||
localIP [4]uint32
|
||||
remoteIP [4]uint32
|
||||
localPort uint16
|
||||
remotePort uint16
|
||||
ipv6 bool
|
||||
protocol uint8
|
||||
}
|
||||
|
||||
var m = make(map[Key]Rxtxdata)
|
||||
|
||||
func BandwidthStatsWorker(ctx context.Context, collectInterval time.Duration, bandwidthUpdates chan *packet.BandwidthUpdate) error {
|
||||
// Setup ticker.
|
||||
ticker := time.NewTicker(collectInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Collect bandwidth at every tick.
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := reportBandwidth(ctx, bandwidthUpdates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.BandwidthUpdate) error {
|
||||
stats, err := GetConnectionsStats()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Report all statistics.
|
||||
for i, stat := range stats {
|
||||
connID := packet.CreateConnectionID(
|
||||
packet.IPProtocol(stat.protocol),
|
||||
convertArrayToIP(stat.localIP, stat.ipV6 == 1), stat.localPort,
|
||||
convertArrayToIP(stat.remoteIP, stat.ipV6 == 1), stat.remotePort,
|
||||
false,
|
||||
)
|
||||
update := &packet.BandwidthUpdate{
|
||||
ConnID: connID,
|
||||
BytesReceived: stat.receivedBytes,
|
||||
BytesSent: stat.transmittedBytes,
|
||||
Method: packet.Additive,
|
||||
}
|
||||
select {
|
||||
case bandwidthUpdates <- update:
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
log.Warningf("kext: bandwidth update queue is full, skipping rest of batch (%d entries)", len(stats)-i)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func StartBandwithConsoleLogger() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
conns, err := GetConnectionsStats()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, conn := range conns {
|
||||
if conn.receivedBytes == 0 && conn.transmittedBytes == 0 {
|
||||
continue
|
||||
}
|
||||
key := Key{
|
||||
localIP: conn.localIP,
|
||||
remoteIP: conn.remoteIP,
|
||||
localPort: conn.localPort,
|
||||
remotePort: conn.remotePort,
|
||||
ipv6: conn.ipV6 == 1,
|
||||
protocol: conn.protocol,
|
||||
}
|
||||
|
||||
// First we get a "copy" of the entry
|
||||
if entry, ok := m[key]; ok {
|
||||
// Then we modify the copy
|
||||
entry.rx += conn.receivedBytes
|
||||
entry.tx += conn.transmittedBytes
|
||||
|
||||
// Then we reassign map entry
|
||||
m[key] = entry
|
||||
} else {
|
||||
m[key] = Rxtxdata{
|
||||
rx: conn.receivedBytes,
|
||||
tx: conn.transmittedBytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debug("----------------------------------")
|
||||
for key, value := range m {
|
||||
log.Debugf(
|
||||
"Conn: %d %s:%d %s:%d rx:%d tx:%d", key.protocol,
|
||||
convertArrayToIP(key.localIP, key.ipv6), key.localPort,
|
||||
convertArrayToIP(key.remoteIP, key.ipv6), key.remotePort,
|
||||
value.rx, value.tx,
|
||||
)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
4
service/firewall/interception/windowskext/doc.go
Normal file
4
service/firewall/interception/windowskext/doc.go
Normal file
@@ -0,0 +1,4 @@
|
||||
// +build windows
|
||||
|
||||
// Package windowskext provides network interception capabilities on windows via the Portmaster Kernel Extension.
|
||||
package windowskext
|
||||
209
service/firewall/interception/windowskext/handler.go
Normal file
209
service/firewall/interception/windowskext/handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/safing/portmaster/service/process"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
const (
|
||||
// VerdictRequestFlagFastTrackPermitted is set on packets that have been
|
||||
// already permitted by the kernel extension and the verdict request is only
|
||||
// informational.
|
||||
VerdictRequestFlagFastTrackPermitted = 1
|
||||
|
||||
// VerdictRequestFlagSocketAuth indicates that the verdict request is for a
|
||||
// connection that was intercepted on an ALE layer instead of in the network
|
||||
// stack itself. Thus, no packet data is available.
|
||||
VerdictRequestFlagSocketAuth = 2
|
||||
|
||||
// VerdictRequestFlagExpectSocketAuth indicates that the next verdict
|
||||
// requests is expected to be an informational socket auth request from
|
||||
// the ALE layer.
|
||||
VerdictRequestFlagExpectSocketAuth = 4
|
||||
)
|
||||
|
||||
// Do not change the order of the members! The structure is used to communicate with the kernel extension.
|
||||
// VerdictRequest is the request structure from the Kext.
|
||||
type VerdictRequest struct {
|
||||
id uint32 // ID from RegisterPacket
|
||||
pid uint64 // Process ID - info only packets
|
||||
direction uint8
|
||||
ipV6 uint8 // True: IPv6, False: IPv4
|
||||
protocol uint8 // Protocol
|
||||
flags uint8 // Flags
|
||||
localIP [4]uint32 // Source Address
|
||||
remoteIP [4]uint32 // Destination Address
|
||||
localPort uint16 // Source Port
|
||||
remotePort uint16 // Destination port
|
||||
_ uint32 // compartmentID
|
||||
_ uint32 // interfaceIndex
|
||||
_ uint32 // subInterfaceIndex
|
||||
packetSize uint32
|
||||
}
|
||||
|
||||
// Do not change the order of the members! The structure is used to communicate with the kernel extension.
|
||||
type VerdictInfo struct {
|
||||
id uint32 // ID from RegisterPacket
|
||||
verdict network.Verdict // verdict for the connection
|
||||
}
|
||||
|
||||
// Do not change the order of the members! The structure to communicate with the kernel extension.
|
||||
type VerdictUpdateInfo struct {
|
||||
localIP [4]uint32 // Source Address, only srcIP[0] if IPv4
|
||||
remoteIP [4]uint32 // Destination Address
|
||||
localPort uint16 // Source Port
|
||||
remotePort uint16 // Destination port
|
||||
ipV6 uint8 // True: IPv6, False: IPv4
|
||||
protocol uint8 // Protocol (UDP, TCP, ...)
|
||||
verdict uint8 // New verdict
|
||||
}
|
||||
|
||||
type ConnectionStat struct {
|
||||
localIP [4]uint32 //Source Address, only srcIP[0] if IPv4
|
||||
remoteIP [4]uint32 //Destination Address
|
||||
localPort uint16 //Source Port
|
||||
remotePort uint16 //Destination port
|
||||
receivedBytes uint64 //Number of bytes recived on this connection
|
||||
transmittedBytes uint64 //Number of bytes transsmited from this connection
|
||||
ipV6 uint8 //True: IPv6, False: IPv4
|
||||
protocol uint8 //Protocol (UDP, TCP, ...)
|
||||
}
|
||||
|
||||
type VersionInfo struct {
|
||||
major uint8
|
||||
minor uint8
|
||||
revision uint8
|
||||
build uint8
|
||||
}
|
||||
|
||||
func (v *VersionInfo) String() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", v.major, v.minor, v.revision, v.build)
|
||||
}
|
||||
|
||||
// Handler transforms received packets to the Packet interface.
|
||||
func Handler(ctx context.Context, packets chan packet.Packet) {
|
||||
for {
|
||||
packetInfo, err := RecvVerdictRequest()
|
||||
if err != nil {
|
||||
// Check if we are done with processing.
|
||||
if errors.Is(err, ErrKextNotReady) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Warningf("failed to get packet from windows kext: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if packetInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// log.Tracef("packet: %+v", packetInfo)
|
||||
|
||||
// New Packet
|
||||
new := &Packet{
|
||||
verdictRequest: packetInfo,
|
||||
verdictSet: abool.NewBool(false),
|
||||
}
|
||||
info := new.Info()
|
||||
info.Inbound = packetInfo.direction > 0
|
||||
info.InTunnel = false
|
||||
info.Protocol = packet.IPProtocol(packetInfo.protocol)
|
||||
info.PID = int(packetInfo.pid)
|
||||
info.SeenAt = time.Now()
|
||||
|
||||
// Check PID
|
||||
if info.PID == 0 {
|
||||
// Windows does not have zero PIDs.
|
||||
// Set to UndefinedProcessID.
|
||||
info.PID = process.UndefinedProcessID
|
||||
}
|
||||
|
||||
// Set IP version
|
||||
if packetInfo.ipV6 == 1 {
|
||||
info.Version = packet.IPv6
|
||||
} else {
|
||||
info.Version = packet.IPv4
|
||||
}
|
||||
|
||||
// Set IPs
|
||||
if info.Inbound {
|
||||
// Inbound
|
||||
info.Src = convertArrayToIP(packetInfo.remoteIP, info.Version == packet.IPv6)
|
||||
info.Dst = convertArrayToIP(packetInfo.localIP, info.Version == packet.IPv6)
|
||||
} else {
|
||||
// Outbound
|
||||
info.Src = convertArrayToIP(packetInfo.localIP, info.Version == packet.IPv6)
|
||||
info.Dst = convertArrayToIP(packetInfo.remoteIP, info.Version == packet.IPv6)
|
||||
}
|
||||
|
||||
// Set Ports
|
||||
if info.Inbound {
|
||||
// Inbound
|
||||
info.SrcPort = packetInfo.remotePort
|
||||
info.DstPort = packetInfo.localPort
|
||||
} else {
|
||||
// Outbound
|
||||
info.SrcPort = packetInfo.localPort
|
||||
info.DstPort = packetInfo.remotePort
|
||||
}
|
||||
|
||||
packets <- new
|
||||
}
|
||||
}
|
||||
|
||||
// convertArrayToIP converts an array of uint32 values to a net.IP address.
|
||||
func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP {
|
||||
if !ipv6 {
|
||||
addressBuf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(addressBuf, input[0])
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
addressBuf := make([]byte, 16)
|
||||
for i := 0; i < 4; i++ {
|
||||
binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
|
||||
}
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
func ipAddressToArray(ip net.IP, isIPv6 bool) [4]uint32 {
|
||||
array := [4]uint32{0}
|
||||
if isIPv6 {
|
||||
for i := 0; i < 4; i++ {
|
||||
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[i], 4), getUInt32Value(&ip[i]))
|
||||
}
|
||||
} else {
|
||||
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[0], 4), getUInt32Value(&ip[0]))
|
||||
}
|
||||
|
||||
return array
|
||||
}
|
||||
|
||||
func asByteArray[T any](obj *T) []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj))
|
||||
}
|
||||
|
||||
func asByteArrayWithLength[T any](obj *T, size uint32) []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), size)
|
||||
}
|
||||
|
||||
func getUInt32Value[T any](obj *T) uint32 {
|
||||
return *(*uint32)(unsafe.Pointer(obj))
|
||||
}
|
||||
335
service/firewall/interception/windowskext/kext.go
Normal file
335
service/firewall/interception/windowskext/kext.go
Normal file
@@ -0,0 +1,335 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// Package errors
|
||||
var (
|
||||
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
|
||||
ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension")
|
||||
|
||||
kextLock sync.RWMutex
|
||||
driverPath string
|
||||
|
||||
kextHandle windows.Handle
|
||||
service *KextService
|
||||
)
|
||||
|
||||
const (
|
||||
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
|
||||
winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
|
||||
driverName = "PortmasterKext"
|
||||
)
|
||||
|
||||
// Init initializes the DLL and the Kext (Kernel Driver).
|
||||
func Init(path string) error {
|
||||
kextHandle = winInvalidHandleValue
|
||||
driverPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start intercepting.
|
||||
func Start() error {
|
||||
kextLock.Lock()
|
||||
defer kextLock.Unlock()
|
||||
|
||||
// initialize and start driver service
|
||||
var err error
|
||||
service, err = createKextService(driverName, driverPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
running, err := service.isRunning()
|
||||
if err == nil && !running {
|
||||
err = service.start(true)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("service not initialized: %w", err)
|
||||
}
|
||||
|
||||
// Open the driver
|
||||
filename := `\\.\` + driverName
|
||||
kextHandle, err = openDriver(filename)
|
||||
|
||||
// driver was not installed
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open driver: %q %w", filename, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop intercepting.
|
||||
func Stop() error {
|
||||
// Prepare kernel for shutdown
|
||||
err := shutdownRequest()
|
||||
if err != nil {
|
||||
log.Warningf("winkext: shutdown request failed: %s", err)
|
||||
}
|
||||
|
||||
kextLock.Lock()
|
||||
defer kextLock.Unlock()
|
||||
|
||||
err = closeDriver(kextHandle)
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to close the handle: %s", err)
|
||||
}
|
||||
|
||||
err = service.stop(true)
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to stop service: %s", err)
|
||||
}
|
||||
// Driver file may change on the next start so it's better to delete the service
|
||||
err = service.delete()
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to delete service: %s", err)
|
||||
}
|
||||
|
||||
kextHandle = winInvalidHandleValue
|
||||
return nil
|
||||
}
|
||||
|
||||
func shutdownRequest() error {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return ErrKextNotReady
|
||||
}
|
||||
// Sent a shutdown request so the kernel extension can prepare.
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_SHUTDOWN_REQUEST, nil, nil)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil.
|
||||
func RecvVerdictRequest() (*VerdictRequest, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
// DEBUG:
|
||||
// timestamp := time.Now()
|
||||
// defer log.Tracef("winkext: getting verdict request took %s", time.Since(timestamp))
|
||||
|
||||
// Initialize struct for the output data
|
||||
var new VerdictRequest
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&new)
|
||||
bytesRead, err := deviceIOControl(kextHandle, IOCTL_RECV_VERDICT_REQ, nil, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if bytesRead == 0 {
|
||||
return nil, nil // no error, no new verdict request
|
||||
}
|
||||
|
||||
return &new, nil
|
||||
}
|
||||
|
||||
// SetVerdict sets the verdict for a packet and/or connection.
|
||||
func SetVerdict(pkt *Packet, verdict network.Verdict) error {
|
||||
if pkt.verdictRequest.pid != 0 {
|
||||
return nil // Ignore info only packets
|
||||
}
|
||||
if pkt.verdictRequest.id == 0 {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: no packet ID", verdict)
|
||||
return ErrNoPacketID
|
||||
}
|
||||
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: kext not ready", verdict)
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
verdictInfo := VerdictInfo{pkt.verdictRequest.id, verdict}
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&verdictInfo)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_SET_VERDICT, data, nil)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPayload returns the payload of a packet.
|
||||
func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
|
||||
if packetID == 0 {
|
||||
return nil, ErrNoPacketID
|
||||
}
|
||||
|
||||
// Check if driver is initialized
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
buf := make([]byte, packetSize)
|
||||
|
||||
// Combine id and length
|
||||
payload := struct {
|
||||
id uint32
|
||||
length uint32
|
||||
}{packetID, packetSize}
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&payload)
|
||||
bytesRead, err := deviceIOControl(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check the result and return
|
||||
if bytesRead == 0 {
|
||||
return nil, errors.New("windows kext did not return any data")
|
||||
}
|
||||
|
||||
if bytesRead < uint32(len(buf)) {
|
||||
return buf[:bytesRead], nil
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func ClearCache() error {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
// Make driver request
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_CLEAR_CACHE, nil, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func UpdateVerdict(conn *network.Connection) error {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
var isIpv6 uint8 = 0
|
||||
if conn.IPVersion == packet.IPv6 {
|
||||
isIpv6 = 1
|
||||
}
|
||||
|
||||
// initialize variables
|
||||
info := VerdictUpdateInfo{
|
||||
ipV6: isIpv6,
|
||||
protocol: uint8(conn.IPProtocol),
|
||||
localIP: ipAddressToArray(conn.LocalIP, isIpv6 == 1),
|
||||
localPort: conn.LocalPort,
|
||||
remoteIP: ipAddressToArray(conn.Entity.IP, isIpv6 == 1),
|
||||
remotePort: conn.Entity.Port,
|
||||
verdict: uint8(conn.Verdict),
|
||||
}
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&info)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_UPDATE_VERDICT, data, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func GetVersion() (*VersionInfo, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
data := make([]uint8, 4)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_VERSION, nil, data)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
version := &VersionInfo{
|
||||
major: data[0],
|
||||
minor: data[1],
|
||||
revision: data[2],
|
||||
build: data[3],
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
var sizeOfConnectionStat = uint32(unsafe.Sizeof(ConnectionStat{}))
|
||||
|
||||
func GetConnectionsStats() ([]ConnectionStat, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
var data [100]ConnectionStat
|
||||
size := len(data)
|
||||
bytesReturned, err := deviceIOControl(kextHandle, IOCTL_GET_CONNECTIONS_STAT, asByteArray(&size), asByteArray(&data))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data[:bytesReturned/sizeOfConnectionStat], nil
|
||||
}
|
||||
|
||||
func openDriver(filename string) (windows.Handle, error) {
|
||||
u16filename, err := syscall.UTF16FromString(filename)
|
||||
if err != nil {
|
||||
return winInvalidHandleValue, fmt.Errorf("failed to convert driver filename to UTF16 string %w", err)
|
||||
}
|
||||
|
||||
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0)
|
||||
if err != nil {
|
||||
return winInvalidHandleValue, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func closeDriver(handle windows.Handle) error {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
return windows.CloseHandle(handle)
|
||||
}
|
||||
137
service/firewall/interception/windowskext/packet.go
Normal file
137
service/firewall/interception/windowskext/packet.go
Normal file
@@ -0,0 +1,137 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
// Packet represents an IP packet.
|
||||
type Packet struct {
|
||||
packet.Base
|
||||
|
||||
verdictRequest *VerdictRequest
|
||||
verdictSet *abool.AtomicBool
|
||||
|
||||
payloadLoaded bool
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// FastTrackedByIntegration returns whether the packet has been fast-track
|
||||
// accepted by the OS integration.
|
||||
func (pkt *Packet) FastTrackedByIntegration() bool {
|
||||
return pkt.verdictRequest.flags&VerdictRequestFlagFastTrackPermitted > 0
|
||||
}
|
||||
|
||||
// InfoOnly returns whether the packet is informational only and does not
|
||||
// represent an actual packet.
|
||||
func (pkt *Packet) InfoOnly() bool {
|
||||
return pkt.verdictRequest.flags&VerdictRequestFlagSocketAuth > 0
|
||||
}
|
||||
|
||||
// ExpectInfo returns whether the next packet is expected to be informational only.
|
||||
func (pkt *Packet) ExpectInfo() bool {
|
||||
return pkt.verdictRequest.flags&VerdictRequestFlagExpectSocketAuth > 0
|
||||
}
|
||||
|
||||
// GetPayload returns the full raw packet.
|
||||
func (pkt *Packet) LoadPacketData() error {
|
||||
pkt.lock.Lock()
|
||||
defer pkt.lock.Unlock()
|
||||
|
||||
if pkt.verdictRequest.id == 0 {
|
||||
return ErrNoPacketID
|
||||
}
|
||||
|
||||
if !pkt.payloadLoaded {
|
||||
pkt.payloadLoaded = true
|
||||
|
||||
payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to load payload: %s", err)
|
||||
return packet.ErrFailedToLoadPayload
|
||||
}
|
||||
|
||||
err = packet.Parse(payload, &pkt.Base)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to parse payload: %s", err)
|
||||
return packet.ErrFailedToLoadPayload
|
||||
}
|
||||
}
|
||||
|
||||
if len(pkt.Raw()) == 0 {
|
||||
return packet.ErrFailedToLoadPayload
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Accept accepts the packet.
|
||||
func (pkt *Packet) Accept() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, -network.VerdictAccept)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Block blocks the packet.
|
||||
func (pkt *Packet) Block() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, -network.VerdictBlock)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Drop drops the packet.
|
||||
func (pkt *Packet) Drop() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, -network.VerdictDrop)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PermanentAccept permanently accepts connection (and the current packet).
|
||||
func (pkt *Packet) PermanentAccept() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictAccept)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PermanentBlock permanently blocks connection (and the current packet).
|
||||
func (pkt *Packet) PermanentBlock() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictBlock)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PermanentDrop permanently drops connection (and the current packet).
|
||||
func (pkt *Packet) PermanentDrop() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictDrop)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet).
|
||||
func (pkt *Packet) RerouteToNameserver() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictRerouteToNameserver)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet).
|
||||
func (pkt *Packet) RerouteToTunnel() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictRerouteToTunnel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
197
service/firewall/interception/windowskext/service.go
Normal file
197
service/firewall/interception/windowskext/service.go
Normal file
@@ -0,0 +1,197 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type KextService struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func createKextService(driverName string, driverPath string) (*KextService, error) {
|
||||
// Open the service manager:
|
||||
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open service manager: %d", err)
|
||||
}
|
||||
defer windows.CloseServiceHandle(manager)
|
||||
|
||||
driverNameU16, err := syscall.UTF16FromString(driverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err)
|
||||
}
|
||||
|
||||
// Check if there is an old service.
|
||||
service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS)
|
||||
if err == nil {
|
||||
log.Warning("kext: old driver service was found")
|
||||
oldService := &KextService{handle: service}
|
||||
err := deleteService(manager, oldService, driverNameU16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete old driver service: %s", err)
|
||||
}
|
||||
|
||||
service = winInvalidHandleValue
|
||||
log.Info("kext: old driver service was deleted successfully")
|
||||
}
|
||||
|
||||
driverPathU16, err := syscall.UTF16FromString(driverPath)
|
||||
|
||||
// Create the service
|
||||
service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KextService{handle: service}, nil
|
||||
}
|
||||
|
||||
func deleteService(manager windows.Handle, service *KextService, driverName []uint16) error {
|
||||
// Stop and wait before deleting
|
||||
_ = service.stop(true)
|
||||
|
||||
// Try to delete even if stop failed
|
||||
err := service.delete()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old service: %s", err)
|
||||
}
|
||||
|
||||
// Wait until we can no longer open the old service.
|
||||
// Not very efficient but NotifyServiceStatusChange cannot be used with driver service.
|
||||
start := time.Now()
|
||||
timeLimit := time.Duration(30 * time.Second)
|
||||
for {
|
||||
handle, err := windows.OpenService(manager, &driverName[0], windows.SERVICE_ALL_ACCESS)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
_ = windows.CloseServiceHandle(handle)
|
||||
|
||||
if time.Since(start) > timeLimit {
|
||||
return fmt.Errorf("time limit reached")
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) isValid() bool {
|
||||
return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
|
||||
}
|
||||
|
||||
func (s *KextService) isRunning() (bool, error) {
|
||||
if !s.isValid() {
|
||||
return false, fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
var status windows.SERVICE_STATUS
|
||||
err := windows.QueryServiceStatus(s.handle, &status)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return status.CurrentState == windows.SERVICE_RUNNING, nil
|
||||
}
|
||||
|
||||
func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) {
|
||||
var status windows.SERVICE_STATUS
|
||||
status.CurrentState = windows.SERVICE_NO_CHANGE
|
||||
start := time.Now()
|
||||
for status.CurrentState == neededStatus {
|
||||
err := windows.QueryServiceStatus(handle, &status)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed while waiting for service to start: %w", err)
|
||||
}
|
||||
|
||||
if time.Since(start) > timeLimit {
|
||||
return false, fmt.Errorf("time limit reached")
|
||||
}
|
||||
|
||||
// Sleep for 1/10 of the wait hint, recommended time from microsoft
|
||||
time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *KextService) start(wait bool) error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
// Start the service:
|
||||
err := windows.StartService(s.handle, 0, nil)
|
||||
|
||||
if err != nil {
|
||||
err = windows.GetLastError()
|
||||
if err != windows.ERROR_SERVICE_ALREADY_RUNNING {
|
||||
// Failed to start service; clean-up:
|
||||
var status windows.SERVICE_STATUS
|
||||
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
|
||||
_ = windows.DeleteService(s.handle)
|
||||
_ = windows.CloseServiceHandle(s.handle)
|
||||
s.handle = winInvalidHandleValue
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for service to start
|
||||
if wait {
|
||||
success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second))
|
||||
if err != nil || !success {
|
||||
return fmt.Errorf("service did not start: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) stop(wait bool) error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
// Stop the service
|
||||
var status windows.SERVICE_STATUS
|
||||
err := windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service failed to stop: %w", err)
|
||||
}
|
||||
|
||||
// Wait for service to stop
|
||||
if wait {
|
||||
success, err := waitForServiceStatus(s.handle, windows.SERVICE_STOPPED, time.Duration(10*time.Second))
|
||||
if err != nil || !success {
|
||||
return fmt.Errorf("service did not stop: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) delete() error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
err := windows.DeleteService(s.handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete service: %s", err)
|
||||
}
|
||||
|
||||
// Service wont be deleted until all handles are closed.
|
||||
err = windows.CloseServiceHandle(s.handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close service handle: %s", err)
|
||||
}
|
||||
|
||||
s.handle = winInvalidHandleValue
|
||||
return nil
|
||||
}
|
||||
72
service/firewall/interception/windowskext/syscall.go
Normal file
72
service/firewall/interception/windowskext/syscall.go
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
METHOD_BUFFERED = 0
|
||||
METHOD_IN_DIRECT = 1
|
||||
METHOD_OUT_DIRECT = 2
|
||||
METHOD_NEITHER = 3
|
||||
|
||||
SIOCTL_TYPE = 40000
|
||||
)
|
||||
|
||||
var (
|
||||
IOCTL_VERSION = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_SHUTDOWN_REQUEST = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_UPDATE_VERDICT = ctlCode(SIOCTL_TYPE, 0x806, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_GET_CONNECTIONS_STAT = ctlCode(SIOCTL_TYPE, 0x807, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
)
|
||||
|
||||
func ctlCode(device_type, function, method, access uint32) uint32 {
|
||||
return (device_type << 16) | (access << 14) | (function << 2) | method
|
||||
}
|
||||
|
||||
func deviceIOControlAsync(handle windows.Handle, code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) {
|
||||
var inDataPtr *byte = nil
|
||||
var inDataSize uint32 = 0
|
||||
if inData != nil {
|
||||
inDataPtr = &inData[0]
|
||||
inDataSize = uint32(len(inData))
|
||||
}
|
||||
|
||||
var outDataPtr *byte = nil
|
||||
var outDataSize uint32 = 0
|
||||
if outData != nil {
|
||||
outDataPtr = &outData[0]
|
||||
outDataSize = uint32(len(outData))
|
||||
}
|
||||
|
||||
overlapped := &windows.Overlapped{}
|
||||
err := windows.DeviceIoControl(handle,
|
||||
code,
|
||||
inDataPtr, inDataSize,
|
||||
outDataPtr, outDataSize,
|
||||
nil, overlapped)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return overlapped, nil
|
||||
|
||||
}
|
||||
|
||||
func deviceIOControl(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
|
||||
overlapped, err := deviceIOControlAsync(handle, code, inData, outData)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var bytesReturned uint32
|
||||
err = windows.GetOverlappedResult(handle, overlapped, &bytesReturned, true)
|
||||
|
||||
return bytesReturned, err
|
||||
}
|
||||
649
service/firewall/master.go
Normal file
649
service/firewall/master.go
Normal file
@@ -0,0 +1,649 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/agext/levenshtein"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/detection/dga"
|
||||
"github.com/safing/portmaster/service/intel/customlists"
|
||||
"github.com/safing/portmaster/service/intel/filterlists"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
)
|
||||
|
||||
const noReasonOptionKey = ""
|
||||
|
||||
type deciderFn func(context.Context, *network.Connection, *profile.LayeredProfile, packet.Packet) bool
|
||||
|
||||
var defaultDeciders = []deciderFn{
|
||||
checkPortmasterConnection,
|
||||
checkIfBroadcastReply,
|
||||
checkConnectionType,
|
||||
checkConnectionScope,
|
||||
checkEndpointLists,
|
||||
checkInvalidIP,
|
||||
checkResolverScope,
|
||||
checkConnectivityDomain,
|
||||
checkBypassPrevention,
|
||||
checkFilterLists,
|
||||
checkCustomFilterList,
|
||||
checkDomainHeuristics,
|
||||
checkAutoPermitRelated,
|
||||
}
|
||||
|
||||
// decideOnConnection makes a decision about a connection.
|
||||
// When called, the connection and profile is already locked.
|
||||
func decideOnConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) {
|
||||
// Check if we have a process and profile.
|
||||
layeredProfile := conn.Process().Profile()
|
||||
if layeredProfile == nil {
|
||||
conn.Deny("unknown process or profile", noReasonOptionKey)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the layered profile needs updating.
|
||||
if layeredProfile.NeedsUpdate() {
|
||||
// Update revision counter in connection.
|
||||
conn.ProfileRevisionCounter = layeredProfile.Update(
|
||||
conn.Process().MatchingData(),
|
||||
conn.Process().CreateProfileCallback,
|
||||
)
|
||||
conn.SaveWhenFinished()
|
||||
|
||||
// Reset verdict for connection.
|
||||
log.Tracer(ctx).Infof("filter: profile updated, re-evaluating verdict of %s", conn)
|
||||
|
||||
// Reset entity if it exists.
|
||||
if conn.Entity != nil {
|
||||
conn.Entity.ResetLists()
|
||||
}
|
||||
} else {
|
||||
// Check if the revision counter of the connection needs updating.
|
||||
revCnt := layeredProfile.RevisionCnt()
|
||||
if conn.ProfileRevisionCounter != revCnt {
|
||||
conn.ProfileRevisionCounter = revCnt
|
||||
conn.SaveWhenFinished()
|
||||
}
|
||||
}
|
||||
|
||||
// prepare the entity and resolve all filterlist matches
|
||||
conn.Entity.ResolveSubDomainLists(ctx, layeredProfile.FilterSubDomains())
|
||||
conn.Entity.EnableCNAMECheck(ctx, layeredProfile.FilterCNAMEs())
|
||||
conn.Entity.LoadLists(ctx)
|
||||
|
||||
// Run all deciders and return if they came to a conclusion.
|
||||
done, defaultAction := runDeciders(ctx, defaultDeciders, conn, layeredProfile, pkt)
|
||||
if done {
|
||||
return
|
||||
}
|
||||
|
||||
// DNS Request are always default allowed, as the endpoint lists could not
|
||||
// be checked fully.
|
||||
if conn.Type == network.DNSRequest {
|
||||
conn.Accept("allowing dns request", noReasonOptionKey)
|
||||
return
|
||||
}
|
||||
|
||||
// Deciders did not conclude, use default action.
|
||||
switch defaultAction {
|
||||
case profile.DefaultActionPermit:
|
||||
conn.Accept("allowed by default action", profile.CfgOptionDefaultActionKey)
|
||||
case profile.DefaultActionAsk:
|
||||
// Only prompt if there has not been a decision already.
|
||||
// This prevents prompts from being created when re-evaluating connections.
|
||||
if conn.Verdict == network.VerdictUndecided {
|
||||
prompt(ctx, conn)
|
||||
}
|
||||
default:
|
||||
conn.Deny("blocked by default action", profile.CfgOptionDefaultActionKey)
|
||||
}
|
||||
}
|
||||
|
||||
func runDeciders(ctx context.Context, selectedDeciders []deciderFn, conn *network.Connection, layeredProfile *profile.LayeredProfile, pkt packet.Packet) (done bool, defaultAction uint8) {
|
||||
// Read-lock all the profiles.
|
||||
layeredProfile.LockForUsage()
|
||||
defer layeredProfile.UnlockForUsage()
|
||||
|
||||
// Go though all deciders, return if one sets an action.
|
||||
for _, decider := range selectedDeciders {
|
||||
if decider(ctx, conn, layeredProfile, pkt) {
|
||||
return true, profile.DefaultActionNotSet
|
||||
}
|
||||
}
|
||||
|
||||
// Return the default action.
|
||||
return false, layeredProfile.DefaultAction()
|
||||
}
|
||||
|
||||
// checkPortmasterConnection allows all connection that originate from
|
||||
// portmaster itself.
|
||||
func checkPortmasterConnection(ctx context.Context, conn *network.Connection, _ *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// Grant own outgoing or local connections.
|
||||
|
||||
// Blocking our own connections can lead to a very literal deadlock.
|
||||
// This can currently happen, as fast-tracked connections are also
|
||||
// reset in the OS integration and might show up in the connection
|
||||
// handling if a packet in the other direction hits the firewall first.
|
||||
|
||||
// Ignore other processes.
|
||||
if conn.Process().Pid != ownPID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Ignore inbound connection if non-local.
|
||||
if conn.Inbound {
|
||||
myIP, err := netenv.IsMyIP(conn.Entity.IP)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Debugf("filter: failed to check if %s is own IP for granting own connection: %s", conn.Entity.IP, err)
|
||||
return false
|
||||
}
|
||||
if !myIP {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
log.Tracer(ctx).Infof("filter: granting own connection %s", conn)
|
||||
conn.Accept("connection by Portmaster", noReasonOptionKey)
|
||||
conn.Internal = true
|
||||
return true
|
||||
}
|
||||
|
||||
func checkIfBroadcastReply(ctx context.Context, conn *network.Connection, _ *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// Only check inbound connections.
|
||||
if !conn.Inbound {
|
||||
return false
|
||||
}
|
||||
// Only check if the process has been identified.
|
||||
if !conn.Process().IsIdentified() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the remote IP is part of a local network.
|
||||
localNet, err := netenv.GetLocalNetwork(conn.Entity.IP)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Warningf("filter: failed to get local network: %s", err)
|
||||
return false
|
||||
}
|
||||
if localNet == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Search for a matching requesting connection.
|
||||
requestingConn := network.GetMulticastRequestConn(conn, localNet)
|
||||
if requestingConn == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.Accept(
|
||||
fmt.Sprintf(
|
||||
"response to multi/broadcast query to %s/%s",
|
||||
packet.IPProtocol(requestingConn.Entity.Protocol),
|
||||
net.JoinHostPort(
|
||||
requestingConn.Entity.IP.String(),
|
||||
strconv.Itoa(int(requestingConn.Entity.Port)),
|
||||
),
|
||||
),
|
||||
"",
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
func checkEndpointLists(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// DNS request from the system resolver require a special decision process,
|
||||
// because the original requesting process is not known. Here, we only check
|
||||
// global-only and the most important per-app aspects. The resulting
|
||||
// connection is then blocked when the original requesting process is known.
|
||||
if conn.Type == network.DNSRequest && conn.Process().IsSystemResolver() {
|
||||
return checkEndpointListsForSystemResolverDNSRequests(ctx, conn, p)
|
||||
}
|
||||
|
||||
var result endpoints.EPResult
|
||||
var reason endpoints.Reason
|
||||
|
||||
// check endpoints list
|
||||
var optionKey string
|
||||
if conn.Inbound {
|
||||
result, reason = p.MatchServiceEndpoint(ctx, conn.Entity)
|
||||
optionKey = profile.CfgOptionServiceEndpointsKey
|
||||
} else {
|
||||
result, reason = p.MatchEndpoint(ctx, conn.Entity)
|
||||
optionKey = profile.CfgOptionEndpointsKey
|
||||
}
|
||||
switch result {
|
||||
case endpoints.Denied, endpoints.MatchError:
|
||||
conn.DenyWithContext(reason.String(), optionKey, reason.Context())
|
||||
return true
|
||||
case endpoints.Permitted:
|
||||
conn.AcceptWithContext(reason.String(), optionKey, reason.Context())
|
||||
return true
|
||||
case endpoints.NoMatch:
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// checkEndpointListsForSystemResolverDNSRequests is a special version of
|
||||
// checkEndpointLists that is only meant for DNS queries by the system
|
||||
// resolver. It only checks the endpoint filter list of the local profile and
|
||||
// does not include the global profile.
|
||||
func checkEndpointListsForSystemResolverDNSRequests(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile) bool {
|
||||
var profileEndpoints endpoints.Endpoints
|
||||
var optionKey string
|
||||
if conn.Inbound {
|
||||
profileEndpoints = p.LocalProfileWithoutLocking().GetServiceEndpoints()
|
||||
optionKey = profile.CfgOptionServiceEndpointsKey
|
||||
} else {
|
||||
profileEndpoints = p.LocalProfileWithoutLocking().GetEndpoints()
|
||||
optionKey = profile.CfgOptionEndpointsKey
|
||||
}
|
||||
|
||||
if profileEndpoints.IsSet() {
|
||||
result, reason := profileEndpoints.Match(ctx, conn.Entity)
|
||||
if endpoints.IsDecision(result) {
|
||||
switch result {
|
||||
case endpoints.Denied, endpoints.MatchError:
|
||||
conn.DenyWithContext(reason.String(), optionKey, reason.Context())
|
||||
return true
|
||||
case endpoints.Permitted:
|
||||
conn.AcceptWithContext(reason.String(), optionKey, reason.Context())
|
||||
return true
|
||||
case endpoints.NoMatch:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var p2pFilterLists = []string{"17-P2P"}
|
||||
|
||||
func checkConnectionType(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
switch {
|
||||
// Block incoming connection, if not from localhost.
|
||||
case p.BlockInbound() && conn.Inbound &&
|
||||
!conn.Entity.IPScope.IsLocalhost():
|
||||
conn.Drop("inbound connections blocked", profile.CfgOptionBlockInboundKey)
|
||||
return true
|
||||
|
||||
// Check for P2P and related connections.
|
||||
case p.BlockP2P() && !conn.Inbound:
|
||||
switch {
|
||||
// Block anything that is in the P2P filter list.
|
||||
case conn.Entity.MatchLists(p2pFilterLists):
|
||||
conn.Block("P2P assistive infrastructure blocked based on filter list", profile.CfgOptionBlockP2PKey)
|
||||
return true
|
||||
|
||||
// Remaining P2P deciders only apply to IP connections.
|
||||
case conn.Type != network.IPConnection:
|
||||
return false
|
||||
|
||||
// Block well known ports of P2P assistive infrastructure.
|
||||
case conn.Entity.DstPort() == 3478 || // STUN/TURN
|
||||
conn.Entity.DstPort() == 5349: // STUN/TURN over TLS/DTLS
|
||||
conn.Block("P2P assistive infrastructure blocked based on port", profile.CfgOptionBlockP2PKey)
|
||||
return true
|
||||
|
||||
// Block direct connections with not previous DNS request.
|
||||
case conn.Entity.IPScope.IsGlobal() &&
|
||||
conn.Entity.Domain == "":
|
||||
conn.Block("direct connections (P2P) blocked", profile.CfgOptionBlockP2PKey)
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func checkConnectivityDomain(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
switch {
|
||||
case conn.Entity.Domain == "":
|
||||
// Only applies if a domain is available.
|
||||
return false
|
||||
|
||||
case netenv.GetOnlineStatus() > netenv.StatusPortal:
|
||||
// Special grant only applies if network status is Portal (or even more limited).
|
||||
return false
|
||||
|
||||
case conn.Inbound:
|
||||
// Special grant only applies to outgoing connections.
|
||||
return false
|
||||
|
||||
case p.BlockScopeInternet():
|
||||
// Special grant only applies if application is allowed to connect to the Internet.
|
||||
return false
|
||||
|
||||
case netenv.IsConnectivityDomain(conn.Entity.Domain):
|
||||
// Special grant!
|
||||
conn.Accept("special grant for connectivity domain during network bootstrap", noReasonOptionKey)
|
||||
return true
|
||||
|
||||
default:
|
||||
// Not a special grant domain
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func checkConnectionScope(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// If we are handling a DNS request, check if we can immediately block it.
|
||||
if conn.Type == network.DNSRequest {
|
||||
// DNS is expected to resolve to LAN or Internet addresses.
|
||||
// Localhost queries are immediately responded to by the nameserver.
|
||||
if p.BlockScopeInternet() && p.BlockScopeLAN() {
|
||||
conn.Block("Internet and LAN access blocked", profile.CfgOptionBlockScopeInternetKey)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the network scope is permitted.
|
||||
switch conn.Entity.IPScope {
|
||||
case netutils.Global, netutils.GlobalMulticast:
|
||||
if p.BlockScopeInternet() {
|
||||
conn.Deny("Internet access blocked", profile.CfgOptionBlockScopeInternetKey) // Block Outbound / Drop Inbound
|
||||
return true
|
||||
}
|
||||
case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast:
|
||||
if p.BlockScopeLAN() {
|
||||
conn.Block("LAN access blocked", profile.CfgOptionBlockScopeLANKey) // Block Outbound / Drop Inbound
|
||||
return true
|
||||
}
|
||||
case netutils.HostLocal:
|
||||
if p.BlockScopeLocal() {
|
||||
conn.Block("Localhost access blocked", profile.CfgOptionBlockScopeLocalKey) // Block Outbound / Drop Inbound
|
||||
return true
|
||||
}
|
||||
case netutils.Undefined, netutils.Invalid:
|
||||
// Block Invalid / Undefined IPs _after_ the rules.
|
||||
return false
|
||||
default:
|
||||
conn.Deny("invalid IP", noReasonOptionKey) // Block Outbound / Drop Inbound
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkInvalidIP(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// Only applies to IP connections.
|
||||
if conn.Type != network.IPConnection {
|
||||
return false
|
||||
}
|
||||
|
||||
// Block Invalid / Undefined IPs.
|
||||
switch conn.Entity.IPScope { //nolint:exhaustive // Only looking for specific values.
|
||||
case netutils.Undefined, netutils.Invalid:
|
||||
conn.Deny("invalid IP", noReasonOptionKey) // Block Outbound / Drop Inbound
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkBypassPrevention(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
if p.PreventBypassing() {
|
||||
// check for bypass protection
|
||||
result, reason, reasonCtx := PreventBypassing(ctx, conn)
|
||||
switch result {
|
||||
case endpoints.Denied, endpoints.MatchError:
|
||||
// Also block on MatchError to be on the safe side.
|
||||
// PreventBypassing does not use any data that needs to be loaded, so it should not fail anyway.
|
||||
conn.BlockWithContext("bypass prevention: "+reason, profile.CfgOptionPreventBypassingKey, reasonCtx)
|
||||
return true
|
||||
case endpoints.Permitted:
|
||||
conn.AcceptWithContext("bypass prevention: "+reason, profile.CfgOptionPreventBypassingKey, reasonCtx)
|
||||
return true
|
||||
case endpoints.NoMatch:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkFilterLists(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// apply privacy filter lists
|
||||
result, reason := p.MatchFilterLists(ctx, conn.Entity)
|
||||
switch result {
|
||||
case endpoints.Denied:
|
||||
// If the connection matches a filter list, check if the "unbreak" list matches too and abort blocking.
|
||||
resolvedUnbreakFilterListIDs := filterlists.GetUnbreakFilterListIDs()
|
||||
for _, blockedListID := range conn.Entity.BlockedByLists {
|
||||
for _, unbreakListID := range resolvedUnbreakFilterListIDs {
|
||||
if blockedListID == unbreakListID {
|
||||
log.Tracer(ctx).Debugf("filter: unbreak filter %s matched, ignoring other filter list matches", unbreakListID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
// Otherwise, continue with blocking.
|
||||
conn.DenyWithContext(reason.String(), profile.CfgOptionFilterListsKey, reason.Context())
|
||||
return true
|
||||
case endpoints.NoMatch:
|
||||
// nothing to do
|
||||
case endpoints.Permitted, endpoints.MatchError:
|
||||
fallthrough
|
||||
default:
|
||||
log.Tracer(ctx).Debugf("filter: filter lists returned unsupported verdict: %s", result)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func checkResolverScope(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// If the IP address was resolved, check the scope of the resolver.
|
||||
switch {
|
||||
case conn.Type != network.IPConnection:
|
||||
// Only applies to IP connections.
|
||||
case !p.RemoveOutOfScopeDNS():
|
||||
// Out of scope checking is not active.
|
||||
case conn.Resolver == nil:
|
||||
// IP address of connection was not resolved.
|
||||
case conn.Resolver.IPScope.IsGlobal() &&
|
||||
(conn.Entity.IPScope.IsLAN() || conn.Entity.IPScope.IsLocalhost()):
|
||||
// Block global resolvers from returning LAN/Localhost IPs.
|
||||
conn.Block("DNS server horizon violation: global DNS server returned local IP address", profile.CfgOptionRemoveOutOfScopeDNSKey)
|
||||
return true
|
||||
case conn.Resolver.IPScope.IsLAN() &&
|
||||
conn.Entity.IPScope.IsLocalhost():
|
||||
// Block LAN resolvers from returning Localhost IPs.
|
||||
conn.Block("DNS server horizon violation: LAN DNS server returned localhost IP address", profile.CfgOptionRemoveOutOfScopeDNSKey)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkDomainHeuristics(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// Don't check domain heuristics if no domain is available.
|
||||
if conn.Entity.Domain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if domain heuristics are enabled.
|
||||
if !p.DomainHeuristics() {
|
||||
return false
|
||||
}
|
||||
|
||||
trimmedDomain := strings.TrimRight(conn.Entity.Domain, ".")
|
||||
etld1, err := publicsuffix.EffectiveTLDPlusOne(trimmedDomain)
|
||||
if err != nil {
|
||||
// Don't run the check if the domain is a TLD.
|
||||
return false
|
||||
}
|
||||
|
||||
domainToCheck := strings.Split(etld1, ".")[0]
|
||||
score := dga.LmsScore(domainToCheck)
|
||||
if score < 5 {
|
||||
log.Tracer(ctx).Debugf(
|
||||
"filter: possible data tunnel by %s in eTLD+1 %s: %s has an lms score of %.2f",
|
||||
conn.Process(),
|
||||
etld1,
|
||||
domainToCheck,
|
||||
score,
|
||||
)
|
||||
conn.Block("possible DGA domain commonly used by malware", profile.CfgOptionDomainHeuristicsKey)
|
||||
return true
|
||||
}
|
||||
log.Tracer(ctx).Tracef("filter: LMS score of eTLD+1 %s is %.2f", etld1, score)
|
||||
|
||||
// 100 is a somewhat arbitrary threshold to ensure we don't mess
|
||||
// around with CDN domain names to early. They use short second-level
|
||||
// domains that would trigger LMS checks but are to small to actually
|
||||
// exfiltrate data.
|
||||
if len(conn.Entity.Domain) > len(etld1)+100 {
|
||||
domainToCheck = trimmedDomain[0:len(etld1)]
|
||||
score := dga.LmsScoreOfDomain(domainToCheck)
|
||||
if score < 10 {
|
||||
log.Tracer(ctx).Debugf(
|
||||
"filter: possible data tunnel by %s in subdomain of %s: %s has an lms score of %.2f",
|
||||
conn.Process(),
|
||||
conn.Entity.Domain,
|
||||
domainToCheck,
|
||||
score,
|
||||
)
|
||||
conn.Block("possible data tunnel for covert communication and protection bypassing", profile.CfgOptionDomainHeuristicsKey)
|
||||
return true
|
||||
}
|
||||
log.Tracer(ctx).Tracef("filter: LMS score of entire domain is %.2f", score)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkAutoPermitRelated(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// Auto permit is disabled for default action permit.
|
||||
if p.DefaultAction() == profile.DefaultActionPermit {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if auto permit is disabled.
|
||||
if p.DisableAutoPermit() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for relation to auto permit.
|
||||
related, reason := checkRelation(conn)
|
||||
if related {
|
||||
conn.Accept(reason, profile.CfgOptionDisableAutoPermitKey)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// checkRelation tries to find a relation between a process and a communication. This is for better out of the box experience and is _not_ meant to thwart intentional malware.
|
||||
func checkRelation(conn *network.Connection) (related bool, reason string) {
|
||||
// Don't check relation if no domain is available.
|
||||
if conn.Entity.Domain == "" {
|
||||
return false, ""
|
||||
}
|
||||
// Don't check for unknown processes.
|
||||
if conn.Process().Pid < 0 {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
pathElements := strings.Split(conn.Process().Path, string(filepath.Separator))
|
||||
// only look at the last two path segments
|
||||
if len(pathElements) > 2 {
|
||||
pathElements = pathElements[len(pathElements)-2:]
|
||||
}
|
||||
domainElements := strings.Split(conn.Entity.Domain, ".")
|
||||
|
||||
var domainElement string
|
||||
var processElement string
|
||||
|
||||
matchLoop:
|
||||
for _, domainElement = range domainElements {
|
||||
for _, pathElement := range pathElements {
|
||||
if levenshtein.Match(domainElement, pathElement, nil) > 0.5 {
|
||||
related = true
|
||||
processElement = pathElement
|
||||
break matchLoop
|
||||
}
|
||||
}
|
||||
if levenshtein.Match(domainElement, conn.Process().Name, nil) > 0.5 {
|
||||
related = true
|
||||
processElement = conn.Process().Name
|
||||
break matchLoop
|
||||
}
|
||||
if levenshtein.Match(domainElement, conn.Process().ExecName, nil) > 0.5 {
|
||||
related = true
|
||||
processElement = conn.Process().ExecName
|
||||
break matchLoop
|
||||
}
|
||||
}
|
||||
|
||||
if related {
|
||||
reason = fmt.Sprintf("auto allowed: domain is related to process: %s is related to %s", domainElement, processElement)
|
||||
}
|
||||
return related, reason
|
||||
}
|
||||
|
||||
func checkCustomFilterList(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
|
||||
// Check if any custom list is loaded at all.
|
||||
if !customlists.IsLoaded() {
|
||||
return false
|
||||
}
|
||||
|
||||
// block if the domain name appears in the custom filter list (check for subdomains if enabled)
|
||||
if conn.Entity.Domain != "" {
|
||||
if ok, match := customlists.LookupDomain(conn.Entity.Domain, p.FilterSubDomains()); ok {
|
||||
conn.Deny(fmt.Sprintf("domain %s matches %s in custom filter list", conn.Entity.Domain, match), customlists.CfgOptionCustomListFileKey)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// block if any of the CNAME appears in the custom filter list (check for subdomains if enabled)
|
||||
if p.FilterCNAMEs() {
|
||||
for _, cname := range conn.Entity.CNAME {
|
||||
if ok, match := customlists.LookupDomain(cname, p.FilterSubDomains()); ok {
|
||||
conn.Deny(fmt.Sprintf("domain alias (CNAME) %s matches %s in custom filter list", cname, match), customlists.CfgOptionCustomListFileKey)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// block if ip addresses appears in the custom filter list
|
||||
if conn.Entity.IP != nil {
|
||||
if customlists.LookupIP(conn.Entity.IP) {
|
||||
conn.Deny("IP address is in the custom filter list", customlists.CfgOptionCustomListFileKey)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// block autonomous system by its number if it appears in the custom filter list
|
||||
if conn.Entity.ASN != 0 {
|
||||
if customlists.LookupASN(conn.Entity.ASN) {
|
||||
conn.Deny("AS is in the custom filter list", customlists.CfgOptionCustomListFileKey)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// block if the country appears in the custom filter list
|
||||
if conn.Entity.Country != "" {
|
||||
if customlists.LookupCountry(conn.Entity.Country) {
|
||||
conn.Deny("country is in the custom filter list", customlists.CfgOptionCustomListFileKey)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
147
service/firewall/module.go
Normal file
147
service/firewall/module.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
_ "github.com/safing/portmaster/service/core"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
"github.com/safing/portmaster/spn/access"
|
||||
"github.com/safing/portmaster/spn/captain"
|
||||
)
|
||||
|
||||
var module *modules.Module
|
||||
|
||||
func init() {
|
||||
module = modules.Register("filter", prep, start, stop, "core", "interception", "intel", "netquery")
|
||||
subsystems.Register(
|
||||
"filter",
|
||||
"Privacy Filter",
|
||||
"DNS and Network Filter",
|
||||
module,
|
||||
"config:filter/",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
network.SetDefaultFirewallHandler(verdictHandler)
|
||||
|
||||
// Reset connections every time configuration changes
|
||||
// this will be triggered on spn enable/disable
|
||||
err := module.RegisterEventHook(
|
||||
"config",
|
||||
config.ChangeEvent,
|
||||
"reset connection verdicts after global config change",
|
||||
func(ctx context.Context, _ interface{}) error {
|
||||
resetAllConnectionVerdicts()
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("filter: failed to register event hook: %s", err)
|
||||
}
|
||||
|
||||
// Reset connections every time profile changes
|
||||
err = module.RegisterEventHook(
|
||||
"profiles",
|
||||
profile.ConfigChangeEvent,
|
||||
"reset connection verdicts after profile config change",
|
||||
func(ctx context.Context, eventData interface{}) error {
|
||||
// Expected event data: scoped profile ID.
|
||||
profileID, ok := eventData.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("event data is not a string: %v", eventData)
|
||||
}
|
||||
profileSource, profileID, ok := strings.Cut(profileID, "/")
|
||||
if !ok {
|
||||
return fmt.Errorf("event data does not seem to be a scoped profile ID: %v", eventData)
|
||||
}
|
||||
|
||||
resetProfileConnectionVerdict(profileSource, profileID)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("filter: failed to register event hook: %s", err)
|
||||
}
|
||||
|
||||
// Reset connections when spn is connected
|
||||
// connect and disconnecting is triggered on config change event but connecting takеs more time
|
||||
err = module.RegisterEventHook(
|
||||
"captain",
|
||||
captain.SPNConnectedEvent,
|
||||
"reset connection verdicts on SPN connect",
|
||||
func(ctx context.Context, _ interface{}) error {
|
||||
resetAllConnectionVerdicts()
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("filter: failed to register event hook: %s", err)
|
||||
}
|
||||
|
||||
// Reset connections when account is updated.
|
||||
// This will not change verdicts, but will update the feature flags on connections.
|
||||
err = module.RegisterEventHook(
|
||||
"access",
|
||||
access.AccountUpdateEvent,
|
||||
"update connection feature flags after account update",
|
||||
func(ctx context.Context, _ interface{}) error {
|
||||
resetAllConnectionVerdicts()
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("filter: failed to register event hook: %s", err)
|
||||
}
|
||||
|
||||
err = module.RegisterEventHook(
|
||||
"network",
|
||||
network.ConnectionReattributedEvent,
|
||||
"reset verdict of re-attributed connection",
|
||||
func(ctx context.Context, eventData interface{}) error {
|
||||
// Expected event data: connection ID.
|
||||
connID, ok := eventData.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("event data is not a string: %v", eventData)
|
||||
}
|
||||
resetSingleConnectionVerdict(connID)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("filter: failed to register event hook: %s", err)
|
||||
}
|
||||
|
||||
if err := registerConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return prepAPIAuth()
|
||||
}
|
||||
|
||||
func start() error {
|
||||
getConfig()
|
||||
startAPIAuth()
|
||||
|
||||
module.StartServiceWorker("packet handler", 0, packetHandler)
|
||||
module.StartServiceWorker("bandwidth update handler", 0, bandwidthUpdateHandler)
|
||||
|
||||
// Start stat logger if logging is set to trace.
|
||||
if log.GetLogLevel() == log.TraceLevel {
|
||||
module.StartServiceWorker("stat logger", 0, statLogger)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
return nil
|
||||
}
|
||||
747
service/firewall/packet_handler.go
Normal file
747
service/firewall/packet_handler.go
Normal file
@@ -0,0 +1,747 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/compat"
|
||||
_ "github.com/safing/portmaster/service/core/base"
|
||||
"github.com/safing/portmaster/service/firewall/inspection"
|
||||
"github.com/safing/portmaster/service/firewall/interception"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/netquery"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/spn/access"
|
||||
)
|
||||
|
||||
var (
|
||||
nameserverIPMatcher func(ip net.IP) bool
|
||||
nameserverIPMatcherSet = abool.New()
|
||||
nameserverIPMatcherReady = abool.New()
|
||||
|
||||
packetsAccepted = new(uint64)
|
||||
packetsBlocked = new(uint64)
|
||||
packetsDropped = new(uint64)
|
||||
packetsFailed = new(uint64)
|
||||
|
||||
blockedIPv4 = net.IPv4(0, 0, 0, 17)
|
||||
blockedIPv6 = net.ParseIP("::17")
|
||||
|
||||
ownPID = os.Getpid()
|
||||
)
|
||||
|
||||
func resetSingleConnectionVerdict(connID string) {
|
||||
// Create tracing context.
|
||||
ctx, tracer := log.AddTracer(context.Background())
|
||||
defer tracer.Submit()
|
||||
|
||||
conn, ok := network.GetConnection(connID)
|
||||
if !ok {
|
||||
conn, ok = network.GetDNSConnection(connID)
|
||||
if !ok {
|
||||
tracer.Debugf("filter: could not find re-attributed connection %s for re-evaluation", connID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resetConnectionVerdict(ctx, conn)
|
||||
}
|
||||
|
||||
func resetProfileConnectionVerdict(profileSource, profileID string) {
|
||||
// Create tracing context.
|
||||
ctx, tracer := log.AddTracer(context.Background())
|
||||
defer tracer.Submit()
|
||||
|
||||
// Resetting will force all the connection to be evaluated by the firewall again
|
||||
// this will set new verdicts if configuration was update or spn has been disabled or enabled.
|
||||
tracer.Infof("filter: re-evaluating connections of %s/%s", profileSource, profileID)
|
||||
|
||||
// Re-evaluate all connections.
|
||||
var changedVerdicts int
|
||||
for _, conn := range network.GetAllConnections() {
|
||||
// Check if connection is complete and attributed to the deleted profile.
|
||||
if conn.DataIsComplete() &&
|
||||
conn.ProcessContext.Profile == profileID &&
|
||||
conn.ProcessContext.Source == profileSource {
|
||||
if resetConnectionVerdict(ctx, conn) {
|
||||
changedVerdicts++
|
||||
}
|
||||
}
|
||||
}
|
||||
tracer.Infof("filter: changed verdict on %d connections", changedVerdicts)
|
||||
}
|
||||
|
||||
func resetAllConnectionVerdicts() {
|
||||
// Create tracing context.
|
||||
ctx, tracer := log.AddTracer(context.Background())
|
||||
defer tracer.Submit()
|
||||
|
||||
// Resetting will force all the connection to be evaluated by the firewall again
|
||||
// this will set new verdicts if configuration was update or spn has been disabled or enabled.
|
||||
tracer.Info("filter: re-evaluating all connections")
|
||||
|
||||
// Re-evaluate all connections.
|
||||
var changedVerdicts int
|
||||
for _, conn := range network.GetAllConnections() {
|
||||
// Skip incomplete connections.
|
||||
if !conn.DataIsComplete() {
|
||||
continue
|
||||
}
|
||||
|
||||
if resetConnectionVerdict(ctx, conn) {
|
||||
changedVerdicts++
|
||||
}
|
||||
}
|
||||
tracer.Infof("filter: changed verdict on %d connections", changedVerdicts)
|
||||
}
|
||||
|
||||
func resetConnectionVerdict(ctx context.Context, conn *network.Connection) (verdictChanged bool) {
|
||||
tracer := log.Tracer(ctx)
|
||||
|
||||
// Remove any active prompt as we settings are being re-evaluated.
|
||||
conn.RemovePrompt()
|
||||
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
|
||||
// Update feature flags.
|
||||
if err := conn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
|
||||
tracer.Warningf("filter: failed to update connection feature flags: %s", err)
|
||||
}
|
||||
|
||||
// Skip internal connections:
|
||||
// - Pre-authenticated connections from Portmaster
|
||||
// - Redirected DNS requests
|
||||
// - SPN Uplink to Home Hub
|
||||
if conn.Internal {
|
||||
// tracer.Tracef("filter: skipping internal connection %s", conn)
|
||||
return false
|
||||
}
|
||||
|
||||
tracer.Debugf("filter: re-evaluating verdict of %s", conn)
|
||||
previousVerdict := conn.Verdict
|
||||
|
||||
// Apply privacy filter and check tunneling.
|
||||
FilterConnection(ctx, conn, nil, true, true)
|
||||
|
||||
// Stop existing SPN tunnel if not needed anymore.
|
||||
if conn.Verdict != network.VerdictRerouteToTunnel && conn.TunnelContext != nil {
|
||||
err := conn.TunnelContext.StopTunnel()
|
||||
if err != nil {
|
||||
tracer.Debugf("filter: failed to stopped unneeded tunnel: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save if verdict changed.
|
||||
if conn.Verdict != previousVerdict {
|
||||
err := interception.UpdateVerdictOfConnection(conn)
|
||||
if err != nil {
|
||||
log.Debugf("filter: failed to update connection verdict: %s", err)
|
||||
}
|
||||
conn.Save()
|
||||
tracer.Infof("filter: verdict of connection %s changed from %s to %s", conn, previousVerdict.Verb(), conn.VerdictVerb())
|
||||
|
||||
// Update verdict in OS integration, if an IP connection.
|
||||
if conn.Type == network.IPConnection {
|
||||
err := interception.UpdateVerdictOfConnection(conn)
|
||||
if err != nil {
|
||||
log.Debugf("filter: failed to update connection verdict: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
tracer.Tracef("filter: verdict to connection %s unchanged at %s", conn, conn.VerdictVerb())
|
||||
return false
|
||||
}
|
||||
|
||||
// SetNameserverIPMatcher sets a function that is used to match the internal
|
||||
// nameserver IP(s). Can only bet set once.
|
||||
func SetNameserverIPMatcher(fn func(ip net.IP) bool) error {
|
||||
if !nameserverIPMatcherSet.SetToIf(false, true) {
|
||||
return errors.New("nameserver IP matcher already set")
|
||||
}
|
||||
|
||||
nameserverIPMatcher = fn
|
||||
nameserverIPMatcherReady.Set()
|
||||
return nil
|
||||
}
|
||||
|
||||
func handlePacket(pkt packet.Packet) {
|
||||
// First, check for an existing connection.
|
||||
conn, ok := network.GetConnection(pkt.GetConnectionID())
|
||||
if ok {
|
||||
// Add packet to connection handler queue or apply verdict directly.
|
||||
conn.HandlePacket(pkt)
|
||||
return
|
||||
}
|
||||
|
||||
// Else create new incomplete connection from the packet and start the new handler.
|
||||
conn = network.NewIncompleteConnection(pkt)
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
conn.SetFirewallHandler(fastTrackHandler)
|
||||
|
||||
// Let the new connection handler worker handle the packet.
|
||||
conn.HandlePacket(pkt)
|
||||
}
|
||||
|
||||
// fastTrackedPermit quickly permits certain network critical or internal connections.
|
||||
func fastTrackedPermit(conn *network.Connection, pkt packet.Packet) (verdict network.Verdict, permanent bool) {
|
||||
meta := pkt.Info()
|
||||
|
||||
// Check if packed was already fast-tracked by the OS integration.
|
||||
if pkt.FastTrackedByIntegration() {
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: fast-tracked by OS integration: %s", pkt)
|
||||
return network.VerdictAccept, true
|
||||
}
|
||||
|
||||
// Check if connection was already blocked.
|
||||
if meta.Dst.Equal(blockedIPv4) || meta.Dst.Equal(blockedIPv6) {
|
||||
return network.VerdictBlock, true
|
||||
}
|
||||
|
||||
// Some programs do a network self-check where they connects to the same
|
||||
// IP/Port to test network capabilities.
|
||||
// Eg. dig: https://gitlab.isc.org/isc-projects/bind9/-/issues/1140
|
||||
if meta.SrcPort == meta.DstPort &&
|
||||
meta.Src.Equal(meta.Dst) {
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track network self-check: %s", pkt)
|
||||
return network.VerdictAccept, true
|
||||
|
||||
}
|
||||
|
||||
switch meta.Protocol { //nolint:exhaustive // Checking for specific values only.
|
||||
case packet.ICMP, packet.ICMPv6:
|
||||
// Load packet data.
|
||||
err := pkt.LoadPacketData()
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: failed to load ICMP packet data: %s", err)
|
||||
return network.VerdictAccept, true
|
||||
}
|
||||
|
||||
// Submit to ICMP listener.
|
||||
submitted := netenv.SubmitPacketToICMPListener(pkt)
|
||||
if submitted {
|
||||
// If the packet was submitted to the listener, we must not do a
|
||||
// permanent accept, because then we won't see any future packets of that
|
||||
// connection and thus cannot continue to submit them.
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track tracing ICMP/v6: %s", pkt)
|
||||
return network.VerdictAccept, false
|
||||
}
|
||||
|
||||
// Handle echo request and replies regularly.
|
||||
// Other ICMP packets are considered system business.
|
||||
icmpLayers := pkt.Layers().LayerClass(layers.LayerClassIPControl)
|
||||
switch icmpLayer := icmpLayers.(type) {
|
||||
case *layers.ICMPv4:
|
||||
switch icmpLayer.TypeCode.Type() {
|
||||
case layers.ICMPv4TypeEchoRequest,
|
||||
layers.ICMPv4TypeEchoReply:
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
case *layers.ICMPv6:
|
||||
switch icmpLayer.TypeCode.Type() {
|
||||
case layers.ICMPv6TypeEchoRequest,
|
||||
layers.ICMPv6TypeEchoReply:
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
}
|
||||
|
||||
// Permit all ICMP/v6 packets that are not echo requests or replies.
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting ICMP/v6: %s", pkt)
|
||||
return network.VerdictAccept, true
|
||||
|
||||
case packet.UDP, packet.TCP:
|
||||
switch meta.DstPort {
|
||||
|
||||
case 67, 68, 546, 547:
|
||||
// Always allow DHCP, DHCPv6.
|
||||
|
||||
// DHCP and DHCPv6 must be UDP.
|
||||
if meta.Protocol != packet.UDP {
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// DHCP is only valid in local network scopes.
|
||||
switch netutils.ClassifyIP(meta.Dst) { //nolint:exhaustive // Checking for specific values only.
|
||||
case netutils.HostLocal, netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
|
||||
default:
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// Log and permit.
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting DHCP: %s", pkt)
|
||||
return network.VerdictAccept, true
|
||||
|
||||
case apiPort:
|
||||
// Always allow direct access to the Portmaster API.
|
||||
|
||||
// Portmaster API is TCP only.
|
||||
if meta.Protocol != packet.TCP {
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// Check if the api port is even set.
|
||||
if !apiPortSet {
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// Must be destined for the API IP.
|
||||
if !meta.Dst.Equal(apiIP) {
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// Only fast-track local requests.
|
||||
isMe, err := netenv.IsMyIP(meta.Src)
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
|
||||
return network.VerdictUndecided, false
|
||||
case !isMe:
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// Log and permit.
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting api connection: %s", pkt)
|
||||
return network.VerdictAccept, true
|
||||
|
||||
case 53:
|
||||
// Always allow direct access to the Portmaster Nameserver.
|
||||
// DNS is both UDP and TCP.
|
||||
|
||||
// Check if a nameserver IP matcher is set.
|
||||
if !nameserverIPMatcherReady.IsSet() {
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// Check if packet is destined for a nameserver IP.
|
||||
if !nameserverIPMatcher(meta.Dst) {
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// Only fast-track local requests.
|
||||
isMe, err := netenv.IsMyIP(meta.Src)
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
|
||||
return network.VerdictUndecided, false
|
||||
case !isMe:
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
// Log and permit.
|
||||
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting local dns: %s", pkt)
|
||||
|
||||
// Add to DNS request connections to attribute DNS request if outgoing.
|
||||
if pkt.IsOutbound() {
|
||||
// Assign PID from packet directly, as processing stops after fast-track.
|
||||
conn.PID = pkt.Info().PID
|
||||
network.SaveDNSRequestConnection(conn, pkt)
|
||||
}
|
||||
|
||||
// Accept local DNS, but only make permanent if we have the PID too.
|
||||
return network.VerdictAccept, conn.PID != process.UndefinedProcessID
|
||||
}
|
||||
|
||||
case compat.SystemIntegrationCheckProtocol:
|
||||
if pkt.Info().Dst.Equal(compat.SystemIntegrationCheckDstIP) {
|
||||
compat.SubmitSystemIntegrationCheckPacket(pkt)
|
||||
return network.VerdictDrop, false
|
||||
}
|
||||
}
|
||||
|
||||
return network.VerdictUndecided, false
|
||||
}
|
||||
|
||||
func fastTrackHandler(conn *network.Connection, pkt packet.Packet) {
|
||||
fastTrackedVerdict, permanent := fastTrackedPermit(conn, pkt)
|
||||
if fastTrackedVerdict != network.VerdictUndecided {
|
||||
// Set verdict on connection.
|
||||
conn.Verdict = fastTrackedVerdict
|
||||
|
||||
// Apply verdict to (real) packet.
|
||||
if !pkt.InfoOnly() {
|
||||
issueVerdict(conn, pkt, fastTrackedVerdict, permanent)
|
||||
}
|
||||
|
||||
// Stop handler if permanent.
|
||||
if permanent {
|
||||
conn.SetVerdict(fastTrackedVerdict, "fast-tracked", "", nil)
|
||||
|
||||
// Do not finalize verdict, as we are missing necessary data.
|
||||
conn.StopFirewallHandler()
|
||||
}
|
||||
|
||||
// Do not continue to next handler.
|
||||
return
|
||||
}
|
||||
|
||||
// If packet is not fast-tracked, continue with gathering more information.
|
||||
conn.UpdateFirewallHandler(gatherDataHandler)
|
||||
gatherDataHandler(conn, pkt)
|
||||
}
|
||||
|
||||
func gatherDataHandler(conn *network.Connection, pkt packet.Packet) {
|
||||
// Get process info
|
||||
_ = conn.GatherConnectionInfo(pkt)
|
||||
// Errors are informational and are logged to the context.
|
||||
|
||||
// Run this handler again if data is not yet complete.
|
||||
if !conn.DataIsComplete() {
|
||||
return
|
||||
}
|
||||
|
||||
// Continue to filter handler, when connection data is complete.
|
||||
conn.UpdateFirewallHandler(filterHandler)
|
||||
filterHandler(conn, pkt)
|
||||
}
|
||||
|
||||
func filterHandler(conn *network.Connection, pkt packet.Packet) {
|
||||
// Skip if data is not complete or packet is info-only.
|
||||
if !conn.DataIsComplete() || pkt.InfoOnly() {
|
||||
return
|
||||
}
|
||||
|
||||
filterConnection := true
|
||||
|
||||
// Check for special (internal) connection cases.
|
||||
switch {
|
||||
case !conn.Inbound && localPortIsPreAuthenticated(conn.Entity.Protocol, conn.LocalPort):
|
||||
// Approve connection.
|
||||
conn.Accept("connection by Portmaster", noReasonOptionKey)
|
||||
conn.Internal = true
|
||||
filterConnection = false
|
||||
log.Tracer(pkt.Ctx()).Infof("filter: granting own pre-authenticated connection %s", conn)
|
||||
|
||||
// Redirect outbound DNS packets if enabled,
|
||||
case dnsQueryInterception() &&
|
||||
pkt.IsOutbound() &&
|
||||
pkt.Info().DstPort == 53 &&
|
||||
// that don't match the address of our nameserver,
|
||||
nameserverIPMatcherReady.IsSet() &&
|
||||
!nameserverIPMatcher(pkt.Info().Dst) &&
|
||||
// and are not broadcast queries by us.
|
||||
// Context:
|
||||
// - Unicast queries by the resolver are pre-authenticated.
|
||||
// - Unicast queries by the compat self-check should be redirected.
|
||||
!(conn.Process().Pid == ownPID &&
|
||||
conn.Entity.IPScope == netutils.LocalMulticast):
|
||||
|
||||
// Reroute rogue dns queries back to Portmaster.
|
||||
conn.SetVerdict(network.VerdictRerouteToNameserver, "redirecting rogue dns query", "", nil)
|
||||
conn.Internal = true
|
||||
log.Tracer(pkt.Ctx()).Infof("filter: redirecting dns query %s to Portmaster", conn)
|
||||
|
||||
// Add to DNS request connections to attribute DNS request.
|
||||
network.SaveDNSRequestConnection(conn, pkt)
|
||||
|
||||
// End directly, as no other processing is necessary.
|
||||
conn.StopFirewallHandler()
|
||||
|
||||
issueVerdict(conn, pkt, 0, true)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply privacy filter and check tunneling.
|
||||
FilterConnection(pkt.Ctx(), conn, pkt, filterConnection, true)
|
||||
|
||||
// Decide how to continue handling connection.
|
||||
switch {
|
||||
case conn.Inspecting:
|
||||
log.Tracer(pkt.Ctx()).Trace("filter: start inspecting")
|
||||
conn.SetFirewallHandler(inspectAndVerdictHandler)
|
||||
inspectAndVerdictHandler(conn, pkt)
|
||||
default:
|
||||
conn.StopFirewallHandler()
|
||||
issueVerdict(conn, pkt, 0, true)
|
||||
}
|
||||
}
|
||||
|
||||
// FilterConnection runs all the filtering (and tunneling) procedures.
|
||||
func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet, checkFilter, checkTunnel bool) {
|
||||
// Skip if data is not complete.
|
||||
if !conn.DataIsComplete() {
|
||||
return
|
||||
}
|
||||
|
||||
if checkFilter {
|
||||
if filterEnabled() {
|
||||
log.Tracer(ctx).Trace("filter: starting decision process")
|
||||
decideOnConnection(ctx, conn, pkt)
|
||||
} else {
|
||||
conn.Accept("privacy filter disabled", noReasonOptionKey)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Enable inspection framework again.
|
||||
conn.Inspecting = false
|
||||
|
||||
// TODO: Quick fix for the SPN.
|
||||
// Use inspection framework for proper encryption detection.
|
||||
switch conn.Entity.DstPort() {
|
||||
case
|
||||
22, // SSH
|
||||
443, // HTTPS
|
||||
465, // SMTP-SSL
|
||||
853, // DoT
|
||||
993, // IMAP-SSL
|
||||
995: // POP3-SSL
|
||||
conn.Encrypted = true
|
||||
}
|
||||
|
||||
// Check if connection should be tunneled.
|
||||
if checkTunnel {
|
||||
checkTunneling(ctx, conn)
|
||||
}
|
||||
|
||||
// Request tunneling if no tunnel is set and connection should be tunneled.
|
||||
if conn.Verdict == network.VerdictRerouteToTunnel &&
|
||||
conn.TunnelContext == nil {
|
||||
err := requestTunneling(ctx, conn)
|
||||
if err == nil {
|
||||
conn.ConnectionEstablished = true
|
||||
} else {
|
||||
// Set connection to failed, but keep tunneling data.
|
||||
// The tunneling data makes connection easy to recognize as a failed SPN
|
||||
// connection and the data will help with debugging and displaying in the UI.
|
||||
conn.Failed(fmt.Sprintf("failed to request tunneling: %s", err), "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func verdictHandler(conn *network.Connection, pkt packet.Packet) {
|
||||
// Ignore info-only packets in this handler.
|
||||
if pkt.InfoOnly() {
|
||||
return
|
||||
}
|
||||
|
||||
issueVerdict(conn, pkt, 0, true)
|
||||
}
|
||||
|
||||
func inspectAndVerdictHandler(conn *network.Connection, pkt packet.Packet) {
|
||||
// Ignore info-only packets in this handler.
|
||||
if pkt.InfoOnly() {
|
||||
return
|
||||
}
|
||||
|
||||
// Run inspectors.
|
||||
pktVerdict, continueInspection := inspection.RunInspectors(conn, pkt)
|
||||
if continueInspection {
|
||||
issueVerdict(conn, pkt, pktVerdict, false)
|
||||
return
|
||||
}
|
||||
|
||||
// we are done with inspecting
|
||||
conn.StopFirewallHandler()
|
||||
issueVerdict(conn, pkt, 0, true)
|
||||
}
|
||||
|
||||
func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.Verdict, allowPermanent bool) {
|
||||
// Check if packed was already fast-tracked by the OS integration.
|
||||
if pkt.FastTrackedByIntegration() {
|
||||
return
|
||||
}
|
||||
|
||||
// enable permanent verdict
|
||||
if allowPermanent && !conn.VerdictPermanent {
|
||||
conn.VerdictPermanent = permanentVerdicts()
|
||||
if conn.VerdictPermanent {
|
||||
conn.SaveWhenFinished()
|
||||
}
|
||||
}
|
||||
|
||||
// do not allow to circumvent decision: e.g. to ACCEPT packets from a DROP-ed connection
|
||||
if verdict < conn.Verdict {
|
||||
verdict = conn.Verdict
|
||||
}
|
||||
|
||||
var err error
|
||||
switch verdict {
|
||||
case network.VerdictAccept:
|
||||
atomic.AddUint64(packetsAccepted, 1)
|
||||
if conn.VerdictPermanent {
|
||||
err = pkt.PermanentAccept()
|
||||
} else {
|
||||
err = pkt.Accept()
|
||||
}
|
||||
case network.VerdictBlock:
|
||||
atomic.AddUint64(packetsBlocked, 1)
|
||||
if conn.VerdictPermanent {
|
||||
err = pkt.PermanentBlock()
|
||||
} else {
|
||||
err = pkt.Block()
|
||||
}
|
||||
case network.VerdictDrop:
|
||||
atomic.AddUint64(packetsDropped, 1)
|
||||
if conn.VerdictPermanent {
|
||||
err = pkt.PermanentDrop()
|
||||
} else {
|
||||
err = pkt.Drop()
|
||||
}
|
||||
case network.VerdictRerouteToNameserver:
|
||||
err = pkt.RerouteToNameserver()
|
||||
case network.VerdictRerouteToTunnel:
|
||||
err = pkt.RerouteToTunnel()
|
||||
case network.VerdictFailed:
|
||||
atomic.AddUint64(packetsFailed, 1)
|
||||
err = pkt.Drop()
|
||||
case network.VerdictUndecided, network.VerdictUndeterminable:
|
||||
log.Tracer(pkt.Ctx()).Warningf("filter: tried to apply verdict %s to pkt %s: dropping instead", verdict, pkt)
|
||||
fallthrough
|
||||
default:
|
||||
atomic.AddUint64(packetsDropped, 1)
|
||||
err = pkt.Drop()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Warningf("filter: failed to apply verdict to pkt %s: %s", pkt, err)
|
||||
}
|
||||
}
|
||||
|
||||
// verdictRating rates the privacy and security aspect of verdicts from worst to best.
|
||||
var verdictRating = []network.Verdict{
|
||||
network.VerdictAccept, // Connection allowed in the open.
|
||||
network.VerdictRerouteToTunnel, // Connection allowed, but protected.
|
||||
network.VerdictRerouteToNameserver, // Connection allowed, but resolved via Portmaster.
|
||||
network.VerdictBlock, // Connection blocked, with feedback.
|
||||
network.VerdictDrop, // Connection blocked, without feedback.
|
||||
network.VerdictFailed,
|
||||
network.VerdictUndeterminable,
|
||||
network.VerdictUndecided,
|
||||
}
|
||||
|
||||
// func tunnelHandler(pkt packet.Packet) {
|
||||
// tunnelInfo := GetTunnelInfo(pkt.Info().Dst)
|
||||
// if tunnelInfo == nil {
|
||||
// pkt.Block()
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords())
|
||||
// log.Tracef("filter: rerouting %s to tunnel entry point", pkt)
|
||||
// pkt.RerouteToTunnel()
|
||||
// return
|
||||
// }
|
||||
|
||||
func packetHandler(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case pkt := <-interception.Packets:
|
||||
if pkt != nil {
|
||||
handlePacket(pkt)
|
||||
} else {
|
||||
return errors.New("received nil packet from interception")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bandwidthUpdateHandler(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case bwUpdate := <-interception.BandwidthUpdates:
|
||||
if bwUpdate != nil {
|
||||
// DEBUG:
|
||||
// log.Debugf("filter: bandwidth update: %s", bwUpdate)
|
||||
updateBandwidth(ctx, bwUpdate)
|
||||
} else {
|
||||
return errors.New("received nil bandwidth update from interception")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
|
||||
// Check if update makes sense.
|
||||
if bwUpdate.BytesReceived == 0 && bwUpdate.BytesSent == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Get connection.
|
||||
conn, ok := network.GetConnection(bwUpdate.ConnID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Do not wait for connections that are locked.
|
||||
// TODO: Use atomic operations for updating bandwidth stats.
|
||||
if !conn.TryLock() {
|
||||
// DEBUG:
|
||||
// log.Warningf("filter: failed to lock connection for bandwidth update: %s", conn)
|
||||
return
|
||||
}
|
||||
defer conn.Unlock()
|
||||
|
||||
bytesIn := bwUpdate.BytesReceived
|
||||
bytesOut := bwUpdate.BytesSent
|
||||
|
||||
// Update stats according to method.
|
||||
switch bwUpdate.Method {
|
||||
case packet.Absolute:
|
||||
bytesIn = bwUpdate.BytesReceived - conn.BytesReceived
|
||||
bytesOut = bwUpdate.BytesSent - conn.BytesSent
|
||||
|
||||
conn.BytesReceived = bwUpdate.BytesReceived
|
||||
conn.BytesSent = bwUpdate.BytesSent
|
||||
case packet.Additive:
|
||||
conn.BytesReceived += bwUpdate.BytesReceived
|
||||
conn.BytesSent += bwUpdate.BytesSent
|
||||
default:
|
||||
log.Warningf("filter: unsupported bandwidth update method: %d", bwUpdate.Method)
|
||||
return
|
||||
}
|
||||
|
||||
// Update bandwidth in the netquery module.
|
||||
if netquery.DefaultModule != nil && conn.BandwidthEnabled {
|
||||
if err := netquery.DefaultModule.Store.UpdateBandwidth(
|
||||
ctx,
|
||||
conn.HistoryEnabled,
|
||||
fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile),
|
||||
conn.Process().GetKey(),
|
||||
conn.ID,
|
||||
bytesIn,
|
||||
bytesOut,
|
||||
); err != nil {
|
||||
log.Errorf("filter: failed to persist bandwidth data: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func statLogger(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(10 * time.Second):
|
||||
log.Tracef(
|
||||
"filter: packets accepted %d, blocked %d, dropped %d, failed %d",
|
||||
atomic.LoadUint64(packetsAccepted),
|
||||
atomic.LoadUint64(packetsBlocked),
|
||||
atomic.LoadUint64(packetsDropped),
|
||||
atomic.LoadUint64(packetsFailed),
|
||||
)
|
||||
atomic.StoreUint64(packetsAccepted, 0)
|
||||
atomic.StoreUint64(packetsBlocked, 0)
|
||||
atomic.StoreUint64(packetsDropped, 0)
|
||||
atomic.StoreUint64(packetsFailed, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
111
service/firewall/preauth.go
Normal file
111
service/firewall/preauth.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
)
|
||||
|
||||
var (
|
||||
preAuthenticatedPorts = make(map[string]struct{})
|
||||
preAuthenticatedPortsLock sync.Mutex
|
||||
)
|
||||
|
||||
func init() {
|
||||
resolver.SetLocalAddrFactory(PermittedAddr)
|
||||
netenv.SetLocalAddrFactory(PermittedAddr)
|
||||
}
|
||||
|
||||
// PermittedAddr returns an already permitted local address for the given network for reliable connectivity.
|
||||
// Returns nil in case of error.
|
||||
func PermittedAddr(network string) net.Addr {
|
||||
switch network {
|
||||
case "udp":
|
||||
return PermittedUDPAddr()
|
||||
case "tcp":
|
||||
return PermittedTCPAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PermittedUDPAddr returns an already permitted local udp address for reliable connectivity.
|
||||
// Returns nil in case of error.
|
||||
func PermittedUDPAddr() *net.UDPAddr {
|
||||
preAuthdPort := GetPermittedPort(packet.UDP)
|
||||
if preAuthdPort == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", preAuthdPort))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// PermittedTCPAddr returns an already permitted local tcp address for reliable connectivity.
|
||||
// Returns nil in case of error.
|
||||
func PermittedTCPAddr() *net.TCPAddr {
|
||||
preAuthdPort := GetPermittedPort(packet.TCP)
|
||||
if preAuthdPort == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf(":%d", preAuthdPort))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// GetPermittedPort returns a local port number that is already permitted for communication.
|
||||
// This bypasses the process attribution step to guarantee connectivity.
|
||||
// Communication on the returned port is attributed to the Portmaster.
|
||||
// Every pre-authenticated port is only valid once.
|
||||
// If no unused local port number can be found, it will return 0, which is
|
||||
// expected to trigger automatic port selection by the underlying OS.
|
||||
func GetPermittedPort(protocol packet.IPProtocol) uint16 {
|
||||
port, ok := network.GetUnusedLocalPort(uint8(protocol))
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
preAuthenticatedPortsLock.Lock()
|
||||
defer preAuthenticatedPortsLock.Unlock()
|
||||
|
||||
// Save generated port.
|
||||
key := generateLocalPreAuthKey(uint8(protocol), port)
|
||||
preAuthenticatedPorts[key] = struct{}{}
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
// localPortIsPreAuthenticated checks if the given protocol and port are
|
||||
// pre-authenticated and should be attributed to the Portmaster itself.
|
||||
func localPortIsPreAuthenticated(protocol uint8, port uint16) bool {
|
||||
preAuthenticatedPortsLock.Lock()
|
||||
defer preAuthenticatedPortsLock.Unlock()
|
||||
|
||||
// Check if the given protocol and port are pre-authenticated.
|
||||
key := generateLocalPreAuthKey(protocol, port)
|
||||
_, ok := preAuthenticatedPorts[key]
|
||||
if ok {
|
||||
// Immediately remove pre authenticated port.
|
||||
delete(preAuthenticatedPorts, key)
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// generateLocalPreAuthKey creates a map key for the pre-authenticated ports.
|
||||
func generateLocalPreAuthKey(protocol uint8, port uint16) string {
|
||||
return strconv.Itoa(int(protocol)) + ":" + strconv.Itoa(int(port))
|
||||
}
|
||||
327
service/firewall/prompt.go
Normal file
327
service/firewall/prompt.go
Normal file
@@ -0,0 +1,327 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portmaster/service/intel"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
)
|
||||
|
||||
const (
|
||||
// notification action IDs.
|
||||
allowDomainAll = "allow-domain-all"
|
||||
allowDomainDistinct = "allow-domain-distinct"
|
||||
blockDomainAll = "block-domain-all"
|
||||
blockDomainDistinct = "block-domain-distinct"
|
||||
|
||||
allowIP = "allow-ip"
|
||||
blockIP = "block-ip"
|
||||
allowServingIP = "allow-serving-ip"
|
||||
blockServingIP = "block-serving-ip"
|
||||
|
||||
cancelPrompt = "cancel"
|
||||
)
|
||||
|
||||
var (
|
||||
promptNotificationCreation sync.Mutex
|
||||
|
||||
decisionTimeout int64 = 10 // in seconds
|
||||
)
|
||||
|
||||
type promptData struct {
|
||||
Entity *intel.Entity
|
||||
Profile promptProfile
|
||||
}
|
||||
|
||||
type promptProfile struct {
|
||||
Source string
|
||||
ID string
|
||||
LinkedPath string
|
||||
}
|
||||
|
||||
func prompt(ctx context.Context, conn *network.Connection) {
|
||||
// Create notification.
|
||||
n := createPrompt(ctx, conn)
|
||||
if n == nil {
|
||||
// createPrompt returns nil when no further action should be taken.
|
||||
return
|
||||
}
|
||||
|
||||
// Add prompt to connection.
|
||||
conn.SetPrompt(n)
|
||||
|
||||
// Get decision timeout and make sure it does not exceed the ask timeout.
|
||||
timeout := decisionTimeout
|
||||
if timeout > askTimeout() {
|
||||
timeout = askTimeout()
|
||||
}
|
||||
|
||||
// wait for response/timeout
|
||||
select {
|
||||
case promptResponse := <-n.Response():
|
||||
switch promptResponse {
|
||||
case allowDomainAll, allowDomainDistinct, allowIP, allowServingIP:
|
||||
// Accept
|
||||
conn.Accept("allowed via prompt", profile.CfgOptionEndpointsKey)
|
||||
case "":
|
||||
// Dismissed
|
||||
conn.Deny("prompting canceled, waiting for new decision", profile.CfgOptionDefaultActionKey)
|
||||
default:
|
||||
// Deny
|
||||
conn.Deny("blocked via prompt", profile.CfgOptionEndpointsKey)
|
||||
}
|
||||
|
||||
case <-time.After(time.Duration(timeout) * time.Second):
|
||||
log.Tracer(ctx).Debugf("filter: continuing prompting async")
|
||||
conn.Deny("prompting in progress, please respond to prompt", profile.CfgOptionDefaultActionKey)
|
||||
|
||||
case <-ctx.Done():
|
||||
log.Tracer(ctx).Debugf("filter: aborting prompting because of shutdown")
|
||||
conn.Drop("shutting down", noReasonOptionKey)
|
||||
}
|
||||
}
|
||||
|
||||
// promptIDPrefix is an identifier for privacy filter prompts. This is also used
|
||||
// in the UI, so don't change!
|
||||
const promptIDPrefix = "filter:prompt"
|
||||
|
||||
func createPrompt(ctx context.Context, conn *network.Connection) (n *notifications.Notification) {
|
||||
expires := time.Now().Add(time.Duration(askTimeout()) * time.Second).Unix()
|
||||
|
||||
// Get local profile.
|
||||
layeredProfile := conn.Process().Profile()
|
||||
if layeredProfile == nil {
|
||||
log.Tracer(ctx).Warningf("filter: tried creating prompt for connection without profile")
|
||||
return nil
|
||||
}
|
||||
localProfile := layeredProfile.LocalProfile()
|
||||
if localProfile == nil {
|
||||
log.Tracer(ctx).Warningf("filter: tried creating prompt for connection without local profile")
|
||||
return nil
|
||||
}
|
||||
|
||||
// first check if there is an existing notification for this.
|
||||
// build notification ID
|
||||
var nID string
|
||||
switch {
|
||||
case conn.Inbound, conn.Entity.Domain == "": // connection to/from IP
|
||||
nID = fmt.Sprintf(
|
||||
"%s-%s-%v-%s",
|
||||
promptIDPrefix,
|
||||
localProfile.ID,
|
||||
conn.Inbound,
|
||||
conn.Entity.IP,
|
||||
)
|
||||
default: // connection to domain
|
||||
nID = fmt.Sprintf(
|
||||
"%s-%s-%s",
|
||||
promptIDPrefix,
|
||||
localProfile.ID,
|
||||
conn.Entity.Domain,
|
||||
)
|
||||
}
|
||||
|
||||
// Only handle one notification at a time.
|
||||
promptNotificationCreation.Lock()
|
||||
defer promptNotificationCreation.Unlock()
|
||||
|
||||
n = notifications.Get(nID)
|
||||
|
||||
// If there already is a notification, just update the expiry.
|
||||
if n != nil {
|
||||
// Get notification state and action.
|
||||
n.Lock()
|
||||
state := n.State
|
||||
action := n.SelectedActionID
|
||||
n.Unlock()
|
||||
|
||||
// If the notification is still active, extend and return.
|
||||
// This can happen because user input (prompts changing the endpoint
|
||||
// lists) can happen any time - also between checking the endpoint lists
|
||||
// and now.
|
||||
if state == notifications.Active {
|
||||
n.Update(expires)
|
||||
log.Tracer(ctx).Debugf("filter: updated existing prompt notification")
|
||||
return n
|
||||
}
|
||||
|
||||
// The notification is not active anymore, let's check if there is an
|
||||
// action we can perform.
|
||||
// If there already is an action defined, we won't be fast enough to
|
||||
// receive the action with n.Response(), so we take direct action here.
|
||||
if action != "" {
|
||||
switch action {
|
||||
case allowDomainAll, allowDomainDistinct, allowIP, allowServingIP:
|
||||
conn.Accept("allowed via prompt", profile.CfgOptionEndpointsKey)
|
||||
default: // deny
|
||||
conn.Deny("blocked via prompt", profile.CfgOptionEndpointsKey)
|
||||
}
|
||||
return nil // Do not take further action.
|
||||
}
|
||||
|
||||
// Continue to create a new notification because the previous one is not
|
||||
// active and not actionable.
|
||||
}
|
||||
|
||||
// Reference relevant data for save function
|
||||
entity := conn.Entity
|
||||
// Also needed: localProfile
|
||||
|
||||
// Create new notification.
|
||||
n = ¬ifications.Notification{
|
||||
EventID: nID,
|
||||
Type: notifications.Prompt,
|
||||
Title: "Connection Prompt",
|
||||
Category: "Privacy Filter",
|
||||
ShowOnSystem: askWithSystemNotifications(),
|
||||
EventData: &promptData{
|
||||
Entity: entity,
|
||||
Profile: promptProfile{
|
||||
Source: string(localProfile.Source),
|
||||
ID: localProfile.ID,
|
||||
// LinkedPath is used to enhance the display of the prompt in the UI.
|
||||
// TODO: Using the process path is a workaround. Find a cleaner solution.
|
||||
LinkedPath: conn.Process().Path,
|
||||
},
|
||||
},
|
||||
Expires: expires,
|
||||
}
|
||||
|
||||
// Set action function.
|
||||
n.SetActionFunction(func(_ context.Context, n *notifications.Notification) error {
|
||||
return saveResponse(
|
||||
localProfile,
|
||||
entity,
|
||||
n.SelectedActionID,
|
||||
)
|
||||
})
|
||||
|
||||
// Get name of profile for notification. The profile is read-locked by the firewall handler.
|
||||
profileName := localProfile.Name
|
||||
|
||||
// add message and actions
|
||||
switch {
|
||||
case conn.Inbound:
|
||||
n.Message = fmt.Sprintf("%s wants to accept connections from %s (%d/%d)", profileName, conn.Entity.IP.String(), conn.Entity.Protocol, conn.Entity.Port)
|
||||
n.AvailableActions = []*notifications.Action{
|
||||
{
|
||||
ID: allowServingIP,
|
||||
Text: "Allow",
|
||||
},
|
||||
{
|
||||
ID: blockServingIP,
|
||||
Text: "Block",
|
||||
},
|
||||
}
|
||||
case conn.Entity.Domain == "": // direct connection
|
||||
n.Message = fmt.Sprintf("%s wants to connect to %s (%d/%d)", profileName, conn.Entity.IP.String(), conn.Entity.Protocol, conn.Entity.Port)
|
||||
n.AvailableActions = []*notifications.Action{
|
||||
{
|
||||
ID: allowIP,
|
||||
Text: "Allow",
|
||||
},
|
||||
{
|
||||
ID: blockIP,
|
||||
Text: "Block",
|
||||
},
|
||||
}
|
||||
default: // connection to domain
|
||||
n.Message = fmt.Sprintf("%s wants to connect to %s", profileName, conn.Entity.Domain)
|
||||
n.AvailableActions = []*notifications.Action{
|
||||
{
|
||||
ID: allowDomainAll,
|
||||
Text: "Allow",
|
||||
},
|
||||
{
|
||||
ID: blockDomainAll,
|
||||
Text: "Block",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
n.Save()
|
||||
log.Tracer(ctx).Debugf("filter: sent prompt notification")
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// promptSavingLock makes sure that only one prompt is saved at a time.
|
||||
// Should prompts be persisted in bulk, the next save process might load an
|
||||
// outdated profile and save it, losing config data.
|
||||
var promptSavingLock sync.Mutex
|
||||
|
||||
func saveResponse(p *profile.Profile, entity *intel.Entity, promptResponse string) error {
|
||||
if promptResponse == cancelPrompt {
|
||||
return nil
|
||||
}
|
||||
|
||||
promptSavingLock.Lock()
|
||||
defer promptSavingLock.Unlock()
|
||||
|
||||
// Update the profile if necessary.
|
||||
if p.IsOutdated() {
|
||||
var err error
|
||||
p, err = profile.GetLocalProfile(p.ID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var ep endpoints.Endpoint
|
||||
switch promptResponse {
|
||||
case allowDomainAll:
|
||||
ep = &endpoints.EndpointDomain{
|
||||
EndpointBase: endpoints.EndpointBase{Permitted: true},
|
||||
OriginalValue: "." + entity.Domain,
|
||||
}
|
||||
case allowDomainDistinct:
|
||||
ep = &endpoints.EndpointDomain{
|
||||
EndpointBase: endpoints.EndpointBase{Permitted: true},
|
||||
OriginalValue: entity.Domain,
|
||||
}
|
||||
case blockDomainAll:
|
||||
ep = &endpoints.EndpointDomain{
|
||||
EndpointBase: endpoints.EndpointBase{Permitted: false},
|
||||
OriginalValue: "." + entity.Domain,
|
||||
}
|
||||
case blockDomainDistinct:
|
||||
ep = &endpoints.EndpointDomain{
|
||||
EndpointBase: endpoints.EndpointBase{Permitted: false},
|
||||
OriginalValue: entity.Domain,
|
||||
}
|
||||
case allowIP, allowServingIP:
|
||||
ep = &endpoints.EndpointIP{
|
||||
EndpointBase: endpoints.EndpointBase{Permitted: true},
|
||||
IP: entity.IP,
|
||||
}
|
||||
case blockIP, blockServingIP:
|
||||
ep = &endpoints.EndpointIP{
|
||||
EndpointBase: endpoints.EndpointBase{Permitted: false},
|
||||
IP: entity.IP,
|
||||
}
|
||||
case cancelPrompt:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown prompt response: %s", promptResponse)
|
||||
}
|
||||
|
||||
switch promptResponse {
|
||||
case allowServingIP, blockServingIP:
|
||||
p.AddServiceEndpoint(ep.String())
|
||||
log.Infof("filter: added incoming rule to profile %s (LP Rev. %d): %q",
|
||||
p, p.LayeredProfile().RevisionCnt(), ep.String())
|
||||
default:
|
||||
p.AddEndpoint(ep.String())
|
||||
log.Infof("filter: added outgoing rule to profile %s (LP Rev. %d): %q",
|
||||
p, p.LayeredProfile().RevisionCnt(), ep.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
215
service/firewall/tunnel.go
Normal file
215
service/firewall/tunnel.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/intel"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
"github.com/safing/portmaster/service/resolver"
|
||||
"github.com/safing/portmaster/spn/captain"
|
||||
"github.com/safing/portmaster/spn/crew"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/safing/portmaster/spn/sluice"
|
||||
)
|
||||
|
||||
func checkTunneling(ctx context.Context, conn *network.Connection) {
|
||||
// Check if the connection should be tunneled at all.
|
||||
switch {
|
||||
case !tunnelEnabled():
|
||||
// Tunneling is disabled.
|
||||
return
|
||||
case !conn.Entity.IPScope.IsGlobal():
|
||||
// Can't tunnel Local/LAN connections.
|
||||
return
|
||||
case conn.Inbound:
|
||||
// Can't tunnel incoming connections.
|
||||
return
|
||||
case conn.Verdict != network.VerdictAccept:
|
||||
// Connection will be blocked.
|
||||
return
|
||||
case conn.IPProtocol != packet.TCP && conn.IPProtocol != packet.UDP:
|
||||
// Unsupported protocol.
|
||||
return
|
||||
case conn.Process().Pid == ownPID:
|
||||
// Bypass tunneling for certain own connections.
|
||||
switch {
|
||||
case !captain.ClientReady():
|
||||
return
|
||||
case captain.IsExcepted(conn.Entity.IP):
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check more extensively for Local/LAN connections.
|
||||
localNet, err := netenv.GetLocalNetwork(conn.Entity.IP)
|
||||
if err != nil {
|
||||
log.Warningf("firewall: failed to check if %s is in my net: %s", conn.Entity.IP, err)
|
||||
} else if localNet != nil {
|
||||
// With IPv6, just checking the IP scope is not enough, as the host very
|
||||
// likely has a public IPv6 address.
|
||||
// Don't tunnel LAN connections.
|
||||
|
||||
// TODO: We currently don't check the full LAN scope, but only the
|
||||
// broadcast domain of the host - ie. the networks that the host is
|
||||
// directly attached to.
|
||||
return
|
||||
}
|
||||
|
||||
// Get profile.
|
||||
layeredProfile := conn.Process().Profile()
|
||||
if layeredProfile == nil {
|
||||
conn.Failed("no profile set", "")
|
||||
return
|
||||
}
|
||||
|
||||
// Update profile.
|
||||
if layeredProfile.NeedsUpdate() {
|
||||
// Update revision counter in connection.
|
||||
conn.ProfileRevisionCounter = layeredProfile.Update(
|
||||
conn.Process().MatchingData(),
|
||||
conn.Process().CreateProfileCallback,
|
||||
)
|
||||
conn.SaveWhenFinished()
|
||||
} else {
|
||||
// Check if the revision counter of the connection needs updating.
|
||||
revCnt := layeredProfile.RevisionCnt()
|
||||
if conn.ProfileRevisionCounter != revCnt {
|
||||
conn.ProfileRevisionCounter = revCnt
|
||||
conn.SaveWhenFinished()
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tunneling is enabled for this app at all.
|
||||
if !layeredProfile.UseSPN() {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if tunneling is enabled for entity.
|
||||
conn.Entity.FetchData(ctx)
|
||||
result, _ := layeredProfile.MatchSPNUsagePolicy(ctx, conn.Entity)
|
||||
switch result {
|
||||
case endpoints.MatchError:
|
||||
conn.Failed("failed to check SPN rules", profile.CfgOptionSPNUsagePolicyKey)
|
||||
return
|
||||
case endpoints.Denied:
|
||||
return
|
||||
case endpoints.Permitted, endpoints.NoMatch:
|
||||
// Continue
|
||||
}
|
||||
|
||||
// Tunnel all the things!
|
||||
conn.SaveWhenFinished()
|
||||
|
||||
// Check if ready.
|
||||
if !captain.ClientReady() {
|
||||
// Block connection as SPN is not ready yet.
|
||||
log.Tracer(ctx).Trace("SPN not ready for tunneling")
|
||||
conn.Failed("SPN not ready for tunneling", "")
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetVerdictDirectly(network.VerdictRerouteToTunnel)
|
||||
conn.Tunneled = true
|
||||
}
|
||||
|
||||
func requestTunneling(ctx context.Context, conn *network.Connection) error {
|
||||
// Get profile.
|
||||
layeredProfile := conn.Process().Profile()
|
||||
if layeredProfile == nil {
|
||||
return errors.New("no profile set")
|
||||
}
|
||||
|
||||
// Get tunnel options.
|
||||
conn.TunnelOpts = DeriveTunnelOptions(layeredProfile, conn.Process(), conn.Entity, conn.Encrypted)
|
||||
|
||||
// Queue request in sluice.
|
||||
err := sluice.AwaitRequest(conn, crew.HandleSluiceRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracer(ctx).Trace("filter: tunneling requested")
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
navigator.DeriveTunnelOptions = func(lp *profile.LayeredProfile, destination *intel.Entity, connEncrypted bool) *navigator.Options {
|
||||
return DeriveTunnelOptions(lp, nil, destination, connEncrypted)
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveTunnelOptions derives and returns the tunnel options from the connection and profile.
|
||||
func DeriveTunnelOptions(lp *profile.LayeredProfile, proc *process.Process, destination *intel.Entity, connEncrypted bool) *navigator.Options {
|
||||
// Set options.
|
||||
tunnelOpts := &navigator.Options{
|
||||
Transit: &navigator.TransitHubOptions{
|
||||
HubPolicies: lp.StackedTransitHubPolicies(),
|
||||
},
|
||||
Destination: &navigator.DestinationHubOptions{
|
||||
HubPolicies: lp.StackedExitHubPolicies(),
|
||||
CheckHubPolicyWith: destination,
|
||||
},
|
||||
RoutingProfile: lp.SPNRoutingAlgorithm(),
|
||||
}
|
||||
if !connEncrypted {
|
||||
tunnelOpts.Destination.Regard = tunnelOpts.Destination.Regard.Add(navigator.StateTrusted)
|
||||
// TODO: Add this when all Hubs are on v0.6.21+
|
||||
// tunnelOpts.Destination.Regard = tunnelOpts.Destination.Regard.Add(navigator.StateAllowUnencrypted)
|
||||
}
|
||||
|
||||
// Add required verified owners if community nodes should not be used.
|
||||
if !useCommunityNodes() {
|
||||
tunnelOpts.Transit.RequireVerifiedOwners = captain.NonCommunityVerifiedOwners
|
||||
tunnelOpts.Destination.RequireVerifiedOwners = captain.NonCommunityVerifiedOwners
|
||||
}
|
||||
|
||||
// Get routing profile for checking for upgrades.
|
||||
routingProfile := navigator.GetRoutingProfile(tunnelOpts.RoutingProfile)
|
||||
|
||||
// If we have any exit hub policies, we must be able to hop in order to follow the policy.
|
||||
// Switch to single-hop routing to allow for routing with hub selection.
|
||||
if routingProfile.MaxHops <= 1 && navigator.HubPoliciesAreSet(tunnelOpts.Destination.HubPolicies) {
|
||||
tunnelOpts.RoutingProfile = navigator.RoutingProfileSingleHopID
|
||||
}
|
||||
|
||||
// If the current home node is not trusted, then upgrade at least to two hops.
|
||||
if routingProfile.MinHops < 2 {
|
||||
homeNode, _ := navigator.Main.GetHome()
|
||||
if homeNode != nil && !homeNode.State.Has(navigator.StateTrusted) {
|
||||
tunnelOpts.RoutingProfile = navigator.RoutingProfileDoubleHopID
|
||||
}
|
||||
}
|
||||
|
||||
// Special handling for the internal DNS resolver.
|
||||
if proc != nil && proc.Pid == ownPID && resolver.IsResolverAddress(destination.IP, destination.Port) {
|
||||
dnsExitHubPolicy, err := captain.GetDNSExitHubPolicy()
|
||||
if err != nil {
|
||||
log.Errorf("firewall: failed to get dns exit hub policy: %s", err)
|
||||
}
|
||||
|
||||
if err == nil && dnsExitHubPolicy.IsSet() {
|
||||
// Apply the dns exit hub policy, if set.
|
||||
tunnelOpts.Destination.HubPolicies = []endpoints.Endpoints{dnsExitHubPolicy}
|
||||
// Use the routing algorithm from the profile, as the home profile won't work with the policy.
|
||||
tunnelOpts.RoutingProfile = lp.SPNRoutingAlgorithm()
|
||||
// Raise the routing algorithm at least to single-hop.
|
||||
if tunnelOpts.RoutingProfile == navigator.RoutingProfileHomeID {
|
||||
tunnelOpts.RoutingProfile = navigator.RoutingProfileSingleHopID
|
||||
}
|
||||
} else {
|
||||
// Disable any policies for the internal DNS resolver.
|
||||
tunnelOpts.Destination.HubPolicies = nil
|
||||
// Always use the home routing profile for the internal DNS resolver.
|
||||
tunnelOpts.RoutingProfile = navigator.RoutingProfileHomeID
|
||||
}
|
||||
}
|
||||
|
||||
return tunnelOpts
|
||||
}
|
||||
103
service/intel/block_reason.go
Normal file
103
service/intel/block_reason.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package intel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/nameserver/nsutil"
|
||||
)
|
||||
|
||||
// ListMatch represents an entity that has been
|
||||
// matched against filterlists.
|
||||
type ListMatch struct {
|
||||
Entity string
|
||||
ActiveLists []string
|
||||
InactiveLists []string
|
||||
}
|
||||
|
||||
func (lm *ListMatch) String() string {
|
||||
inactive := ""
|
||||
if len(lm.InactiveLists) > 0 {
|
||||
inactive = " and in deactivated lists " + strings.Join(lm.InactiveLists, ", ")
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"%s in activated lists %s%s",
|
||||
lm.Entity,
|
||||
strings.Join(lm.ActiveLists, ","),
|
||||
inactive,
|
||||
)
|
||||
}
|
||||
|
||||
// ListBlockReason is a list of list matches.
|
||||
type ListBlockReason []ListMatch
|
||||
|
||||
func (br ListBlockReason) String() string {
|
||||
if len(br) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
matches := make([]string, len(br))
|
||||
for idx, lm := range br {
|
||||
matches[idx] = lm.String()
|
||||
}
|
||||
|
||||
return strings.Join(matches, " and ")
|
||||
}
|
||||
|
||||
// Context returns br wrapped into a map. It implements
|
||||
// the endpoints.Reason interface.
|
||||
func (br ListBlockReason) Context() interface{} {
|
||||
return br
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the list block reason into a map
|
||||
// prefixed with filterlists.
|
||||
func (br ListBlockReason) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(map[string]interface{}{
|
||||
// we convert to []ListMatch to avoid recursing
|
||||
// here.
|
||||
"filterlists": []ListMatch(br),
|
||||
})
|
||||
}
|
||||
|
||||
// GetExtraRRs implements the nsutil.RRProvider interface
|
||||
// and adds additional TXT records justifying the reason
|
||||
// the request was blocked.
|
||||
func (br ListBlockReason) GetExtraRRs(ctx context.Context, _ *dns.Msg) []dns.RR {
|
||||
rrs := make([]dns.RR, 0, len(br))
|
||||
|
||||
for _, lm := range br {
|
||||
blockedBy, err := nsutil.MakeMessageRecord(log.InfoLevel, fmt.Sprintf(
|
||||
"%s is blocked by filter lists %s",
|
||||
lm.Entity,
|
||||
strings.Join(lm.ActiveLists, ", "),
|
||||
))
|
||||
if err == nil {
|
||||
rrs = append(rrs, blockedBy)
|
||||
} else {
|
||||
log.Tracer(ctx).Errorf("intel: failed to create TXT RR for block reason: %s", err)
|
||||
}
|
||||
|
||||
if len(lm.InactiveLists) > 0 {
|
||||
wouldBeBlockedBy, err := nsutil.MakeMessageRecord(log.InfoLevel, fmt.Sprintf(
|
||||
"%s would be blocked by filter lists %s",
|
||||
lm.Entity,
|
||||
strings.Join(lm.InactiveLists, ", "),
|
||||
))
|
||||
if err == nil {
|
||||
rrs = append(rrs, wouldBeBlockedBy)
|
||||
} else {
|
||||
log.Tracer(ctx).Errorf("intel: failed to create TXT RR for block reason: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return rrs
|
||||
}
|
||||
|
||||
var _ nsutil.RRProvider = ListBlockReason(nil)
|
||||
55
service/intel/customlists/config.go
Normal file
55
service/intel/customlists/config.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package customlists
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/config"
|
||||
)
|
||||
|
||||
var (
|
||||
// CfgOptionCustomListFileKey is the config key for custom filter list file.
|
||||
CfgOptionCustomListFileKey = "filter/customListFile"
|
||||
cfgOptionCustomListFileOrder = 35
|
||||
cfgOptionCustomListCategoryAnnotation = "Filter Lists"
|
||||
)
|
||||
|
||||
var getFilePath config.StringOption
|
||||
|
||||
func registerConfig() error {
|
||||
help := `The file (.txt) is checked every couple minutes and will be automatically reloaded when it has changed.
|
||||
|
||||
Entries (one per line) may be one of:
|
||||
- Domain: "example.com"
|
||||
- IP Address: "10.0.0.1"
|
||||
- Country Code (based on IP): "US"
|
||||
- AS (Autonomous System): "AS1234"
|
||||
|
||||
Everything after the first element of a line, comments starting with a '#', and empty lines are ignored.
|
||||
The settings "Block Subdomains of Filter List Entries" and "Block Domain Aliases" also apply to the custom filter list.
|
||||
Lists in the "Hosts" format are not supported.
|
||||
|
||||
Please note that the custom filter list is fully loaded into memory. This can have a negative impact on your device if big lists are loaded.`
|
||||
|
||||
// Register a setting for the file path in the ui
|
||||
err := config.Register(&config.Option{
|
||||
Name: "Custom Filter List",
|
||||
Key: CfgOptionCustomListFileKey,
|
||||
Description: "Specify the file path to a custom filter list (.txt), which will be automatically refreshed. Any connections matching a domain, IP address, Country or ASN in the file will be blocked.",
|
||||
Help: help,
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
DefaultValue: "",
|
||||
RequiresRestart: false,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionCustomListFileOrder,
|
||||
config.CategoryAnnotation: cfgOptionCustomListCategoryAnnotation,
|
||||
config.DisplayHintAnnotation: config.DisplayHintFilePicker,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
getFilePath = config.GetAsString(CfgOptionCustomListFileKey, "")
|
||||
|
||||
return nil
|
||||
}
|
||||
195
service/intel/customlists/lists.go
Normal file
195
service/intel/customlists/lists.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package customlists
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
)
|
||||
|
||||
var (
|
||||
countryCodesFilterList map[string]struct{}
|
||||
ipAddressesFilterList map[string]struct{}
|
||||
autonomousSystemsFilterList map[uint]struct{}
|
||||
domainsFilterList map[string]struct{}
|
||||
)
|
||||
|
||||
const (
|
||||
rationForInvalidLinesUntilWarning = 0.1
|
||||
parseStatusNotificationID = "customlists:parse-status"
|
||||
parseWarningNotificationID = "customlists:parse-warning"
|
||||
zeroIPNotificationID = "customlists:too-many-zero-ips"
|
||||
)
|
||||
|
||||
func initFilterLists() {
|
||||
countryCodesFilterList = make(map[string]struct{})
|
||||
ipAddressesFilterList = make(map[string]struct{})
|
||||
autonomousSystemsFilterList = make(map[uint]struct{})
|
||||
domainsFilterList = make(map[string]struct{})
|
||||
}
|
||||
|
||||
// IsLoaded returns whether a custom filter list is loaded.
|
||||
func IsLoaded() bool {
|
||||
filterListLock.RLock()
|
||||
defer filterListLock.RUnlock()
|
||||
|
||||
switch {
|
||||
case len(domainsFilterList) > 0:
|
||||
return true
|
||||
case len(ipAddressesFilterList) > 0:
|
||||
return true
|
||||
case len(countryCodesFilterList) > 0:
|
||||
return true
|
||||
case len(autonomousSystemsFilterList) > 0:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parseFile(filePath string) error {
|
||||
// Reset all maps, previous (if any) settings will be lost.
|
||||
for key := range countryCodesFilterList {
|
||||
delete(countryCodesFilterList, key)
|
||||
}
|
||||
for key := range ipAddressesFilterList {
|
||||
delete(ipAddressesFilterList, key)
|
||||
}
|
||||
for key := range autonomousSystemsFilterList {
|
||||
delete(autonomousSystemsFilterList, key)
|
||||
}
|
||||
for key := range domainsFilterList {
|
||||
delete(domainsFilterList, key)
|
||||
}
|
||||
|
||||
// Ignore empty file path.
|
||||
if filePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Open the file if possible
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
log.Warningf("intel/customlists: failed to parse file %s", err)
|
||||
module.Warning(parseWarningNotificationID, "Failed to open custom filter list", err.Error())
|
||||
return err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
var allLinesCount uint64
|
||||
var invalidLinesCount uint64
|
||||
|
||||
// Read filter file line by line.
|
||||
scanner := bufio.NewScanner(file)
|
||||
// The scanner will error out if the line is greater than 64K, in this case it is enough.
|
||||
for scanner.Scan() {
|
||||
allLinesCount++
|
||||
// Parse and count invalid lines (comment, empty lines, zero IPs...)
|
||||
if !parseLine(scanner.Text()) {
|
||||
invalidLinesCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Check for scanner error.
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
invalidLinesRation := float32(invalidLinesCount) / float32(allLinesCount)
|
||||
|
||||
if invalidLinesRation > rationForInvalidLinesUntilWarning {
|
||||
log.Warning("intel/customlists: Too many invalid lines")
|
||||
module.Warning(zeroIPNotificationID, "Custom filter list has many invalid lines",
|
||||
fmt.Sprintf(`%d out of %d lines are invalid.
|
||||
Check if you are using the correct file format and if the path to the custom filter list is correct.`, invalidLinesCount, allLinesCount))
|
||||
} else {
|
||||
module.Resolve(zeroIPNotificationID)
|
||||
}
|
||||
|
||||
allEntriesCount := len(domainsFilterList) + len(ipAddressesFilterList) + len(autonomousSystemsFilterList) + len(countryCodesFilterList)
|
||||
log.Infof("intel/customlists: loaded %d entries from %s", allEntriesCount, filePath)
|
||||
|
||||
notifications.NotifyInfo(parseStatusNotificationID,
|
||||
"Custom filter list loaded successfully.",
|
||||
fmt.Sprintf(`Custom filter list loaded from file %s:
|
||||
%d Domains
|
||||
%d IPs
|
||||
%d Autonomous Systems
|
||||
%d Countries`,
|
||||
filePath,
|
||||
len(domainsFilterList),
|
||||
len(ipAddressesFilterList),
|
||||
len(autonomousSystemsFilterList),
|
||||
len(countryCodesFilterList)))
|
||||
|
||||
module.Resolve(parseWarningNotificationID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseLine(line string) (valid bool) {
|
||||
// Everything after the first field will be ignored.
|
||||
fields := strings.Fields(line)
|
||||
|
||||
// Ignore empty lines.
|
||||
if len(fields) == 0 {
|
||||
return true // Not an entry, but a valid line.
|
||||
}
|
||||
|
||||
field := fields[0]
|
||||
|
||||
// Ignore comments
|
||||
if strings.HasPrefix(field, "#") {
|
||||
return true // Not an entry, but a valid line.
|
||||
}
|
||||
|
||||
// Go through all possible field types.
|
||||
// Parsing is ordered by
|
||||
// 1. Parsing options (ie. the domain has most variation and goes last.)
|
||||
// 2. Speed
|
||||
|
||||
// Check if it'a a country code.
|
||||
if isCountryCode(field) {
|
||||
countryCodesFilterList[field] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's a Autonomous system (example AS123).
|
||||
if isAutonomousSystem(field) {
|
||||
asNumber, err := strconv.ParseUint(field[2:], 10, 32)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
autonomousSystemsFilterList[uint(asNumber)] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to parse IP address.
|
||||
ip := net.ParseIP(field)
|
||||
if ip != nil {
|
||||
// Check for zero ip.
|
||||
if net.IP.Equal(ip, net.IPv4zero) || net.IP.Equal(ip, net.IPv6zero) {
|
||||
return false
|
||||
}
|
||||
|
||||
ipAddressesFilterList[ip.String()] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's a domain.
|
||||
domain := dns.Fqdn(field)
|
||||
if netutils.IsValidFqdn(domain) {
|
||||
domainsFilterList[domain] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
207
service/intel/customlists/module.go
Normal file
207
service/intel/customlists/module.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package customlists
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/publicsuffix"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
var module *modules.Module
|
||||
|
||||
const (
|
||||
configModuleName = "config"
|
||||
configChangeEvent = "config change"
|
||||
)
|
||||
|
||||
// Helper variables for parsing the input file.
|
||||
var (
|
||||
isCountryCode = regexp.MustCompile("^[A-Z]{2}$").MatchString
|
||||
isAutonomousSystem = regexp.MustCompile(`^AS[0-9]+$`).MatchString
|
||||
)
|
||||
|
||||
var (
|
||||
filterListFilePath string
|
||||
filterListFileModifiedTime time.Time
|
||||
|
||||
filterListLock sync.RWMutex
|
||||
parserTask *modules.Task
|
||||
|
||||
// ErrNotConfigured is returned when updating the custom filter list, but it
|
||||
// is not configured.
|
||||
ErrNotConfigured = errors.New("custom filter list not configured")
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("customlists", prep, start, nil, "base")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
initFilterLists()
|
||||
|
||||
// Register the config in the ui.
|
||||
err := registerConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register api endpoint for updating the filter list.
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "customlists/update",
|
||||
Write: api.PermitUser,
|
||||
BelongsTo: module,
|
||||
ActionFunc: func(ar *api.Request) (msg string, err error) {
|
||||
errCheck := checkAndUpdateFilterList()
|
||||
if errCheck != nil {
|
||||
return "", errCheck
|
||||
}
|
||||
return "Custom filter list loaded successfully.", nil
|
||||
},
|
||||
Name: "Update custom filter list",
|
||||
Description: "Reload the filter list from the configured file.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
// Register to hook to update after config change.
|
||||
if err := module.RegisterEventHook(
|
||||
configModuleName,
|
||||
configChangeEvent,
|
||||
"update custom filter list",
|
||||
func(ctx context.Context, obj interface{}) error {
|
||||
if err := checkAndUpdateFilterList(); !errors.Is(err, ErrNotConfigured) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create parser task and enqueue for execution. "checkAndUpdateFilterList" will schedule the next execution.
|
||||
parserTask = module.NewTask("intel/customlists:file-update-check", func(context.Context, *modules.Task) error {
|
||||
_ = checkAndUpdateFilterList()
|
||||
return nil
|
||||
}).Schedule(time.Now().Add(20 * time.Second))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkAndUpdateFilterList() error {
|
||||
filterListLock.Lock()
|
||||
defer filterListLock.Unlock()
|
||||
|
||||
// Get path and return error if empty
|
||||
filePath := getFilePath()
|
||||
if filePath == "" {
|
||||
return ErrNotConfigured
|
||||
}
|
||||
|
||||
// Schedule next update check
|
||||
parserTask.Schedule(time.Now().Add(1 * time.Minute))
|
||||
|
||||
// Try to get file info
|
||||
modifiedTime := time.Now()
|
||||
if fileInfo, err := os.Stat(filePath); err == nil {
|
||||
modifiedTime = fileInfo.ModTime()
|
||||
}
|
||||
|
||||
// Check if file path has changed or if modified time has changed
|
||||
if filterListFilePath != filePath || !filterListFileModifiedTime.Equal(modifiedTime) {
|
||||
err := parseFile(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filterListFileModifiedTime = modifiedTime
|
||||
filterListFilePath = filePath
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookupIP checks if the IP address is in a custom filter list.
|
||||
func LookupIP(ip net.IP) bool {
|
||||
filterListLock.RLock()
|
||||
defer filterListLock.RUnlock()
|
||||
|
||||
_, ok := ipAddressesFilterList[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
// LookupDomain checks if the Domain is in a custom filter list.
|
||||
func LookupDomain(fullDomain string, filterSubdomains bool) (bool, string) {
|
||||
filterListLock.RLock()
|
||||
defer filterListLock.RUnlock()
|
||||
|
||||
if filterSubdomains {
|
||||
// Check if domain is in the list and all its subdomains.
|
||||
listOfDomains := splitDomain(fullDomain)
|
||||
for _, domain := range listOfDomains {
|
||||
_, ok := domainsFilterList[domain]
|
||||
if ok {
|
||||
return true, domain
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Check only if the domain is in the list
|
||||
_, ok := domainsFilterList[fullDomain]
|
||||
return ok, fullDomain
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// LookupASN checks if the Autonomous system number is in a custom filter list.
|
||||
func LookupASN(number uint) bool {
|
||||
filterListLock.RLock()
|
||||
defer filterListLock.RUnlock()
|
||||
|
||||
_, ok := autonomousSystemsFilterList[number]
|
||||
return ok
|
||||
}
|
||||
|
||||
// LookupCountry checks if the country code is in a custom filter list.
|
||||
func LookupCountry(countryCode string) bool {
|
||||
filterListLock.RLock()
|
||||
defer filterListLock.RUnlock()
|
||||
|
||||
_, ok := countryCodesFilterList[countryCode]
|
||||
return ok
|
||||
}
|
||||
|
||||
func splitDomain(domain string) []string {
|
||||
domain = strings.Trim(domain, ".")
|
||||
suffix, _ := publicsuffix.PublicSuffix(domain)
|
||||
if suffix == domain {
|
||||
return []string{domain}
|
||||
}
|
||||
|
||||
domainWithoutSuffix := domain[:len(domain)-len(suffix)]
|
||||
domainWithoutSuffix = strings.Trim(domainWithoutSuffix, ".")
|
||||
|
||||
splitted := strings.FieldsFunc(domainWithoutSuffix, func(r rune) bool {
|
||||
return r == '.'
|
||||
})
|
||||
|
||||
domains := make([]string, 0, len(splitted))
|
||||
for idx := range splitted {
|
||||
|
||||
d := strings.Join(splitted[idx:], ".") + "." + suffix
|
||||
if d[len(d)-1] != '.' {
|
||||
d += "."
|
||||
}
|
||||
domains = append(domains, d)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
614
service/intel/entity.go
Normal file
614
service/intel/entity.go
Normal file
@@ -0,0 +1,614 @@
|
||||
package intel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/publicsuffix"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/intel/filterlists"
|
||||
"github.com/safing/portmaster/service/intel/geoip"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
)
|
||||
|
||||
// Entity describes a remote endpoint in many different ways.
|
||||
// It embeddes a sync.Mutex but none of the endpoints own
|
||||
// functions performs locking. The caller MUST ENSURE
|
||||
// proper locking and synchronization when accessing
|
||||
// any properties of Entity.
|
||||
type Entity struct { //nolint:maligned
|
||||
sync.Mutex
|
||||
|
||||
// lists exist for most entity information and
|
||||
// we need to know which one we loaded
|
||||
domainListLoaded bool
|
||||
ipListLoaded bool
|
||||
countryListLoaded bool
|
||||
asnListLoaded bool
|
||||
reverseResolveEnabled bool
|
||||
resolveSubDomainLists bool
|
||||
checkCNAMEs bool
|
||||
|
||||
// IP is the IP address of the connection. If domain is
|
||||
// set, IP has been resolved by following all CNAMEs.
|
||||
IP net.IP
|
||||
|
||||
// IPScope holds the network scope of the IP.
|
||||
// For DNS requests, this signifies in which scope the DNS request was resolved.
|
||||
IPScope netutils.IPScope
|
||||
|
||||
// Protocol is the protcol number used by the connection.
|
||||
Protocol uint8
|
||||
|
||||
// Port is the remote port of the connection
|
||||
Port uint16
|
||||
|
||||
// dstPort is the destination port of the connection
|
||||
dstPort uint16
|
||||
|
||||
// Domain is the target domain of the connection.
|
||||
Domain string
|
||||
|
||||
// ReverseDomain is the domain the IP address points to. This is only
|
||||
// resolved and populated when needed.
|
||||
ReverseDomain string
|
||||
|
||||
// CNAME is a list of domain names that have been
|
||||
// resolved for Domain.
|
||||
CNAME []string
|
||||
|
||||
// Country holds the country the IP address (ASN) is
|
||||
// located in.
|
||||
Country string
|
||||
|
||||
// Coordinates holds the approximate coordinates of the IP address.
|
||||
Coordinates *geoip.Coordinates
|
||||
|
||||
// ASN holds the autonomous system number of the IP.
|
||||
ASN uint
|
||||
|
||||
// ASOrg holds the owner's name of the autonomous system.
|
||||
ASOrg string
|
||||
|
||||
// LocationError holds an error message if fetching the location failed.
|
||||
LocationError string
|
||||
|
||||
location *geoip.Location
|
||||
|
||||
// BlockedByLists holds list source IDs that
|
||||
// are used to block the entity.
|
||||
BlockedByLists []string
|
||||
|
||||
// BlockedEntities holds a list of entities that
|
||||
// have been blocked. Values can be used as a key
|
||||
// for the ListOccurences map.
|
||||
BlockedEntities []string
|
||||
|
||||
// ListOccurences is a map that matches an entity (Domain, IPs, ASN, Country, Sub-domain)
|
||||
// to a list of sources where the entity has been observed in.
|
||||
ListOccurences map[string][]string
|
||||
|
||||
// ListsError holds an error message if fetching the lists failed.
|
||||
ListsError string
|
||||
|
||||
// we only load each data above at most once
|
||||
fetchLocationOnce sync.Once
|
||||
reverseResolveOnce sync.Once
|
||||
loadDomainListOnce sync.Once
|
||||
loadIPListOnce sync.Once
|
||||
loadCountryListOnce sync.Once
|
||||
loadAsnListOnce sync.Once
|
||||
}
|
||||
|
||||
// Init initializes internal metadata about the entity.
|
||||
// If the entity does not describe a destination, you can supply a different
|
||||
// destination port for endpoint matching.
|
||||
// It returns the entity itself for single line formatting.
|
||||
func (e *Entity) Init(dstPort uint16) *Entity {
|
||||
// Get IP scope.
|
||||
if e.IP != nil {
|
||||
e.IPScope = netutils.GetIPScope(e.IP)
|
||||
} else {
|
||||
e.IPScope = netutils.Undefined
|
||||
}
|
||||
|
||||
// Set dst port to given value or fall back to entity.
|
||||
if dstPort > 0 {
|
||||
e.dstPort = dstPort
|
||||
} else {
|
||||
e.dstPort = e.Port
|
||||
}
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
// DstPort returns the destination port.
|
||||
func (e *Entity) DstPort() uint16 {
|
||||
return e.dstPort
|
||||
}
|
||||
|
||||
// FetchData fetches additional information, meant to be called before persisting an entity record.
|
||||
func (e *Entity) FetchData(ctx context.Context) {
|
||||
e.getLocation(ctx)
|
||||
e.getLists(ctx)
|
||||
}
|
||||
|
||||
// ResetLists resets the current list data and forces
|
||||
// all list sources to be re-acquired when calling GetLists().
|
||||
func (e *Entity) ResetLists() {
|
||||
// TODO(ppacher): our actual goal is to reset the domain
|
||||
// list right now so we could be more efficient by keeping
|
||||
// the other lists around.
|
||||
|
||||
e.BlockedByLists = nil
|
||||
e.BlockedEntities = nil
|
||||
e.ListOccurences = nil
|
||||
|
||||
e.domainListLoaded = false
|
||||
e.ipListLoaded = false
|
||||
e.countryListLoaded = false
|
||||
e.asnListLoaded = false
|
||||
e.resolveSubDomainLists = false
|
||||
e.checkCNAMEs = false
|
||||
e.loadDomainListOnce = sync.Once{}
|
||||
e.loadIPListOnce = sync.Once{}
|
||||
e.loadCountryListOnce = sync.Once{}
|
||||
e.loadAsnListOnce = sync.Once{}
|
||||
}
|
||||
|
||||
// ResolveSubDomainLists enables or disables list lookups for
|
||||
// sub-domains.
|
||||
func (e *Entity) ResolveSubDomainLists(ctx context.Context, enabled bool) {
|
||||
if e.domainListLoaded && enabled != e.resolveSubDomainLists {
|
||||
log.Tracer(ctx).Warningf("intel/filterlists: tried to change sub-domain resolving for %s but lists are already fetched", e.Domain)
|
||||
}
|
||||
e.resolveSubDomainLists = enabled
|
||||
}
|
||||
|
||||
// EnableCNAMECheck enalbes or disables list lookups for
|
||||
// entity CNAMEs.
|
||||
func (e *Entity) EnableCNAMECheck(ctx context.Context, enabled bool) {
|
||||
if e.domainListLoaded && enabled != e.checkCNAMEs {
|
||||
log.Tracer(ctx).Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain)
|
||||
}
|
||||
e.checkCNAMEs = enabled
|
||||
}
|
||||
|
||||
// CNAMECheckEnabled returns true if the entities CNAMEs should
|
||||
// also be checked.
|
||||
func (e *Entity) CNAMECheckEnabled() bool {
|
||||
return e.checkCNAMEs
|
||||
}
|
||||
|
||||
// Domain and IP
|
||||
|
||||
// EnableReverseResolving enables reverse resolving the domain from the IP on demand.
|
||||
func (e *Entity) EnableReverseResolving() {
|
||||
e.reverseResolveEnabled = true
|
||||
}
|
||||
|
||||
func (e *Entity) reverseResolve(ctx context.Context) {
|
||||
e.reverseResolveOnce.Do(func() {
|
||||
// need IP!
|
||||
if e.IP == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// reverse resolve
|
||||
if reverseResolver == nil {
|
||||
return
|
||||
}
|
||||
// TODO: security level
|
||||
domain, err := reverseResolver(ctx, e.IP.String())
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Warningf("intel: failed to resolve IP %s: %s", e.IP, err)
|
||||
return
|
||||
}
|
||||
e.ReverseDomain = domain
|
||||
})
|
||||
}
|
||||
|
||||
// GetDomain returns the domain and whether it is set.
|
||||
func (e *Entity) GetDomain(ctx context.Context, mayUseReverseDomain bool) (string, bool) {
|
||||
if mayUseReverseDomain && e.reverseResolveEnabled {
|
||||
e.reverseResolve(ctx)
|
||||
|
||||
if e.ReverseDomain == "" {
|
||||
return "", false
|
||||
}
|
||||
return e.ReverseDomain, true
|
||||
}
|
||||
|
||||
if e.Domain == "" {
|
||||
return "", false
|
||||
}
|
||||
return e.Domain, true
|
||||
}
|
||||
|
||||
// GetIP returns the IP and whether it is set.
|
||||
func (e *Entity) GetIP() (net.IP, bool) {
|
||||
if e.IP == nil {
|
||||
return nil, false
|
||||
}
|
||||
return e.IP, true
|
||||
}
|
||||
|
||||
// Location
|
||||
|
||||
func (e *Entity) getLocation(ctx context.Context) {
|
||||
e.fetchLocationOnce.Do(func() {
|
||||
// Only check if we have a global IP address.
|
||||
if e.IP == nil || !e.IPScope.IsGlobal() {
|
||||
return
|
||||
}
|
||||
|
||||
// get location data
|
||||
loc, err := geoip.GetLocation(e.IP)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Warningf("intel: failed to get location data for %s: %s", e.IP, err)
|
||||
e.LocationError = err.Error()
|
||||
return
|
||||
}
|
||||
e.location = loc
|
||||
e.Country = loc.Country.Code
|
||||
e.Coordinates = &loc.Coordinates
|
||||
e.ASN = loc.AutonomousSystemNumber
|
||||
e.ASOrg = loc.AutonomousSystemOrganization
|
||||
|
||||
// Log result.
|
||||
if log.GetLogLevel() == log.TraceLevel {
|
||||
// Build flags
|
||||
var flags string
|
||||
if loc.IsAnycast {
|
||||
flags += " anycast"
|
||||
}
|
||||
if loc.IsSatelliteProvider {
|
||||
flags += " satellite"
|
||||
}
|
||||
if loc.IsAnonymousProxy {
|
||||
flags += " anonymous"
|
||||
}
|
||||
|
||||
// Log location
|
||||
log.Tracer(ctx).Tracef(
|
||||
"intel: located %s in %s (%s), as part of AS%d by %s%s",
|
||||
e.IP,
|
||||
loc.Country.Name,
|
||||
loc.Country.Code,
|
||||
loc.AutonomousSystemNumber,
|
||||
loc.AutonomousSystemOrganization,
|
||||
flags,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GetLocation returns the raw location data and whether it is set.
|
||||
func (e *Entity) GetLocation(ctx context.Context) (*geoip.Location, bool) {
|
||||
e.getLocation(ctx)
|
||||
|
||||
if e.location == nil {
|
||||
return nil, false
|
||||
}
|
||||
return e.location, true
|
||||
}
|
||||
|
||||
// GetCountry returns the two letter ISO country code and whether it is set.
|
||||
func (e *Entity) GetCountry(ctx context.Context) (string, bool) {
|
||||
e.getLocation(ctx)
|
||||
|
||||
if e.LocationError != "" {
|
||||
return "", false
|
||||
}
|
||||
return e.Country, true
|
||||
}
|
||||
|
||||
// GetCountryInfo returns the two letter ISO country code and whether it is set.
|
||||
func (e *Entity) GetCountryInfo(ctx context.Context) *geoip.CountryInfo {
|
||||
e.getLocation(ctx)
|
||||
|
||||
if e.LocationError != "" {
|
||||
return nil
|
||||
}
|
||||
return &e.location.Country
|
||||
}
|
||||
|
||||
// GetASN returns the AS number and whether it is set.
|
||||
func (e *Entity) GetASN(ctx context.Context) (uint, bool) {
|
||||
e.getLocation(ctx)
|
||||
|
||||
if e.LocationError != "" {
|
||||
return 0, false
|
||||
}
|
||||
return e.ASN, true
|
||||
}
|
||||
|
||||
// Lists
|
||||
|
||||
func (e *Entity) getLists(ctx context.Context) {
|
||||
e.getDomainLists(ctx)
|
||||
e.getASNLists(ctx)
|
||||
e.getIPLists(ctx)
|
||||
e.getCountryLists(ctx)
|
||||
}
|
||||
|
||||
func (e *Entity) mergeList(key string, list []string) {
|
||||
if len(list) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if e.ListOccurences == nil {
|
||||
e.ListOccurences = make(map[string][]string)
|
||||
}
|
||||
|
||||
e.ListOccurences[key] = mergeStringList(e.ListOccurences[key], list)
|
||||
}
|
||||
|
||||
func (e *Entity) getDomainLists(ctx context.Context) {
|
||||
if e.domainListLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
domain, ok := e.GetDomain(ctx, false /* mayUseReverseDomain */)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
e.loadDomainListOnce.Do(func() {
|
||||
domainsToInspect := []string{domain}
|
||||
|
||||
if e.checkCNAMEs && len(e.CNAME) > 0 {
|
||||
log.Tracer(ctx).Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME)
|
||||
domainsToInspect = append(domainsToInspect, e.CNAME...)
|
||||
}
|
||||
|
||||
var domains []string
|
||||
if e.resolveSubDomainLists {
|
||||
for _, domain := range domainsToInspect {
|
||||
subdomains := splitDomain(domain)
|
||||
domains = append(domains, subdomains...)
|
||||
}
|
||||
} else {
|
||||
domains = domainsToInspect
|
||||
}
|
||||
|
||||
domains = makeDistinct(domains)
|
||||
|
||||
for _, d := range domains {
|
||||
list, err := filterlists.LookupDomain(d)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Errorf("intel: failed to get domain blocklists for %s: %s", d, err)
|
||||
e.ListsError = err.Error()
|
||||
return
|
||||
}
|
||||
|
||||
if len(list) > 0 {
|
||||
log.Tracer(ctx).Tracef("intel: loaded domain lists for %s: %s", d, strings.Join(list, ", "))
|
||||
e.mergeList(d, list)
|
||||
}
|
||||
}
|
||||
e.domainListLoaded = true
|
||||
})
|
||||
}
|
||||
|
||||
func splitDomain(domain string) []string {
|
||||
domain = strings.Trim(domain, ".")
|
||||
suffix, _ := publicsuffix.PublicSuffix(domain)
|
||||
if suffix == domain {
|
||||
return []string{domain}
|
||||
}
|
||||
|
||||
domainWithoutSuffix := domain[:len(domain)-len(suffix)]
|
||||
domainWithoutSuffix = strings.Trim(domainWithoutSuffix, ".")
|
||||
|
||||
splitted := strings.FieldsFunc(domainWithoutSuffix, func(r rune) bool {
|
||||
return r == '.'
|
||||
})
|
||||
|
||||
domains := make([]string, 0, len(splitted))
|
||||
for idx := range splitted {
|
||||
|
||||
d := strings.Join(splitted[idx:], ".") + "." + suffix
|
||||
if d[len(d)-1] != '.' {
|
||||
d += "."
|
||||
}
|
||||
domains = append(domains, d)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
func (e *Entity) getASNLists(ctx context.Context) {
|
||||
if e.asnListLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
asn, ok := e.GetASN(ctx)
|
||||
if !ok || asn == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
e.loadAsnListOnce.Do(func() {
|
||||
asnStr := fmt.Sprintf("%d", asn)
|
||||
list, err := filterlists.LookupASNString(asnStr)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err)
|
||||
e.ListsError = err.Error()
|
||||
return
|
||||
}
|
||||
|
||||
if len(list) > 0 {
|
||||
log.Tracer(ctx).Tracef("intel: loaded ASN lists for %s: %s", asnStr, strings.Join(list, ", "))
|
||||
e.mergeList(asnStr, list)
|
||||
}
|
||||
|
||||
e.asnListLoaded = true
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Entity) getCountryLists(ctx context.Context) {
|
||||
if e.countryListLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
country, ok := e.GetCountry(ctx)
|
||||
if !ok || country == "" {
|
||||
return
|
||||
}
|
||||
|
||||
e.loadCountryListOnce.Do(func() {
|
||||
list, err := filterlists.LookupCountry(country)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Errorf("intel: failed to load country blocklist for %s: %s", country, err)
|
||||
e.ListsError = err.Error()
|
||||
return
|
||||
}
|
||||
|
||||
if len(list) > 0 {
|
||||
log.Tracer(ctx).Tracef("intel: loaded country lists for %s: %s", country, strings.Join(list, ", "))
|
||||
e.mergeList(country, list)
|
||||
}
|
||||
|
||||
e.countryListLoaded = true
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Entity) getIPLists(ctx context.Context) {
|
||||
if e.ipListLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
ip, ok := e.GetIP()
|
||||
if !ok || ip == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// only load lists for IP addresses that are classified as global.
|
||||
if !e.IPScope.IsGlobal() {
|
||||
return
|
||||
}
|
||||
|
||||
e.loadIPListOnce.Do(func() {
|
||||
list, err := filterlists.LookupIP(ip)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Errorf("intel: failed to get IP blocklist for %s: %s", ip.String(), err)
|
||||
e.ListsError = err.Error()
|
||||
return
|
||||
}
|
||||
|
||||
if len(list) > 0 {
|
||||
log.Tracer(ctx).Tracef("intel: loaded IP lists for %s: %s", ip.String(), strings.Join(list, ", "))
|
||||
e.mergeList(ip.String(), list)
|
||||
}
|
||||
|
||||
e.ipListLoaded = true
|
||||
})
|
||||
}
|
||||
|
||||
// LoadLists searches all filterlists for all occurrences of
|
||||
// this entity.
|
||||
func (e *Entity) LoadLists(ctx context.Context) {
|
||||
e.getLists(ctx)
|
||||
}
|
||||
|
||||
// MatchLists matches the entities lists against a slice
|
||||
// of source IDs and updates various entity properties
|
||||
// like BlockedByLists, ListOccurences and BlockedEntitites.
|
||||
func (e *Entity) MatchLists(lists []string) bool {
|
||||
if len(lists) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
e.BlockedByLists = nil
|
||||
e.BlockedEntities = nil
|
||||
|
||||
lm := makeMap(lists)
|
||||
for key, keyLists := range e.ListOccurences {
|
||||
for _, keyListID := range keyLists {
|
||||
if _, ok := lm[keyListID]; ok {
|
||||
e.BlockedByLists = append(e.BlockedByLists, keyListID)
|
||||
e.BlockedEntities = append(e.BlockedEntities, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
e.BlockedByLists = makeDistinct(e.BlockedByLists)
|
||||
e.BlockedEntities = makeDistinct(e.BlockedEntities)
|
||||
|
||||
return len(e.BlockedByLists) > 0
|
||||
}
|
||||
|
||||
// ListBlockReason returns the block reason for this entity.
|
||||
func (e *Entity) ListBlockReason() ListBlockReason {
|
||||
blockedBy := make([]ListMatch, len(e.BlockedEntities))
|
||||
|
||||
lm := makeMap(e.BlockedByLists)
|
||||
|
||||
for idx, blockedEntity := range e.BlockedEntities {
|
||||
if entityLists, ok := e.ListOccurences[blockedEntity]; ok {
|
||||
var activeLists []string
|
||||
var inactiveLists []string
|
||||
|
||||
for _, l := range entityLists {
|
||||
if _, ok := lm[l]; ok {
|
||||
activeLists = append(activeLists, l)
|
||||
} else {
|
||||
inactiveLists = append(inactiveLists, l)
|
||||
}
|
||||
}
|
||||
|
||||
blockedBy[idx] = ListMatch{
|
||||
Entity: blockedEntity,
|
||||
ActiveLists: activeLists,
|
||||
InactiveLists: inactiveLists,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return blockedBy
|
||||
}
|
||||
|
||||
func mergeStringList(a, b []string) []string {
|
||||
listMap := make(map[string]struct{})
|
||||
for _, s := range a {
|
||||
listMap[s] = struct{}{}
|
||||
}
|
||||
for _, s := range b {
|
||||
listMap[s] = struct{}{}
|
||||
}
|
||||
|
||||
res := make([]string, 0, len(listMap))
|
||||
for s := range listMap {
|
||||
res = append(res, s)
|
||||
}
|
||||
sort.Strings(res)
|
||||
return res
|
||||
}
|
||||
|
||||
func makeDistinct(slice []string) []string {
|
||||
m := make(map[string]struct{}, len(slice))
|
||||
result := make([]string, 0, len(slice))
|
||||
|
||||
for _, v := range slice {
|
||||
if _, ok := m[v]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
m[v] = struct{}{}
|
||||
result = append(result, v)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func makeMap(slice []string) map[string]struct{} {
|
||||
lm := make(map[string]struct{})
|
||||
for _, v := range slice {
|
||||
lm[v] = struct{}{}
|
||||
}
|
||||
return lm
|
||||
}
|
||||
218
service/intel/filterlists/bloom.go
Normal file
218
service/intel/filterlists/bloom.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package filterlists
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tannerryan/ring"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
var defaultFilter = newScopedBloom()
|
||||
|
||||
// scopedBloom is a wrapper around a bloomfilter implementation
|
||||
// providing scoped filters for different entity types.
|
||||
type scopedBloom struct {
|
||||
rw sync.RWMutex
|
||||
domain *ring.Ring
|
||||
asn *ring.Ring
|
||||
country *ring.Ring
|
||||
ipv4 *ring.Ring
|
||||
ipv6 *ring.Ring
|
||||
}
|
||||
|
||||
func newScopedBloom() *scopedBloom {
|
||||
mustInit := func(size int) *ring.Ring {
|
||||
f, err := ring.Init(size, bfFalsePositiveRate)
|
||||
if err != nil {
|
||||
// we panic here as those values cannot be controlled
|
||||
// by the user and invalid values shouldn't be
|
||||
// in a release anyway.
|
||||
panic("Invalid bloom filter parameters!")
|
||||
}
|
||||
return f
|
||||
}
|
||||
return &scopedBloom{
|
||||
domain: mustInit(domainBfSize),
|
||||
asn: mustInit(asnBfSize),
|
||||
country: mustInit(countryBfSize),
|
||||
ipv4: mustInit(ipv4BfSize),
|
||||
ipv6: mustInit(ipv6BfSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (bf *scopedBloom) getBloomForType(entityType string) (*ring.Ring, error) {
|
||||
var r *ring.Ring
|
||||
|
||||
switch strings.ToLower(entityType) {
|
||||
case "domain":
|
||||
r = bf.domain
|
||||
case "asn":
|
||||
r = bf.asn
|
||||
case "ipv4":
|
||||
r = bf.ipv4
|
||||
case "ipv6":
|
||||
r = bf.ipv6
|
||||
case "country":
|
||||
r = bf.country
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported filterlists entity type %q", entityType)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (bf *scopedBloom) add(scope, value string) {
|
||||
bf.rw.Lock()
|
||||
defer bf.rw.Unlock()
|
||||
|
||||
r, err := bf.getBloomForType(scope)
|
||||
if err != nil {
|
||||
// If we don't have a bloom filter for that scope
|
||||
// we are probably running an older version that does
|
||||
// not have support for it. We just drop the value
|
||||
// as a call to Test() for that scope will always
|
||||
// return "true"
|
||||
log.Warningf("failed to add unknown entity type %q with value %q", scope, value)
|
||||
return
|
||||
}
|
||||
|
||||
r.Add([]byte(value))
|
||||
}
|
||||
|
||||
func (bf *scopedBloom) test(scope, value string) bool {
|
||||
bf.rw.RLock()
|
||||
defer bf.rw.RUnlock()
|
||||
|
||||
r, err := bf.getBloomForType(scope)
|
||||
if err != nil {
|
||||
log.Warningf("testing for unknown entity type %q", scope)
|
||||
return true // simulate a match to the caller
|
||||
}
|
||||
|
||||
return r.Test([]byte(value))
|
||||
}
|
||||
|
||||
func (bf *scopedBloom) loadFromCache() error {
|
||||
bf.rw.Lock()
|
||||
defer bf.rw.Unlock()
|
||||
|
||||
if err := loadBloomFromCache(bf.domain, "domain"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := loadBloomFromCache(bf.asn, "asn"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := loadBloomFromCache(bf.country, "country"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := loadBloomFromCache(bf.ipv4, "ipv4"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := loadBloomFromCache(bf.ipv6, "ipv6"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bf *scopedBloom) saveToCache() error {
|
||||
bf.rw.RLock()
|
||||
defer bf.rw.RUnlock()
|
||||
|
||||
if err := saveBloomToCache(bf.domain, "domain"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := saveBloomToCache(bf.asn, "asn"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := saveBloomToCache(bf.country, "country"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := saveBloomToCache(bf.ipv4, "ipv4"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := saveBloomToCache(bf.ipv6, "ipv6"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bf *scopedBloom) replaceWith(other *scopedBloom) {
|
||||
bf.rw.Lock()
|
||||
defer bf.rw.Unlock()
|
||||
|
||||
other.rw.RLock()
|
||||
defer other.rw.RUnlock()
|
||||
|
||||
bf.domain = other.domain
|
||||
bf.asn = other.asn
|
||||
bf.country = other.country
|
||||
bf.ipv4 = other.ipv4
|
||||
bf.ipv6 = other.ipv6
|
||||
}
|
||||
|
||||
type bloomFilterRecord struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Filter string
|
||||
}
|
||||
|
||||
// loadBloomFromCache loads the bloom filter stored under scope
|
||||
// into bf.
|
||||
func loadBloomFromCache(bf *ring.Ring, scope string) error {
|
||||
r, err := cache.Get(makeBloomCacheKey(scope))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var filterRecord *bloomFilterRecord
|
||||
if r.IsWrapped() {
|
||||
filterRecord = new(bloomFilterRecord)
|
||||
if err := record.Unwrap(r, filterRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
var ok bool
|
||||
filterRecord, ok = r.(*bloomFilterRecord)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected bloomFilterRecord but got %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
blob, err := hex.DecodeString(filterRecord.Filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := bf.UnmarshalBinary(blob); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveBloomToCache saves the bitset of the bloomfilter bf
|
||||
// in the cache db.
|
||||
func saveBloomToCache(bf *ring.Ring, scope string) error {
|
||||
blob, err := bf.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filter := hex.EncodeToString(blob)
|
||||
|
||||
r := &bloomFilterRecord{
|
||||
Filter: filter,
|
||||
}
|
||||
|
||||
r.SetKey(makeBloomCacheKey(scope))
|
||||
|
||||
return cache.Put(r)
|
||||
}
|
||||
67
service/intel/filterlists/cache_version.go
Normal file
67
service/intel/filterlists/cache_version.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package filterlists
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/record"
|
||||
)
|
||||
|
||||
const resetVersion = "v0.6.0"
|
||||
|
||||
type cacheVersionRecord struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Version string
|
||||
Reset string
|
||||
}
|
||||
|
||||
// getCacheDatabaseVersion reads and returns the cache
|
||||
// database version record.
|
||||
func getCacheDatabaseVersion() (*version.Version, error) {
|
||||
r, err := cache.Get(filterListCacheVersionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var verRecord *cacheVersionRecord
|
||||
if r.IsWrapped() {
|
||||
verRecord = new(cacheVersionRecord)
|
||||
if err := record.Unwrap(r, verRecord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
var ok bool
|
||||
verRecord, ok = r.(*cacheVersionRecord)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid type, expected cacheVersionRecord but got %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
if verRecord.Reset != resetVersion {
|
||||
return nil, database.ErrNotFound
|
||||
}
|
||||
|
||||
ver, err := version.NewSemver(verRecord.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ver, nil
|
||||
}
|
||||
|
||||
// setCacheDatabaseVersion updates the cache database
|
||||
// version record to ver.
|
||||
func setCacheDatabaseVersion(ver string) error {
|
||||
verRecord := &cacheVersionRecord{
|
||||
Version: ver,
|
||||
Reset: resetVersion,
|
||||
}
|
||||
|
||||
verRecord.SetKey(filterListCacheVersionKey)
|
||||
return cache.Put(verRecord)
|
||||
}
|
||||
240
service/intel/filterlists/database.go
Normal file
240
service/intel/filterlists/database.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package filterlists
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/updater"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
)
|
||||
|
||||
const (
|
||||
baseListFilePath = "intel/lists/base.dsdl"
|
||||
intermediateListFilePath = "intel/lists/intermediate.dsdl"
|
||||
urgentListFilePath = "intel/lists/urgent.dsdl"
|
||||
listIndexFilePath = "intel/lists/index.dsd"
|
||||
)
|
||||
|
||||
// default bloomfilter element sizes (estimated).
|
||||
const (
|
||||
domainBfSize = 1000000
|
||||
asnBfSize = 1000
|
||||
countryBfSize = 100
|
||||
ipv4BfSize = 100
|
||||
ipv6BfSize = 100
|
||||
)
|
||||
|
||||
const bfFalsePositiveRate = 0.001
|
||||
|
||||
var (
|
||||
filterListLock sync.RWMutex
|
||||
|
||||
// Updater files for tracking upgrades.
|
||||
baseFile *updater.File
|
||||
intermediateFile *updater.File
|
||||
urgentFile *updater.File
|
||||
|
||||
filterListsLoaded chan struct{}
|
||||
)
|
||||
|
||||
var cache = database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
CacheSize: 2 ^ 8,
|
||||
})
|
||||
|
||||
// getFileFunc is the function used to get a file from
|
||||
// the updater. It's basically updates.GetFile and used
|
||||
// for unit testing.
|
||||
type getFileFunc func(string) (*updater.File, error)
|
||||
|
||||
// getFile points to updates.GetFile but may be set to
|
||||
// something different during unit testing.
|
||||
var getFile getFileFunc = updates.GetFile
|
||||
|
||||
func init() {
|
||||
filterListsLoaded = make(chan struct{})
|
||||
}
|
||||
|
||||
// isLoaded returns true if the filterlists have been
|
||||
// loaded.
|
||||
func isLoaded() bool {
|
||||
select {
|
||||
case <-filterListsLoaded:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// processListFile opens the latest version of file and decodes it's DSDL
|
||||
// content. It calls processEntry for each decoded filterlists entry.
|
||||
func processListFile(ctx context.Context, filter *scopedBloom, file *updater.File) error {
|
||||
f, err := os.Open(file.Path())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
values := make(chan *listEntry, 100)
|
||||
records := make(chan record.Record, 100)
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
// startSafe runs fn inside the error group but wrapped
|
||||
// in recovered function.
|
||||
startSafe := func(fn func() error) {
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if x := recover(); x != nil {
|
||||
if e, ok := x.(error); ok {
|
||||
err = e
|
||||
} else {
|
||||
err = fmt.Errorf("%v", x)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = fn()
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
startSafe(func() (err error) {
|
||||
defer close(values)
|
||||
|
||||
err = decodeFile(ctx, f, values)
|
||||
return
|
||||
})
|
||||
|
||||
startSafe(func() error {
|
||||
defer close(records)
|
||||
for entry := range values {
|
||||
if err := processEntry(ctx, filter, entry, records); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
persistRecords(startSafe, records)
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func persistRecords(startJob func(func() error), records <-chan record.Record) {
|
||||
var cnt int
|
||||
start := time.Now()
|
||||
logProgress := func() {
|
||||
if cnt == 0 {
|
||||
// protection against panic
|
||||
return
|
||||
}
|
||||
|
||||
timePerEntity := time.Since(start) / time.Duration(cnt)
|
||||
speed := float64(time.Second) / float64(timePerEntity)
|
||||
log.Debugf("processed %d entities in %s with %s / entity (%.2f entities/second)", cnt, time.Since(start), timePerEntity, speed)
|
||||
}
|
||||
|
||||
batch := database.NewInterface(&database.Options{Local: true, Internal: true})
|
||||
|
||||
var processBatch func() error
|
||||
processBatch = func() error {
|
||||
batchPut := batch.PutMany("cache")
|
||||
for r := range records {
|
||||
if err := batchPut(r); err != nil {
|
||||
return err
|
||||
}
|
||||
cnt++
|
||||
|
||||
if cnt%10000 == 0 {
|
||||
logProgress()
|
||||
}
|
||||
|
||||
if cnt%1000 == 0 {
|
||||
if err := batchPut(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startJob(processBatch)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// log final batch
|
||||
if cnt%10000 != 0 { // avoid duplicate logging
|
||||
logProgress()
|
||||
}
|
||||
return batchPut(nil)
|
||||
}
|
||||
|
||||
startJob(processBatch)
|
||||
}
|
||||
|
||||
func normalizeEntry(entry *listEntry) {
|
||||
switch strings.ToLower(entry.Type) { //
|
||||
case "domain":
|
||||
entry.Entity = strings.ToLower(entry.Entity)
|
||||
if entry.Entity[len(entry.Entity)-1] != '.' {
|
||||
// ensure domains from the filter list are fully qualified and end in dot.
|
||||
entry.Entity += "."
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func processEntry(ctx context.Context, filter *scopedBloom, entry *listEntry, records chan<- record.Record) error {
|
||||
normalizeEntry(entry)
|
||||
|
||||
// Only add the entry to the bloom filter if it has any sources.
|
||||
if len(entry.Resources) > 0 {
|
||||
filter.add(entry.Type, entry.Entity)
|
||||
}
|
||||
|
||||
r := &entityRecord{
|
||||
Value: entry.Entity,
|
||||
Type: entry.Type,
|
||||
Sources: entry.getSources(),
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
// If the entry is a "delete" update, actually delete it to save space.
|
||||
if entry.Whitelist {
|
||||
r.CreateMeta()
|
||||
r.Meta().Delete()
|
||||
}
|
||||
|
||||
key := makeListCacheKey(strings.ToLower(r.Type), r.Value)
|
||||
r.SetKey(key)
|
||||
|
||||
select {
|
||||
case records <- r:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func mapKeys(m map[string]struct{}) []string {
|
||||
sl := make([]string, 0, len(m))
|
||||
for s := range m {
|
||||
sl = append(sl, s)
|
||||
}
|
||||
|
||||
sort.Strings(sl)
|
||||
return sl
|
||||
}
|
||||
143
service/intel/filterlists/decoder.go
Normal file
143
service/intel/filterlists/decoder.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package filterlists
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/utils"
|
||||
)
|
||||
|
||||
type listEntry struct {
|
||||
Type string `json:"type"`
|
||||
Entity string `json:"entity"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
Resources []entryResource `json:"resources"`
|
||||
}
|
||||
|
||||
type entryResource struct {
|
||||
SourceID string `json:"sourceID"`
|
||||
ResourceID string `json:"resourceID"`
|
||||
}
|
||||
|
||||
func (entry *listEntry) getSources() (sourceIDs []string) {
|
||||
sourceIDs = make([]string, 0, len(entry.Resources))
|
||||
|
||||
for _, resource := range entry.Resources {
|
||||
if !utils.StringInSlice(sourceIDs, resource.SourceID) {
|
||||
sourceIDs = append(sourceIDs, resource.SourceID)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// decodeFile decodes a DSDL filterlists file and sends decoded entities to
|
||||
// ch. It blocks until all list entries have been consumed or ctx is cancelled.
|
||||
func decodeFile(ctx context.Context, r io.Reader, ch chan<- *listEntry) error {
|
||||
compressed, format, err := parseHeader(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parser header: %w", err)
|
||||
}
|
||||
|
||||
if compressed {
|
||||
r, err = gzip.NewReader(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open gzip reader: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// we need a reader that supports io.ByteReader
|
||||
reader := &byteReader{r}
|
||||
var entryCount int
|
||||
for {
|
||||
entryCount++
|
||||
length, readErr := binary.ReadUvarint(reader)
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to load varint entity length: %w", readErr)
|
||||
}
|
||||
|
||||
blob := make([]byte, length)
|
||||
_, readErr = io.ReadFull(reader, blob)
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
// there shouldn't be an EOF here because
|
||||
// we actually got a length above. Return
|
||||
// ErrUnexpectedEOF instead of just EOF.
|
||||
// io.ReadFull already returns ErrUnexpectedEOF
|
||||
// if it failed to read blob as a whole but my
|
||||
// return io.EOF if it read exactly 0 bytes.
|
||||
readErr = io.ErrUnexpectedEOF
|
||||
}
|
||||
return readErr
|
||||
}
|
||||
|
||||
// we don't really care about the format here but it must be
|
||||
// something that can encode/decode complex structures like
|
||||
// JSON, BSON or GenCode. So LoadAsFormat MUST return the value
|
||||
// passed as the third parameter. String or RAW encoding IS AN
|
||||
// error here.
|
||||
entry := &listEntry{}
|
||||
err := dsd.LoadAsFormat(blob, format, entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decoded DSD encoded entity: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- entry:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseHeader(r io.Reader) (compressed bool, format byte, err error) {
|
||||
var listHeader [1]byte
|
||||
if _, err = r.Read(listHeader[:]); err != nil {
|
||||
// if we have an error here we can safely abort because
|
||||
// the file must be broken
|
||||
return
|
||||
}
|
||||
|
||||
if listHeader[0] != dsd.LIST {
|
||||
err = fmt.Errorf("unexpected file type: %d (%c), expected dsd list", listHeader[0], listHeader[0])
|
||||
return
|
||||
}
|
||||
|
||||
var compression [1]byte
|
||||
if _, err = r.Read(compression[:]); err != nil {
|
||||
// same here, a DSDL file must have at least 2 bytes header
|
||||
return
|
||||
}
|
||||
|
||||
if compression[0] == dsd.GZIP {
|
||||
compressed = true
|
||||
|
||||
var formatSlice [1]byte
|
||||
if _, err = r.Read(formatSlice[:]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
format = formatSlice[0]
|
||||
return
|
||||
}
|
||||
|
||||
format = compression[0]
|
||||
return // nolint:nakedret
|
||||
}
|
||||
|
||||
// byteReader extends an io.Reader to implement the ByteReader interface.
|
||||
type byteReader struct{ io.Reader }
|
||||
|
||||
func (br *byteReader) ReadByte() (byte, error) {
|
||||
var b [1]byte
|
||||
_, err := br.Read(b[:])
|
||||
return b[0], err
|
||||
}
|
||||
288
service/intel/filterlists/index.go
Normal file
288
service/intel/filterlists/index.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package filterlists
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/updater"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
)
|
||||
|
||||
// the following definitions are copied from the intelhub repository
|
||||
// and stripped down to only include data required by portmaster.
|
||||
|
||||
// Category is used to group different list sources by the type
|
||||
// of entity they are blocking. Categories may be nested using
|
||||
// the Parent field.
|
||||
type Category struct {
|
||||
// ID is a unique ID for the category. For sub-categories
|
||||
// this ID must be used in the Parent field of any directly
|
||||
// nesteded categories.
|
||||
ID string `json:"id"`
|
||||
|
||||
// Parent may hold the ID of another category. If set, this
|
||||
// category is made a sub-category of it's parent.
|
||||
Parent string `json:"parent,omitempty"`
|
||||
|
||||
// Name is a human readable name for the category and can
|
||||
// be used in user interfaces.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Description is a human readable description that may be
|
||||
// displayed in user interfaces.
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// Source defines an external filterlists source.
|
||||
type Source struct {
|
||||
// ID is a unique ID for the source. Entities always reference the
|
||||
// sources they have been observed in using this ID. Refer to the
|
||||
// Entry struct for more information.
|
||||
ID string `json:"id"`
|
||||
|
||||
// Name is a human readable name for the source and can be used
|
||||
// in user interfaces.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Description may hold a human readable description for the source.
|
||||
// It may be used in user interfaces.
|
||||
Description string `json:"description"`
|
||||
|
||||
// Type describes the type of entities the source provides. Refer
|
||||
// to the Type definition for more information and well-known types.
|
||||
Type string `json:"type"`
|
||||
|
||||
// URL points to the filterlists file.
|
||||
URL string `json:"url"`
|
||||
|
||||
// Category holds the unique ID of a category the source belongs to. Since
|
||||
// categories can be nested the source is automatically part of all categories
|
||||
// in the hierarchy. Refer to the Category struct for more information.
|
||||
Category string `json:"category"`
|
||||
|
||||
// Website may holds the URL of the source maintainers website.
|
||||
Website string `json:"website,omitempty"`
|
||||
|
||||
// License holds the license that is used for the source.
|
||||
License string `json:"license"`
|
||||
|
||||
// Contribute may hold an opaque string that informs a user on how to
|
||||
// contribute to the source. This may be a URL or mail address.
|
||||
Contribute string `json:"contribute"`
|
||||
}
|
||||
|
||||
// ListIndexFile describes the structure of the released list
|
||||
// index file.
|
||||
type ListIndexFile struct {
|
||||
record.Base
|
||||
sync.RWMutex
|
||||
|
||||
Version string `json:"version"`
|
||||
SchemaVersion string `json:"schemaVersion"`
|
||||
Categories []Category `json:"categories"`
|
||||
Sources []Source `json:"sources"`
|
||||
}
|
||||
|
||||
func (index *ListIndexFile) getCategorySources(id string) []string {
|
||||
ids := make(map[string]struct{})
|
||||
|
||||
// find all sources that match against cat
|
||||
for _, s := range index.Sources {
|
||||
if s.Category == id {
|
||||
ids[s.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// find all child-categories recursing into getCategorySources.
|
||||
for _, c := range index.Categories {
|
||||
if c.Parent == id {
|
||||
for _, sid := range index.getCategorySources(c.ID) {
|
||||
ids[sid] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mapKeys(ids)
|
||||
}
|
||||
|
||||
func (index *ListIndexFile) getSourcesMatching(id string) []string {
|
||||
// if id is already a source ID we just return it
|
||||
for _, s := range index.Sources {
|
||||
if s.ID == id {
|
||||
return []string{s.ID}
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise we need to check the category tree
|
||||
return index.getCategorySources(id)
|
||||
}
|
||||
|
||||
func (index *ListIndexFile) getDistictSourceIDs(ids ...string) []string {
|
||||
index.RLock()
|
||||
defer index.RUnlock()
|
||||
|
||||
distinctIDs := make(map[string]struct{})
|
||||
|
||||
for _, id := range ids {
|
||||
for _, sid := range index.getSourcesMatching(id) {
|
||||
distinctIDs[sid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return mapKeys(distinctIDs)
|
||||
}
|
||||
|
||||
func getListIndexFromCache() (*ListIndexFile, error) {
|
||||
r, err := cache.Get(filterListIndexKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var index *ListIndexFile
|
||||
if r.IsWrapped() {
|
||||
index = new(ListIndexFile)
|
||||
if err := record.Unwrap(r, index); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
var ok bool
|
||||
index, ok = r.(*ListIndexFile)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid type, expected ListIndexFile but got %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
var (
|
||||
// listIndexUpdate must only be used by updateListIndex.
|
||||
listIndexUpdate *updater.File
|
||||
listIndexUpdateLock sync.Mutex
|
||||
)
|
||||
|
||||
func updateListIndex() error {
|
||||
listIndexUpdateLock.Lock()
|
||||
defer listIndexUpdateLock.Unlock()
|
||||
|
||||
// Check if an update is needed.
|
||||
switch {
|
||||
case listIndexUpdate == nil:
|
||||
// This is the first time this function is run, get updater file for index.
|
||||
var err error
|
||||
listIndexUpdate, err = updates.GetFile(listIndexFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the version in the cache is current.
|
||||
index, err := getListIndexFromCache()
|
||||
switch {
|
||||
case errors.Is(err, database.ErrNotFound):
|
||||
log.Info("filterlists: index not in cache, starting update")
|
||||
case err != nil:
|
||||
log.Warningf("filterlists: failed to load index from cache, starting update: %s", err)
|
||||
case !listIndexUpdate.EqualsVersion(strings.TrimPrefix(index.Version, "v")):
|
||||
log.Infof(
|
||||
"filterlists: index from cache is outdated, starting update (%s != %s)",
|
||||
strings.TrimPrefix(index.Version, "v"),
|
||||
listIndexUpdate.Version(),
|
||||
)
|
||||
default:
|
||||
// List is in cache and current, there is nothing to do.
|
||||
log.Debug("filterlists: index is up to date")
|
||||
|
||||
// Update the unbreak filter list IDs on initial load.
|
||||
updateUnbreakFilterListIDs()
|
||||
|
||||
return nil
|
||||
}
|
||||
case listIndexUpdate.UpgradeAvailable():
|
||||
log.Info("filterlists: index update available, starting update")
|
||||
default:
|
||||
// Index is loaded and no update is available, there is nothing to do.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update list index from updates.
|
||||
blob, err := os.ReadFile(listIndexUpdate.Path())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
index := &ListIndexFile{}
|
||||
_, err = dsd.Load(blob, index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
index.SetKey(filterListIndexKey)
|
||||
|
||||
if err := cache.Put(index); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("intel/filterlists: updated list index in cache to %s", index.Version)
|
||||
|
||||
// Update the unbreak filter list IDs after an update.
|
||||
updateUnbreakFilterListIDs()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveListIDs resolves a slice of source or category IDs into
|
||||
// a slice of distinct source IDs.
|
||||
func ResolveListIDs(ids []string) ([]string, error) {
|
||||
index, err := getListIndexFromCache()
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNotFound) {
|
||||
if err := updateListIndex(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// retry resolving IDs
|
||||
return ResolveListIDs(ids)
|
||||
}
|
||||
|
||||
log.Errorf("failed to resolved ids %v: %s", ids, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resolved := index.getDistictSourceIDs(ids...)
|
||||
|
||||
log.Debugf("intel/filterlists: resolved ids %v to %v", ids, resolved)
|
||||
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
var (
|
||||
unbreakCategoryIDs = []string{"UNBREAK"}
|
||||
|
||||
unbreakIDs []string
|
||||
unbreakIDsLock sync.Mutex
|
||||
)
|
||||
|
||||
// GetUnbreakFilterListIDs returns the resolved list of all unbreak filter lists.
|
||||
func GetUnbreakFilterListIDs() []string {
|
||||
unbreakIDsLock.Lock()
|
||||
defer unbreakIDsLock.Unlock()
|
||||
|
||||
return unbreakIDs
|
||||
}
|
||||
|
||||
func updateUnbreakFilterListIDs() {
|
||||
unbreakIDsLock.Lock()
|
||||
defer unbreakIDsLock.Unlock()
|
||||
|
||||
resolvedIDs, err := ResolveListIDs(unbreakCategoryIDs)
|
||||
if err != nil {
|
||||
log.Warningf("filter: failed to resolve unbreak filter list IDs: %s", err)
|
||||
} else {
|
||||
unbreakIDs = resolvedIDs
|
||||
}
|
||||
}
|
||||
26
service/intel/filterlists/keys.go
Normal file
26
service/intel/filterlists/keys.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package filterlists
|
||||
|
||||
const (
|
||||
cacheDBPrefix = "cache:intel/filterlists"
|
||||
|
||||
// filterListCacheVersionKey is used to store the highest version
|
||||
// of a filterlists file (base, intermediate or urgent) in the
|
||||
// cache database. It's used to decide if the cache database and
|
||||
// bloomfilters need to be resetted and rebuilt.
|
||||
filterListCacheVersionKey = cacheDBPrefix + "/version"
|
||||
|
||||
// filterListIndexKey is used to store the filterlists index.
|
||||
filterListIndexKey = cacheDBPrefix + "/index"
|
||||
|
||||
// filterListKeyPrefix is the prefix inside that cache database
|
||||
// used for filter list entries.
|
||||
filterListKeyPrefix = cacheDBPrefix + "/lists/"
|
||||
)
|
||||
|
||||
func makeBloomCacheKey(scope string) string {
|
||||
return cacheDBPrefix + "/bloom/" + scope
|
||||
}
|
||||
|
||||
func makeListCacheKey(scope, key string) string {
|
||||
return filterListKeyPrefix + scope + "/" + key
|
||||
}
|
||||
129
service/intel/filterlists/lookup.go
Normal file
129
service/intel/filterlists/lookup.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package filterlists
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
// lookupBlockLists loads the entity record for key from
|
||||
// cache and returns the list of blocklist sources the
|
||||
// key is part of. It is not considered an error if
|
||||
// key does not exist, instead, an empty slice is
|
||||
// returned.
|
||||
func lookupBlockLists(entity, value string) ([]string, error) {
|
||||
key := makeListCacheKey(entity, value)
|
||||
if !isLoaded() {
|
||||
log.Warningf("intel/filterlists: not searching for %s because filterlists not loaded", key)
|
||||
// filterLists have not yet been loaded so
|
||||
// there's no point querying into the cache
|
||||
// database.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
filterListLock.RLock()
|
||||
defer filterListLock.RUnlock()
|
||||
|
||||
if !defaultFilter.test(entity, value) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// log.Debugf("intel/filterlists: searching for entries with %s", key)
|
||||
entry, err := getEntityRecordByKey(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
log.Errorf("intel/filterlists: failed to get entries for key %s: %s", key, err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entry.Sources, nil
|
||||
}
|
||||
|
||||
// LookupCountry returns a list of sources that mark the country
|
||||
// as blocked. If country is not stored in the cache database
|
||||
// a nil slice is returned.
|
||||
func LookupCountry(country string) ([]string, error) {
|
||||
return lookupBlockLists("country", country)
|
||||
}
|
||||
|
||||
// LookupDomain returns a list of sources that mark the domain
|
||||
// as blocked. If domain is not stored in the cache database
|
||||
// a nil slice is returned. The caller is responsible for making
|
||||
// sure that the given domain is valid and canonical.
|
||||
func LookupDomain(domain string) ([]string, error) {
|
||||
switch domain {
|
||||
case "", ".":
|
||||
// Return no lists for empty domains and the root zone.
|
||||
return nil, nil
|
||||
default:
|
||||
return lookupBlockLists("domain", domain)
|
||||
}
|
||||
}
|
||||
|
||||
// LookupASNString returns a list of sources that mark the ASN
|
||||
// as blocked. If ASN is not stored in the cache database
|
||||
// a nil slice is returned.
|
||||
func LookupASNString(asn string) ([]string, error) {
|
||||
return lookupBlockLists("asn", asn)
|
||||
}
|
||||
|
||||
// LookupIP returns a list of block sources that contain
|
||||
// a reference to ip. LookupIP automatically checks the IPv4 or
|
||||
// IPv6 lists respectively.
|
||||
func LookupIP(ip net.IP) ([]string, error) {
|
||||
if ip.To4() == nil {
|
||||
return LookupIPv6(ip)
|
||||
}
|
||||
|
||||
return LookupIPv4(ip)
|
||||
}
|
||||
|
||||
// LookupIPString is like LookupIP but accepts an IPv4 or
|
||||
// IPv6 address in their string representations.
|
||||
func LookupIPString(ipStr string) ([]string, error) {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return nil, errors.New("invalid IP")
|
||||
}
|
||||
|
||||
return LookupIP(ip)
|
||||
}
|
||||
|
||||
// LookupIPv4String returns a list of block sources that
|
||||
// contain a reference to ip. If the IP is not stored in the
|
||||
// cache database a nil slice is returned.
|
||||
func LookupIPv4String(ipv4 string) ([]string, error) {
|
||||
return lookupBlockLists("ipv4", ipv4)
|
||||
}
|
||||
|
||||
// LookupIPv4 is like LookupIPv4String but accepts a net.IP.
|
||||
func LookupIPv4(ipv4 net.IP) ([]string, error) {
|
||||
ip := ipv4.To4()
|
||||
if ip == nil {
|
||||
return nil, errors.New("invalid IPv4")
|
||||
}
|
||||
|
||||
return LookupIPv4String(ip.String())
|
||||
}
|
||||
|
||||
// LookupIPv6String returns a list of block sources that
|
||||
// contain a reference to ip. If the IP is not stored in the
|
||||
// cache database a nil slice is returned.
|
||||
func LookupIPv6String(ipv6 string) ([]string, error) {
|
||||
return lookupBlockLists("ipv6", ipv6)
|
||||
}
|
||||
|
||||
// LookupIPv6 is like LookupIPv6String but accepts a net.IP.
|
||||
func LookupIPv6(ipv6 net.IP) ([]string, error) {
|
||||
ip := ipv6.To16()
|
||||
if ip == nil {
|
||||
return nil, errors.New("invalid IPv6")
|
||||
}
|
||||
|
||||
return LookupIPv6String(ip.String())
|
||||
}
|
||||
110
service/intel/filterlists/module.go
Normal file
110
service/intel/filterlists/module.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package filterlists
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
)
|
||||
|
||||
var module *modules.Module
|
||||
|
||||
const (
|
||||
filterlistsDisabled = "filterlists:disabled"
|
||||
filterlistsUpdateFailed = "filterlists:update-failed"
|
||||
filterlistsStaleDataSurvived = "filterlists:staledata"
|
||||
)
|
||||
|
||||
// booleans mainly used to decouple the module
|
||||
// during testing.
|
||||
var (
|
||||
ignoreUpdateEvents = abool.New()
|
||||
ignoreNetEnvEvents = abool.New()
|
||||
)
|
||||
|
||||
func init() {
|
||||
ignoreNetEnvEvents.Set()
|
||||
|
||||
module = modules.Register("filterlists", prep, start, stop, "base", "updates")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
if err := module.RegisterEventHook(
|
||||
updates.ModuleName,
|
||||
updates.ResourceUpdateEvent,
|
||||
"Check for blocklist updates",
|
||||
func(ctx context.Context, _ interface{}) error {
|
||||
if ignoreUpdateEvents.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tryListUpdate(ctx)
|
||||
},
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to register resource update event handler: %w", err)
|
||||
}
|
||||
|
||||
if err := module.RegisterEventHook(
|
||||
netenv.ModuleName,
|
||||
netenv.OnlineStatusChangedEvent,
|
||||
"Check for blocklist updates",
|
||||
func(ctx context.Context, _ interface{}) error {
|
||||
if ignoreNetEnvEvents.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Nothing to do if we went offline.
|
||||
if !netenv.Online() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tryListUpdate(ctx)
|
||||
},
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to register online status changed event handler: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
filterListLock.Lock()
|
||||
defer filterListLock.Unlock()
|
||||
|
||||
ver, err := getCacheDatabaseVersion()
|
||||
if err == nil {
|
||||
log.Debugf("intel/filterlists: cache database has version %s", ver.String())
|
||||
|
||||
if err = defaultFilter.loadFromCache(); err != nil {
|
||||
err = fmt.Errorf("failed to initialize bloom filters: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("intel/filterlists: blocklists disabled, waiting for update (%s)", err)
|
||||
warnAboutDisabledFilterLists()
|
||||
} else {
|
||||
log.Debugf("intel/filterlists: using cache database")
|
||||
close(filterListsLoaded)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
filterListsLoaded = make(chan struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func warnAboutDisabledFilterLists() {
|
||||
module.Warning(
|
||||
filterlistsDisabled,
|
||||
"Filter Lists Are Initializing",
|
||||
"Filter lists are being downloaded and set up in the background. They will be activated as configured when finished.",
|
||||
)
|
||||
}
|
||||
86
service/intel/filterlists/module_test.go
Normal file
86
service/intel/filterlists/module_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package filterlists
|
||||
|
||||
/*
|
||||
func TestMain(m *testing.M) {
|
||||
// we completely ignore netenv events during testing.
|
||||
ignoreNetEnvEvents.Set()
|
||||
|
||||
if err := updates.DisableUpdateSchedule(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to disable update schedule: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
pmtesting.TestMainWithHooks(m, module, loadOnStart, nil)
|
||||
}
|
||||
|
||||
func loadOnStart() error {
|
||||
log.SetLogLevel(log.TraceLevel)
|
||||
|
||||
ch := make(chan struct{})
|
||||
defer close(ch)
|
||||
|
||||
if err := updates.TriggerUpdate(); err != nil {
|
||||
return fmt.Errorf("failed to trigger update: %w", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ch:
|
||||
return
|
||||
|
||||
case <-time.After(time.Minute):
|
||||
err = fmt.Errorf("timeout loading")
|
||||
close(filterListsLoaded) // let waitUntilLoaded() return
|
||||
}
|
||||
}()
|
||||
|
||||
waitUntilLoaded()
|
||||
time.Sleep(time.Second * 10)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
failureStatus, failureID, failureMsg := module.FailureStatus()
|
||||
if failureStatus == modules.FailureError || failureStatus == modules.FailureWarning {
|
||||
return fmt.Errorf("module in failure state: %s %q", failureID, failureMsg)
|
||||
}
|
||||
|
||||
// ignore update events from now on during testing.
|
||||
ignoreUpdateEvents.Set()
|
||||
|
||||
testSources := []string{"TEST"}
|
||||
testEntries := []*listEntry{
|
||||
{
|
||||
Entity: "example.com",
|
||||
Sources: testSources,
|
||||
Type: "Domain",
|
||||
},
|
||||
{
|
||||
Entity: "1.1.1.1",
|
||||
Sources: testSources,
|
||||
Type: "IPv4",
|
||||
},
|
||||
{
|
||||
Entity: "AT",
|
||||
Sources: testSources,
|
||||
Type: "Country",
|
||||
},
|
||||
{
|
||||
Entity: "123",
|
||||
Sources: testSources,
|
||||
Type: "ASN",
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range testEntries {
|
||||
// add some test entries
|
||||
if err := processEntry(e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
40
service/intel/filterlists/record.go
Normal file
40
service/intel/filterlists/record.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package filterlists
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
)
|
||||
|
||||
type entityRecord struct {
|
||||
record.Base `json:"-"`
|
||||
sync.Mutex `json:"-"`
|
||||
|
||||
Value string
|
||||
Sources []string
|
||||
Type string
|
||||
UpdatedAt int64
|
||||
}
|
||||
|
||||
func getEntityRecordByKey(key string) (*entityRecord, error) {
|
||||
r, err := cache.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if r.IsWrapped() {
|
||||
newER := &entityRecord{}
|
||||
if err := record.Unwrap(r, newER); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newER, nil
|
||||
}
|
||||
|
||||
newER, ok := r.(*entityRecord)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *entityRecord, but %T", r)
|
||||
}
|
||||
return newER, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user