Merge branch 'develop' into feature/new-kext

This commit is contained in:
Vladimir Stoilov
2024-04-18 10:14:49 +03:00
1563 changed files with 133676 additions and 1525 deletions

222
service/firewall/api.go Normal file
View File

@@ -0,0 +1,222 @@
package firewall
import (
"context"
"fmt"
"net"
"net/http"
"path/filepath"
"slices"
"strings"
"time"
"github.com/safing/portbase/api"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/service/updates"
)
const (
deniedMsgUnidentified = `%wFailed to identify the requesting process. Reload to try again.`
deniedMsgSystem = `%wSystem access to the Portmaster API is not permitted.
You can enable the Development Mode to disable API authentication for development purposes.`
deniedMsgUnauthorized = `%wThe requesting process is not authorized to access the Portmaster API.
Checked process paths:
%s
The authorized root path is %s.
You can enable the Development Mode to disable API authentication for development purposes.
For production use please create an API key in the settings.`
deniedMsgMisconfigured = `%wThe authentication system is misconfigured.`
)
var (
dataRoot *utils.DirStructure
apiPortSet bool
apiIP net.IP
apiPort uint16
)
func prepAPIAuth() error {
dataRoot = dataroot.Root()
return api.SetAuthenticator(apiAuthenticator)
}
func startAPIAuth() {
var err error
apiIP, apiPort, err = netutils.ParseIPPort(apiListenAddress())
if err != nil {
log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err)
return
}
apiPortSet = true
log.Tracef("filter: api port set to %d", apiPort)
}
func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, err error) {
if configReady.IsSet() && devMode() {
return &api.AuthToken{
Read: api.PermitSelf,
Write: api.PermitSelf,
}, nil
}
// get local IP/Port
localIP, localPort, err := netutils.ParseIPPort(s.Addr)
if err != nil {
return nil, fmt.Errorf("failed to get local IP/Port: %w", err)
}
// get remote IP/Port
remoteIP, remotePort, err := netutils.ParseIPPort(r.RemoteAddr)
if err != nil {
return nil, fmt.Errorf("failed to get remote IP/Port: %w", err)
}
// Check if the request is even local.
myIP, err := netenv.IsMyIP(remoteIP)
if err == nil && !myIP {
// Return to caller that the request was not handled.
return nil, nil
}
log.Tracer(r.Context()).Tracef("filter: authenticating API request from %s", r.RemoteAddr)
// It is important that this works, retry 5 times: every 500ms for 2.5s.
var retry bool
for tries := 0; tries < 5; tries++ {
retry, err = authenticateAPIRequest(
r.Context(),
&packet.Info{
Inbound: false, // outbound as we are looking for the process of the source address
Version: packet.IPv4,
Protocol: packet.TCP,
Src: remoteIP, // source as in the process we are looking for
SrcPort: remotePort, // source as in the process we are looking for
Dst: localIP,
DstPort: localPort,
PID: process.UndefinedProcessID,
},
)
if !retry {
break
}
// wait a little
time.Sleep(500 * time.Millisecond)
}
if err != nil {
return nil, err
}
return &api.AuthToken{
Read: api.PermitSelf,
Write: api.PermitSelf,
}, nil
}
func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bool, err error) {
var procsChecked []string
var originalPid int
// Get authenticated path.
authenticatedPath := updates.RootPath()
if authenticatedPath == "" {
return false, fmt.Errorf(deniedMsgMisconfigured, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
}
// Get real path.
authenticatedPath, err = filepath.EvalSymlinks(authenticatedPath)
if err != nil {
return false, fmt.Errorf(deniedMsgUnidentified, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
}
// Add filepath separator to confine to directory.
authenticatedPath += string(filepath.Separator)
// Get process of request.
pid, _, _ := process.GetPidOfConnection(ctx, pktInfo)
if pid < 0 {
return false, fmt.Errorf(deniedMsgUnidentified, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
}
proc, err := process.GetOrFindProcess(ctx, pid)
if err != nil {
log.Tracer(ctx).Debugf("filter: failed to get process of api request: %s", err)
originalPid = process.UnidentifiedProcessID
} else {
originalPid = proc.Pid
var previousPid int
// Find parent for up to two levels, if we don't match the path.
checkLevels := 2
checkLevelsLoop:
for i := 0; i < checkLevels+1; i++ {
// Check for eligible path.
switch proc.Pid {
case process.UnidentifiedProcessID, process.SystemProcessID:
break checkLevelsLoop
default: // normal process
// Check if the requesting process is in database root / updates dir.
if realPath, err := filepath.EvalSymlinks(proc.Path); err == nil {
// check if the client has been allowed by flag
if slices.Contains(allowedClients, realPath) {
return false, nil
}
if strings.HasPrefix(realPath, authenticatedPath) {
return false, nil
}
}
}
// Add checked path to list.
procsChecked = append(procsChecked, proc.Path)
// Get the parent process.
if i < checkLevels {
// save previous PID
previousPid = proc.Pid
// get parent process
proc, err = process.GetOrFindProcess(ctx, proc.ParentPid)
if err != nil {
log.Tracer(ctx).Debugf("filter: failed to get parent process of api request: %s", err)
break
}
// abort if we are looping
if proc.Pid == previousPid {
// this also catches -1 pid loops
break
}
}
}
}
switch originalPid {
case process.UnidentifiedProcessID:
log.Tracer(ctx).Warningf("filter: denying api access: failed to identify process")
return true, fmt.Errorf(deniedMsgUnidentified, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
case process.SystemProcessID:
log.Tracer(ctx).Warningf("filter: denying api access: request by system")
return false, fmt.Errorf(deniedMsgSystem, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
default: // normal process
log.Tracer(ctx).Warningf("filter: denying api access to %s - also checked %s (trusted root is %s)", procsChecked[0], strings.Join(procsChecked[1:], " "), dataRoot.Path)
return false, fmt.Errorf( //nolint:stylecheck // message for user
deniedMsgUnauthorized,
api.ErrAPIAccessDeniedMessage,
strings.Join(procsChecked, "\n"),
authenticatedPath,
)
}
}

View File

@@ -0,0 +1,64 @@
package firewall
import (
"context"
"strings"
"github.com/safing/portmaster/service/compat"
"github.com/safing/portmaster/service/nameserver/nsutil"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/profile/endpoints"
)
var resolverFilterLists = []string{"17-DNS"}
// PreventBypassing checks if the connection should be denied or permitted
// based on some bypass protection checks.
func PreventBypassing(ctx context.Context, conn *network.Connection) (endpoints.EPResult, string, nsutil.Responder) {
// Exclude incoming connections.
if conn.Inbound {
return endpoints.NoMatch, "", nil
}
// Exclude ICMP.
switch packet.IPProtocol(conn.Entity.Protocol) { //nolint:exhaustive // Checking for specific values only.
case packet.ICMP, packet.ICMPv6:
return endpoints.NoMatch, "", nil
}
// Block firefox canary domain to disable DoH.
// This MUST also affect the System Resolver, because the return value must
// be correct for this to work.
if strings.ToLower(conn.Entity.Domain) == "use-application-dns.net." {
return endpoints.Denied,
"blocked canary domain to prevent enabling of DNS-over-HTTPs",
nsutil.NxDomain()
}
// Exclude DNS requests coming from the System Resolver.
// This MUST also affect entities in the secure dns filter list, else the
// System Resolver is wrongly accused of bypassing.
if conn.Type == network.DNSRequest && conn.Process().IsSystemResolver() {
return endpoints.NoMatch, "", nil
}
// Block bypass attempts using an (encrypted) DNS server.
switch {
case conn.Entity.Port == 53:
return endpoints.Denied,
"blocked DNS query, manual dns setup required",
nsutil.BlockIP()
case conn.Entity.Port == 853:
// Block connections to port 853 - DNS over TLS.
fallthrough
case conn.Entity.MatchLists(resolverFilterLists):
// Block connection entities in the secure dns filter list.
compat.ReportSecureDNSBypassIssue(conn.Process())
return endpoints.Denied,
"blocked rogue connection to DNS resolver",
nsutil.BlockIP()
}
return endpoints.NoMatch, "", nil
}

153
service/firewall/config.go Normal file
View File

@@ -0,0 +1,153 @@
package firewall
import (
"github.com/tevino/abool"
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/notifications"
"github.com/safing/portmaster/service/core"
"github.com/safing/portmaster/spn/captain"
)
// Configuration Keys.
var (
CfgOptionEnableFilterKey = "filter/enable"
filterEnabled config.BoolOption
CfgOptionAskWithSystemNotificationsKey = "filter/askWithSystemNotifications"
cfgOptionAskWithSystemNotificationsOrder = 2
askWithSystemNotifications config.BoolOption
CfgOptionAskTimeoutKey = "filter/askTimeout"
cfgOptionAskTimeoutOrder = 3
askTimeout config.IntOption
CfgOptionPermanentVerdictsKey = "filter/permanentVerdicts"
cfgOptionPermanentVerdictsOrder = 80
permanentVerdicts config.BoolOption
CfgOptionDNSQueryInterceptionKey = "filter/dnsQueryInterception"
cfgOptionDNSQueryInterceptionOrder = 81
dnsQueryInterception config.BoolOption
)
func registerConfig() error {
err := config.Register(&config.Option{
Name: "Enable Privacy Filter",
Key: CfgOptionEnableFilterKey,
Description: "Enable the Privacy Filter. If turned off, all privacy filter protections are fully disabled on this device. Not meant to be disabled in production - only turn off for testing.",
OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelDeveloper,
ReleaseLevel: config.ReleaseLevelExperimental,
DefaultValue: true,
Annotations: config.Annotations{
config.CategoryAnnotation: "General",
},
})
if err != nil {
return err
}
filterEnabled = config.Concurrent.GetAsBool(CfgOptionEnableFilterKey, true)
err = config.Register(&config.Option{
Name: "Permanent Verdicts",
Key: CfgOptionPermanentVerdictsKey,
Description: "The Portmaster's system integration intercepts every single packet. Usually the first packet is enough for the Portmaster to set the verdict for a connection - ie. to allow or deny it. Making these verdicts permanent means that the Portmaster will tell the system integration that is does not want to see any more packets of that single connection. This brings a major performance increase.",
OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelDeveloper,
ReleaseLevel: config.ReleaseLevelExperimental,
DefaultValue: true,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionPermanentVerdictsOrder,
config.CategoryAnnotation: "Advanced",
},
})
if err != nil {
return err
}
permanentVerdicts = config.Concurrent.GetAsBool(CfgOptionPermanentVerdictsKey, true)
err = config.Register(&config.Option{
Name: "Seamless DNS Integration",
Key: CfgOptionDNSQueryInterceptionKey,
Description: "Intercept and redirect astray DNS queries to the Portmaster's internal DNS server. This enables seamless DNS integration without having to configure the system or other software. However, this may lead to compatibility issues with other software that attempts the same.",
OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelDeveloper,
ReleaseLevel: config.ReleaseLevelExperimental,
DefaultValue: true,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionDNSQueryInterceptionOrder,
config.CategoryAnnotation: "Advanced",
},
})
if err != nil {
return err
}
dnsQueryInterception = config.Concurrent.GetAsBool(CfgOptionDNSQueryInterceptionKey, true)
err = config.Register(&config.Option{
Name: "Prompt Desktop Notifications",
Key: CfgOptionAskWithSystemNotificationsKey,
Description: `In addition to showing prompt notifications in the Portmaster App, also send them to the Desktop. This requires the Portmaster Notifier to be running. Requires Desktop Notifications to be enabled.`,
OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelUser,
DefaultValue: true,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionAskWithSystemNotificationsOrder,
config.CategoryAnnotation: "General",
config.RequiresAnnotation: config.ValueRequirement{
Key: notifications.CfgUseSystemNotificationsKey,
Value: true,
},
},
})
if err != nil {
return err
}
askWithSystemNotifications = config.Concurrent.GetAsBool(CfgOptionAskWithSystemNotificationsKey, true)
err = config.Register(&config.Option{
Name: "Prompt Timeout",
Key: CfgOptionAskTimeoutKey,
Description: "How long the Portmaster will wait for a reply to a prompt notification. Please note that Desktop Notifications might not respect this or have their own limits.",
OptType: config.OptTypeInt,
ExpertiseLevel: config.ExpertiseLevelUser,
DefaultValue: 60,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionAskTimeoutOrder,
config.UnitAnnotation: "seconds",
config.CategoryAnnotation: "General",
},
ValidationRegex: `^[1-9][0-9]{1,5}$`,
})
if err != nil {
return err
}
askTimeout = config.Concurrent.GetAsInt(CfgOptionAskTimeoutKey, 60)
return nil
}
// Config variables for interception and filter module.
// Everything is registered by the interception module, as the filter module
// can be disabled.
var (
devMode config.BoolOption
apiListenAddress config.StringOption
tunnelEnabled config.BoolOption
useCommunityNodes config.BoolOption
configReady = abool.New()
)
func getConfig() {
devMode = config.Concurrent.GetAsBool(core.CfgDevModeKey, false)
apiListenAddress = config.GetAsString(api.CfgDefaultListenAddressKey, "")
tunnelEnabled = config.Concurrent.GetAsBool(captain.CfgOptionEnableSPNKey, false)
useCommunityNodes = config.Concurrent.GetAsBool(captain.CfgOptionUseCommunityNodesKey, true)
configReady.Set()
}

348
service/firewall/dns.go Normal file
View File

@@ -0,0 +1,348 @@
package firewall
import (
"context"
"errors"
"net"
"strings"
"github.com/miekg/dns"
"github.com/safing/portbase/database"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/service/resolver"
)
func filterDNSSection(
ctx context.Context,
entries []dns.RR,
p *profile.LayeredProfile,
resolverScope netutils.IPScope,
sysResolver bool,
) ([]dns.RR, []string, int, string) {
// Will be filled 1:1 most of the time.
goodEntries := make([]dns.RR, 0, len(entries))
// Will stay empty most of the time.
var filteredRecords []string
// keeps track of the number of valid and allowed
// A and AAAA records.
var allowedAddressRecords int
var interveningOptionKey string
for _, rr := range entries {
// get IP and classification
var ip net.IP
switch v := rr.(type) {
case *dns.A:
ip = v.A
case *dns.AAAA:
ip = v.AAAA
default:
// add non A/AAAA entries
// TODO: Add support for dns.SVCB and dns.HTTPS
goodEntries = append(goodEntries, rr)
continue
}
ipScope := netutils.GetIPScope(ip)
if p.RemoveOutOfScopeDNS() {
switch {
case ipScope.IsLocalhost():
// No DNS should return localhost addresses
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionRemoveOutOfScopeDNSKey
log.Tracer(ctx).Tracef("filter: RR violates resolver scope: %s", formatRR(rr))
continue
case resolverScope.IsGlobal() && ipScope.IsLAN() && !sysResolver:
// No global DNS should return LAN addresses
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionRemoveOutOfScopeDNSKey
log.Tracer(ctx).Tracef("filter: RR violates resolver scope: %s", formatRR(rr))
continue
}
}
if p.RemoveBlockedDNS() && !sysResolver {
// filter by flags
switch {
case p.BlockScopeInternet() && ipScope.IsGlobal():
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionBlockScopeInternetKey
log.Tracer(ctx).Tracef("filter: RR is in blocked scope Internet: %s", formatRR(rr))
continue
case p.BlockScopeLAN() && ipScope.IsLAN():
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionBlockScopeLANKey
log.Tracer(ctx).Tracef("filter: RR is in blocked scope LAN: %s", formatRR(rr))
continue
case p.BlockScopeLocal() && ipScope.IsLocalhost():
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionBlockScopeLocalKey
log.Tracer(ctx).Tracef("filter: RR is in blocked scope Localhost: %s", formatRR(rr))
continue
}
// TODO: filter by endpoint list (IP only)
}
// if survived, add to good entries
allowedAddressRecords++
goodEntries = append(goodEntries, rr)
}
return goodEntries, filteredRecords, allowedAddressRecords, interveningOptionKey
}
func filterDNSResponse(
ctx context.Context,
conn *network.Connection,
p *profile.LayeredProfile,
rrCache *resolver.RRCache,
sysResolver bool,
) *resolver.RRCache {
// do not modify own queries
if conn.Process().Pid == ownPID {
return rrCache
}
// check if DNS response filtering is completely turned off
if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() {
return rrCache
}
var filteredRecords []string
var validIPs int
var interveningOptionKey string
rrCache.Answer, filteredRecords, validIPs, interveningOptionKey = filterDNSSection(ctx, rrCache.Answer, p, rrCache.Resolver.IPScope, sysResolver)
if len(filteredRecords) > 0 {
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
}
// Don't count the valid IPs in the extra section.
rrCache.Extra, filteredRecords, _, _ = filterDNSSection(ctx, rrCache.Extra, p, rrCache.Resolver.IPScope, sysResolver)
if len(filteredRecords) > 0 {
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
}
if len(rrCache.FilteredEntries) > 0 {
rrCache.Filtered = true
if validIPs == 0 {
switch interveningOptionKey {
case profile.CfgOptionBlockScopeInternetKey:
conn.Block("Internet access blocked", interveningOptionKey)
case profile.CfgOptionBlockScopeLANKey:
conn.Block("LAN access blocked", interveningOptionKey)
case profile.CfgOptionBlockScopeLocalKey:
conn.Block("Localhost access blocked", interveningOptionKey)
case profile.CfgOptionRemoveOutOfScopeDNSKey:
conn.Block("DNS global/private split-view violation", interveningOptionKey)
default:
conn.Block("DNS response only contained to-be-blocked IPs", interveningOptionKey)
}
return rrCache
}
}
return rrCache
}
// FilterResolvedDNS filters a dns response according to the application
// profile and settings.
func FilterResolvedDNS(
ctx context.Context,
conn *network.Connection,
q *resolver.Query,
rrCache *resolver.RRCache,
) *resolver.RRCache {
// Check if we have a process and profile.
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
log.Tracer(ctx).Warning("unknown process or profile")
return nil
}
// Don't filter env responses.
if rrCache.Resolver.Type == resolver.ServerTypeEnv {
return rrCache
}
// special grant for connectivity domains
if checkConnectivityDomain(ctx, conn, layeredProfile, nil) {
// returns true if check triggered
return rrCache
}
// Only filter critical things if request comes from the system resolver.
sysResolver := conn.Process().IsSystemResolver()
// Filter dns records and return if the query is blocked.
rrCache = filterDNSResponse(ctx, conn, layeredProfile, rrCache, sysResolver)
if conn.Verdict == network.VerdictBlock {
return rrCache
}
// Block by CNAMEs.
if !sysResolver {
mayBlockCNAMEs(ctx, conn, layeredProfile)
}
return rrCache
}
func mayBlockCNAMEs(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile) bool {
// if we have CNAMEs and the profile is configured to filter them
// we need to re-check the lists and endpoints here
if p.FilterCNAMEs() {
conn.Entity.ResetLists()
conn.Entity.EnableCNAMECheck(ctx, true)
result, reason := p.MatchEndpoint(ctx, conn.Entity)
if result == endpoints.Denied {
conn.BlockWithContext(reason.String(), profile.CfgOptionFilterCNAMEKey, reason.Context())
return true
}
if result == endpoints.NoMatch {
result, reason = p.MatchFilterLists(ctx, conn.Entity)
if result == endpoints.Denied {
conn.BlockWithContext(reason.String(), profile.CfgOptionFilterCNAMEKey, reason.Context())
return true
}
}
}
return false
}
// UpdateIPsAndCNAMEs saves all the IP->Name mappings to the cache database and
// updates the CNAMEs in the Connection's Entity.
func UpdateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) {
// Sanity check input, as this is called from defer.
if q == nil || rrCache == nil {
return
}
// Get profileID for scoping IPInfo.
var profileID string
localProfile := conn.Process().Profile().LocalProfile()
switch localProfile.ID {
case profile.UnidentifiedProfileID,
profile.SystemResolverProfileID:
profileID = resolver.IPInfoProfileScopeGlobal
default:
profileID = localProfile.ID
}
// Collect IPs and CNAMEs.
cnames := make(map[string]string)
ips := make([]net.IP, 0, len(rrCache.Answer))
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
switch v := rr.(type) {
case *dns.CNAME:
cnames[v.Hdr.Name] = v.Target
case *dns.A:
ips = append(ips, v.A)
case *dns.AAAA:
ips = append(ips, v.AAAA)
case *dns.SVCB:
if len(v.Target) >= 2 { // Ignore "" and ".".
cnames[v.Hdr.Name] = v.Target
}
for _, pair := range v.Value {
switch svcbParam := pair.(type) {
case *dns.SVCBIPv4Hint:
ips = append(ips, svcbParam.Hint...)
case *dns.SVCBIPv6Hint:
ips = append(ips, svcbParam.Hint...)
}
}
case *dns.HTTPS:
if len(v.Target) >= 2 { // Ignore "" and ".".
cnames[v.Hdr.Name] = v.Target
}
for _, pair := range v.Value {
switch svcbParam := pair.(type) {
case *dns.SVCBIPv4Hint:
ips = append(ips, svcbParam.Hint...)
case *dns.SVCBIPv6Hint:
ips = append(ips, svcbParam.Hint...)
}
}
}
}
// Package IPs and CNAMEs into IPInfo structs.
for _, ip := range ips {
// Never save domain attributions for localhost IPs.
if netutils.GetIPScope(ip) == netutils.HostLocal {
continue
}
// Create new record for this IP.
record := resolver.ResolvedDomain{
Domain: q.FQDN,
Resolver: rrCache.Resolver,
DNSRequestContext: rrCache.ToDNSRequestContext(),
Expires: rrCache.Expires,
}
// Resolve all CNAMEs in the correct order and add the to the record.
domain := q.FQDN
for {
nextDomain, isCNAME := cnames[domain]
if !isCNAME {
break
}
record.CNAMEs = append(record.CNAMEs, nextDomain)
domain = nextDomain
}
// Update the entity to include the CNAMEs of the query response.
conn.Entity.CNAME = record.CNAMEs
// Check if there is an existing record for this DNS response.
// Else create a new one.
ipString := ip.String()
info, err := resolver.GetIPInfo(profileID, ipString)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
log.Errorf("nameserver: failed to search for IP info record: %s", err)
}
info = &resolver.IPInfo{
IP: ipString,
ProfileID: profileID,
}
}
// Add the new record to the resolved domains for this IP and scope.
info.AddDomain(record)
// Save if the record is new or has been updated.
if err := info.Save(); err != nil {
log.Errorf("nameserver: failed to save IP info record: %s", err)
}
}
}
// formatRR is a friendlier alternative to miekg/dns.RR.String().
func formatRR(rr dns.RR) string {
return strings.ReplaceAll(rr.String(), "\t", " ")
}

View File

@@ -0,0 +1,102 @@
package inspection
import (
"sync"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
)
//nolint:golint,stylecheck // FIXME
const (
DO_NOTHING uint8 = iota
BLOCK_PACKET
DROP_PACKET
BLOCK_CONN
DROP_CONN
STOP_INSPECTING
)
type inspectorFn func(*network.Connection, packet.Packet) uint8
var (
inspectors []inspectorFn
inspectorNames []string
inspectVerdicts []network.Verdict
inspectorsLock sync.Mutex
)
// RegisterInspector registers a traffic inspector.
func RegisterInspector(name string, inspector inspectorFn, inspectVerdict network.Verdict) (index int) {
inspectorsLock.Lock()
defer inspectorsLock.Unlock()
index = len(inspectors)
inspectors = append(inspectors, inspector)
inspectorNames = append(inspectorNames, name)
inspectVerdicts = append(inspectVerdicts, inspectVerdict)
return
}
// RunInspectors runs all the applicable inspectors on the given packet.
func RunInspectors(conn *network.Connection, pkt packet.Packet) (network.Verdict, bool) {
// inspectorsLock.Lock()
// defer inspectorsLock.Unlock()
activeInspectors := conn.GetActiveInspectors()
if activeInspectors == nil {
activeInspectors = make([]bool, len(inspectors))
conn.SetActiveInspectors(activeInspectors)
}
inspectorData := conn.GetInspectorData()
if inspectorData == nil {
inspectorData = make(map[uint8]interface{})
conn.SetInspectorData(inspectorData)
}
continueInspection := false
verdict := network.VerdictUndecided
for key, skip := range activeInspectors {
if skip {
continue
}
// check if the active verdict is already past the inspection criteria.
if conn.Verdict > inspectVerdicts[key] {
activeInspectors[key] = true
continue
}
action := inspectors[key](conn, pkt) // Actually run inspector
switch action {
case DO_NOTHING:
if verdict < network.VerdictAccept {
verdict = network.VerdictAccept
}
continueInspection = true
case BLOCK_PACKET:
if verdict < network.VerdictBlock {
verdict = network.VerdictBlock
}
continueInspection = true
case DROP_PACKET:
verdict = network.VerdictDrop
continueInspection = true
case BLOCK_CONN:
conn.SetVerdict(network.VerdictBlock, "", "", nil)
verdict = conn.Verdict
activeInspectors[key] = true
case DROP_CONN:
conn.SetVerdict(network.VerdictDrop, "", "", nil)
verdict = conn.Verdict
activeInspectors[key] = true
case STOP_INSPECTING:
activeInspectors[key] = true
}
}
return verdict, continueInspection
}

View File

@@ -0,0 +1,147 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
type bpfSkInfo struct {
Rx uint64
Tx uint64
Reported uint64
}
type bpfSkKey struct {
SrcIp [4]uint32
DstIp [4]uint32
SrcPort uint16
DstPort uint16
Protocol uint8
Ipv6 uint8
_ [2]byte
}
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
SocketOperations *ebpf.ProgramSpec `ebpf:"socket_operations"`
UdpRecvmsg *ebpf.ProgramSpec `ebpf:"udp_recvmsg"`
UdpSendmsg *ebpf.ProgramSpec `ebpf:"udp_sendmsg"`
Udpv6Recvmsg *ebpf.ProgramSpec `ebpf:"udpv6_recvmsg"`
Udpv6Sendmsg *ebpf.ProgramSpec `ebpf:"udpv6_sendmsg"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PmBandwidthMap *ebpf.MapSpec `ebpf:"pm_bandwidth_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PmBandwidthMap *ebpf.Map `ebpf:"pm_bandwidth_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PmBandwidthMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
SocketOperations *ebpf.Program `ebpf:"socket_operations"`
UdpRecvmsg *ebpf.Program `ebpf:"udp_recvmsg"`
UdpSendmsg *ebpf.Program `ebpf:"udp_sendmsg"`
Udpv6Recvmsg *ebpf.Program `ebpf:"udpv6_recvmsg"`
Udpv6Sendmsg *ebpf.Program `ebpf:"udpv6_sendmsg"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.SocketOperations,
p.UdpRecvmsg,
p.UdpSendmsg,
p.Udpv6Recvmsg,
p.Udpv6Sendmsg,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

View File

@@ -0,0 +1,147 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || loong64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
type bpfSkInfo struct {
Rx uint64
Tx uint64
Reported uint64
}
type bpfSkKey struct {
SrcIp [4]uint32
DstIp [4]uint32
SrcPort uint16
DstPort uint16
Protocol uint8
Ipv6 uint8
_ [2]byte
}
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
SocketOperations *ebpf.ProgramSpec `ebpf:"socket_operations"`
UdpRecvmsg *ebpf.ProgramSpec `ebpf:"udp_recvmsg"`
UdpSendmsg *ebpf.ProgramSpec `ebpf:"udp_sendmsg"`
Udpv6Recvmsg *ebpf.ProgramSpec `ebpf:"udpv6_recvmsg"`
Udpv6Sendmsg *ebpf.ProgramSpec `ebpf:"udpv6_sendmsg"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PmBandwidthMap *ebpf.MapSpec `ebpf:"pm_bandwidth_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PmBandwidthMap *ebpf.Map `ebpf:"pm_bandwidth_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PmBandwidthMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
SocketOperations *ebpf.Program `ebpf:"socket_operations"`
UdpRecvmsg *ebpf.Program `ebpf:"udp_recvmsg"`
UdpSendmsg *ebpf.Program `ebpf:"udp_sendmsg"`
Udpv6Recvmsg *ebpf.Program `ebpf:"udpv6_recvmsg"`
Udpv6Sendmsg *ebpf.Program `ebpf:"udpv6_sendmsg"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.SocketOperations,
p.UdpRecvmsg,
p.UdpSendmsg,
p.Udpv6Recvmsg,
p.Udpv6Sendmsg,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfel.o
var _BpfBytes []byte

View File

@@ -0,0 +1,192 @@
package ebpf
import (
"context"
"encoding/binary"
"fmt"
"net"
"path/filepath"
"sync/atomic"
"syscall"
"time"
"github.com/cilium/ebpf"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
"golang.org/x/sys/unix"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/packet"
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" bpf ../programs/bandwidth.c
var ebpfLoadingFailed atomic.Uint32
// BandwidthStatsWorker monitors connection bandwidth using ebpf.
func BandwidthStatsWorker(ctx context.Context, collectInterval time.Duration, bandwidthUpdates chan *packet.BandwidthUpdate) error {
// Allow the current process to lock memory for eBPF resources.
err := rlimit.RemoveMemlock()
if err != nil {
if ebpfLoadingFailed.Add(1) >= 5 {
log.Warningf("ebpf: failed to remove memlock 5 times, giving up with error %s", err)
return nil
}
return fmt.Errorf("ebpf: failed to remove memlock: %w", err)
}
// Load pre-compiled programs and maps into the kernel.
objs := bpfObjects{}
if err := loadBpfObjects(&objs, nil); err != nil {
if ebpfLoadingFailed.Add(1) >= 5 {
log.Warningf("ebpf: failed to load ebpf object 5 times, giving up with error %s", err)
return nil
}
return fmt.Errorf("ebpf: failed to load ebpf object: %w", err)
}
defer objs.Close() //nolint:errcheck
// Find the cgroup path
path, err := findCgroupPath()
if err != nil {
return fmt.Errorf("ebpf: failed to find cgroup paths: %w", err)
}
// Attach socket options for monitoring connections
sockOptionsLink, err := link.AttachCgroup(link.CgroupOptions{
Path: path,
Program: objs.bpfPrograms.SocketOperations,
Attach: ebpf.AttachCGroupSockOps,
})
if err != nil {
return fmt.Errorf("ebpf: failed to open module sockops: %w", err)
}
defer sockOptionsLink.Close() //nolint:errcheck
// Attach Udp Ipv4 recive message tracing
udpv4RMLink, err := link.AttachTracing(link.TracingOptions{
Program: objs.bpfPrograms.UdpRecvmsg,
})
if err != nil {
return fmt.Errorf("ebpf: failed to open trace Udp IPv4 recvmsg: %w", err)
}
defer udpv4RMLink.Close() //nolint:errcheck
// Attach UDP IPv4 send message tracing
udpv4SMLink, err := link.AttachTracing(link.TracingOptions{
Program: objs.bpfPrograms.UdpSendmsg,
})
if err != nil {
return fmt.Errorf("ebpf: failed to open trace Udp IPv4 sendmsg: %w", err)
}
defer udpv4SMLink.Close() //nolint:errcheck
// Attach UDP IPv6 receive message tracing
udpv6RMLink, err := link.AttachTracing(link.TracingOptions{
Program: objs.bpfPrograms.Udpv6Recvmsg,
})
if err != nil {
return fmt.Errorf("ebpf: failed to open trace Udp IPv6 recvmsg: %w", err)
}
defer udpv6RMLink.Close() //nolint:errcheck
// Attach UDP IPv6 send message tracing
udpv6SMLink, err := link.AttachTracing(link.TracingOptions{
Program: objs.bpfPrograms.Udpv6Sendmsg,
})
if err != nil {
return fmt.Errorf("ebpf: failed to open trace Udp IPv6 sendmsg: %w", err)
}
defer udpv6SMLink.Close() //nolint:errcheck
// Setup ticker.
ticker := time.NewTicker(collectInterval)
defer ticker.Stop()
// Collect bandwidth at every tick.
for {
select {
case <-ticker.C:
reportBandwidth(ctx, objs, bandwidthUpdates)
case <-ctx.Done():
return nil
}
}
}
// reportBandwidth reports the bandwidth to the given updates channel.
func reportBandwidth(ctx context.Context, objs bpfObjects, bandwidthUpdates chan *packet.BandwidthUpdate) {
var (
skKey bpfSkKey
skInfo bpfSkInfo
updated int
skipped int
)
iter := objs.bpfMaps.PmBandwidthMap.Iterate()
for iter.Next(&skKey, &skInfo) {
// Check if already reported.
if skInfo.Reported >= 1 {
skipped++
continue
}
// Mark as reported and update the map.
skInfo.Reported = 1
if err := objs.bpfMaps.PmBandwidthMap.Update(&skKey, &skInfo, ebpf.UpdateExist); err != nil {
log.Debugf("ebpf: failed to mark bandwidth map entry as reported: %s", err)
}
connID := packet.CreateConnectionID(
packet.IPProtocol(skKey.Protocol),
convertArrayToIP(skKey.SrcIp, skKey.Ipv6 == 1), skKey.SrcPort,
convertArrayToIP(skKey.DstIp, skKey.Ipv6 == 1), skKey.DstPort,
false,
)
update := &packet.BandwidthUpdate{
ConnID: connID,
BytesReceived: skInfo.Rx,
BytesSent: skInfo.Tx,
Method: packet.Absolute,
}
select {
case bandwidthUpdates <- update:
updated++
case <-ctx.Done():
return
default:
log.Warningf("ebpf: bandwidth update queue is full (updated=%d, skipped=%d), ignoring rest of batch", updated, skipped)
return
}
}
}
// findCgroupPath returns the default unified path of the cgroup.
func findCgroupPath() (string, error) {
cgroupPath := "/sys/fs/cgroup"
var st syscall.Statfs_t
err := syscall.Statfs(cgroupPath, &st)
if err != nil {
return "", err
}
isCgroupV2Enabled := st.Type == unix.CGROUP2_SUPER_MAGIC
if !isCgroupV2Enabled {
cgroupPath = filepath.Join(cgroupPath, "unified")
}
return cgroupPath, nil
}
// convertArrayToIP converts an array of uint32 values to a net.IP address.
func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP {
if !ipv6 {
addressBuf := make([]byte, 4)
binary.LittleEndian.PutUint32(addressBuf, input[0])
return net.IP(addressBuf)
}
addressBuf := make([]byte, 16)
for i := 0; i < 4; i++ {
binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
}
return net.IP(addressBuf)
}

View File

@@ -0,0 +1,137 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
type bpfEvent struct {
Saddr [4]uint32
Daddr [4]uint32
Sport uint16
Dport uint16
Pid uint32
IpVersion uint8
Protocol uint8
Direction uint8
_ [1]byte
}
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcpConnect *ebpf.ProgramSpec `ebpf:"tcp_connect"`
UdpV4Connect *ebpf.ProgramSpec `ebpf:"udp_v4_connect"`
UdpV6Connect *ebpf.ProgramSpec `ebpf:"udp_v6_connect"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PmConnectionEvents *ebpf.MapSpec `ebpf:"pm_connection_events"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PmConnectionEvents *ebpf.Map `ebpf:"pm_connection_events"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PmConnectionEvents,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcpConnect *ebpf.Program `ebpf:"tcp_connect"`
UdpV4Connect *ebpf.Program `ebpf:"udp_v4_connect"`
UdpV6Connect *ebpf.Program `ebpf:"udp_v6_connect"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcpConnect,
p.UdpV4Connect,
p.UdpV6Connect,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

View File

@@ -0,0 +1,137 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || loong64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
type bpfEvent struct {
Saddr [4]uint32
Daddr [4]uint32
Sport uint16
Dport uint16
Pid uint32
IpVersion uint8
Protocol uint8
Direction uint8
_ [1]byte
}
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcpConnect *ebpf.ProgramSpec `ebpf:"tcp_connect"`
UdpV4Connect *ebpf.ProgramSpec `ebpf:"udp_v4_connect"`
UdpV6Connect *ebpf.ProgramSpec `ebpf:"udp_v6_connect"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PmConnectionEvents *ebpf.MapSpec `ebpf:"pm_connection_events"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PmConnectionEvents *ebpf.Map `ebpf:"pm_connection_events"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PmConnectionEvents,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcpConnect *ebpf.Program `ebpf:"tcp_connect"`
UdpV4Connect *ebpf.Program `ebpf:"udp_v4_connect"`
UdpV6Connect *ebpf.Program `ebpf:"udp_v6_connect"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcpConnect,
p.UdpV4Connect,
p.UdpV6Connect,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfel.o
var _BpfBytes []byte

View File

@@ -0,0 +1,176 @@
package ebpf
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"sync/atomic"
"time"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/ringbuf"
"github.com/cilium/ebpf/rlimit"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/packet"
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" -type Event bpf ../programs/monitor.c
var ebpfLoadingFailed atomic.Uint32
// ConnectionListenerWorker listens to new connections using ebpf.
func ConnectionListenerWorker(ctx context.Context, packets chan packet.Packet) error {
// Allow the current process to lock memory for eBPF resources.
if err := rlimit.RemoveMemlock(); err != nil {
if ebpfLoadingFailed.Add(1) >= 5 {
log.Warningf("ebpf: failed to remove memlock 5 times, giving up with error %s", err)
return nil
}
return fmt.Errorf("ebpf: failed to remove ebpf memlock: %w", err)
}
// Load pre-compiled programs and maps into the kernel.
objs := bpfObjects{}
if err := loadBpfObjects(&objs, nil); err != nil {
if ebpfLoadingFailed.Add(1) >= 5 {
log.Warningf("ebpf: failed to load ebpf object 5 times, giving up with error %s", err)
return nil
}
return fmt.Errorf("ebpf: failed to load ebpf object: %w", err)
}
defer objs.Close() //nolint:errcheck
// Create a link to the tcp_connect program.
linkTCPConnect, err := link.AttachTracing(link.TracingOptions{
Program: objs.bpfPrograms.TcpConnect,
})
if err != nil {
return fmt.Errorf("ebpf: failed to attach to tcp_v4_connect: %w", err)
}
defer linkTCPConnect.Close() //nolint:errcheck
// Create a link to the udp_v4_connect program.
linkUDPV4, err := link.AttachTracing(link.TracingOptions{
Program: objs.bpfPrograms.UdpV4Connect,
})
if err != nil {
return fmt.Errorf("ebpf: failed to attach to udp_v4_connect: %w", err)
}
defer linkUDPV4.Close() //nolint:errcheck
// Create a link to the udp_v6_connect program.
linkUDPV6, err := link.AttachTracing(link.TracingOptions{
Program: objs.bpfPrograms.UdpV6Connect,
})
if err != nil {
return fmt.Errorf("ebpf: failed to attach to udp_v6_connect: %w", err)
}
defer linkUDPV6.Close() //nolint:errcheck
// Create new reader to read events.
rd, err := ringbuf.NewReader(objs.bpfMaps.PmConnectionEvents)
if err != nil {
return fmt.Errorf("ebpf: failed to open ring buffer: %w", err)
}
defer rd.Close() //nolint:errcheck
// Start watcher to close the reader when the context is canceled.
// TODO: Can we put this into a worker?
go func() {
<-ctx.Done()
if err := rd.Close(); err != nil {
log.Errorf("ebpf: failed closing ringbuf reader: %s", err)
}
}()
for {
// Read next event
record, err := rd.Read()
if err != nil {
if errors.Is(err, ringbuf.ErrClosed) {
// Normal return
return nil
}
log.Errorf("ebpf: failed to read from ring buffer: %s", err)
continue
}
var event bpfEvent
// Parse the ringbuf event entry into a bpfEvent structure.
if err := binary.Read(bytes.NewBuffer(record.RawSample), binary.BigEndian, &event); err != nil {
log.Errorf("ebpf: failed to parse ringbuf event: %s", err)
continue
}
pkt := packet.NewInfoPacket(packet.Info{
Inbound: event.Direction == 1,
InTunnel: false,
Version: packet.IPVersion(event.IpVersion),
Protocol: packet.IPProtocol(event.Protocol),
SrcPort: event.Sport,
DstPort: event.Dport,
Src: convertArrayToIPv4(event.Saddr, packet.IPVersion(event.IpVersion)),
Dst: convertArrayToIPv4(event.Daddr, packet.IPVersion(event.IpVersion)),
PID: int(event.Pid),
SeenAt: time.Now(),
})
if isEventValid(event) {
// DEBUG:
// log.Debugf("ebpf: received valid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt)
packets <- pkt
} else {
log.Warningf("ebpf: received invalid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt)
}
}
}
// isEventValid checks whether the given bpfEvent is valid or not.
// It returns true if the event is valid, otherwise false.
func isEventValid(event bpfEvent) bool {
// Check if the destination port is 0
if event.Dport == 0 {
return false
}
// Check if the source port is 0
if event.Sport == 0 {
return false
}
// Check if the process ID is 0
if event.Pid == 0 {
return false
}
// If the IP version is IPv4
if event.IpVersion == 4 {
if event.Saddr[0] == 0 {
return false
}
if event.Daddr[0] == 0 {
return false
}
}
return true
}
// convertArrayToIPv4 converts an array of uint32 values to an IPv4 net.IP address.
func convertArrayToIPv4(input [4]uint32, ipVersion packet.IPVersion) net.IP {
if ipVersion == packet.IPv4 {
addressBuf := make([]byte, 4)
binary.LittleEndian.PutUint32(addressBuf, input[0])
return net.IP(addressBuf)
}
addressBuf := make([]byte, 16)
for i := 0; i < 4; i++ {
binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
}
return net.IP(addressBuf)
}

View File

@@ -0,0 +1,119 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
EnterExecve *ebpf.ProgramSpec `ebpf:"enter_execve"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PmExecMap *ebpf.MapSpec `ebpf:"pm_exec_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PmExecMap *ebpf.Map `ebpf:"pm_exec_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PmExecMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
EnterExecve *ebpf.Program `ebpf:"enter_execve"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.EnterExecve,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

Binary file not shown.

View File

@@ -0,0 +1,119 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || loong64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
EnterExecve *ebpf.ProgramSpec `ebpf:"enter_execve"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PmExecMap *ebpf.MapSpec `ebpf:"pm_exec_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PmExecMap *ebpf.Map `ebpf:"pm_exec_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PmExecMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
EnterExecve *ebpf.Program `ebpf:"enter_execve"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.EnterExecve,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfel.o
var _BpfBytes []byte

Binary file not shown.

View File

@@ -0,0 +1,249 @@
package ebpf
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"runtime"
"runtime/debug"
"strings"
"sync"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/ringbuf"
"github.com/cilium/ebpf/rlimit"
"github.com/hashicorp/go-multierror"
"golang.org/x/sys/unix"
"github.com/safing/portbase/log"
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" bpf ../programs/exec.c
// These constants are defined in `bpf/handler.c` and must be kept in sync.
const (
arglen = 32
argsize = 1024
)
var errTracerClosed = errors.New("tracer is closed")
// event contains details about each exec call, sent from the eBPF program to
// userspace through a perf ring buffer. This type must be kept in sync with
// `event_t` in `bpf/handler.c`.
type event struct {
// Details about the process being launched.
Filename [argsize]byte
Argv [arglen][argsize]byte
Argc uint32
UID uint32
GID uint32
PID uint32
// Name of the calling process.
Comm [argsize]byte
}
// Event contains data about each exec event with many fields for easy
// filtering and logging.
type Event struct {
Filename string `json:"filename"`
// Argv contains the raw argv supplied to the process, including argv[0]
// (which is equal to `filepath.Base(e.Filename)` in most circumstances).
Argv []string `json:"argv"`
// Truncated is true if we were unable to read all process arguments into
// Argv because there were more than ARGLEN arguments.
Truncated bool `json:"truncated"`
// These values are of the new process. Keep in mind that the exec call may
// fail and the PID will be released in such a case.
PID uint32 `json:"pid"`
UID uint32 `json:"uid"`
GID uint32 `json:"gid"`
// Comm is the "name" of the parent process, usually the filename of the
// executable (but not always).
Comm string `json:"comm"`
}
// Tracer is the exec tracer itself.
// It must be closed after use.
type Tracer struct {
objs bpfObjects
tp link.Link
rb *ringbuf.Reader
closeLock sync.Mutex
closed chan struct{}
}
// New instantiates all of the BPF objects into the running kernel, starts
// tracing, and returns the created Tracer. After calling this successfully, the
// caller should immediately attach a for loop running `h.Read()`.
//
// The returned Tracer MUST be closed when not needed anymore otherwise kernel
// resources may be leaked.
func New() (*Tracer, error) {
t := &Tracer{
tp: nil,
rb: nil,
closeLock: sync.Mutex{},
closed: make(chan struct{}),
}
if err := loadBpfObjects(&t.objs, nil); err != nil {
return nil, fmt.Errorf("ebpf: failed to load ebpf object: %w", err)
}
if err := t.start(); err != nil {
// Best effort.
_ = t.Close()
return nil, fmt.Errorf("start tracer: %w", err)
}
// It could be very bad if someone forgot to close this, so we'll try to
// detect when it doesn't get closed and log a warning.
stack := debug.Stack()
runtime.SetFinalizer(t, func(t *Tracer) {
err := t.Close()
if errors.Is(err, errTracerClosed) {
return
}
log.Infof("tracer was finalized but was not closed, created at: %s", stack)
log.Infof("tracers must be closed when finished with to avoid leaked kernel resources")
if err != nil {
log.Errorf("closing tracer failed: %+v", err)
}
})
return t, nil
}
// start loads the eBPF programs and maps into the kernel and starts them.
// You should immediately attach a for loop running `h.Read()` after calling
// this successfully.
func (t *Tracer) start() error {
// If we don't startup successfully, we need to make sure all of the
// stuff is cleaned up properly or we'll be leaking kernel resources.
ok := false
defer func() {
if !ok {
// Best effort.
_ = t.Close()
}
}()
// Allow the current process to lock memory for eBPF resources. This
// does nothing on 5.11+ kernels which don't need this.
err := rlimit.RemoveMemlock()
if err != nil {
return fmt.Errorf("remove memlock: %w", err)
}
// Attach the eBPF program to the `sys_enter_execve` tracepoint, which
// is triggered at the beginning of each `execve()` syscall.
t.tp, err = link.Tracepoint("syscalls", "sys_enter_execve", t.objs.EnterExecve, nil)
if err != nil {
return fmt.Errorf("open tracepoint: %w", err)
}
// Create the reader for the event ringbuf.
t.rb, err = ringbuf.NewReader(t.objs.PmExecMap)
if err != nil {
return fmt.Errorf("open ringbuf reader: %w", err)
}
ok = true
return nil
}
// Read reads an event from the eBPF program via the ringbuf, parses it and
// returns it. If the *tracer is closed during the blocked call, and error that
// wraps io.EOF will be returned.
func (t *Tracer) Read() (*Event, error) {
rb := t.rb
if rb == nil {
return nil, errors.New("ringbuf reader is not initialized, tracer may not be open or may have been closed")
}
record, err := rb.Read()
if err != nil {
if errors.Is(err, ringbuf.ErrClosed) {
return nil, fmt.Errorf("tracer closed: %w", io.EOF)
}
return nil, fmt.Errorf("read from ringbuf: %w", err)
}
// Parse the ringbuf event entry into an event structure.
var rawEvent event
err = binary.Read(bytes.NewBuffer(record.RawSample), binary.NativeEndian, &rawEvent)
if err != nil {
return nil, fmt.Errorf("parse raw ringbuf entry into event struct: %w", err)
}
ev := &Event{
Filename: unix.ByteSliceToString(rawEvent.Filename[:]),
Argv: []string{}, // populated below
Truncated: rawEvent.Argc == arglen+1,
PID: rawEvent.PID,
UID: rawEvent.UID,
GID: rawEvent.GID,
Comm: unix.ByteSliceToString(rawEvent.Comm[:]),
}
// Copy only the args we're allowed to read from the array. If we read more
// than rawEvent.Argc, we could be copying non-zeroed memory.
argc := int(rawEvent.Argc)
if argc > arglen {
argc = arglen
}
for i := 0; i < argc; i++ {
str := unix.ByteSliceToString(rawEvent.Argv[i][:])
if strings.TrimSpace(str) != "" {
ev.Argv = append(ev.Argv, str)
}
}
return ev, nil
}
// Close gracefully closes and frees all resources associated with the eBPF
// tracepoints, maps and other resources. Any blocked `Read()` operations will
// return an error that wraps `io.EOF`.
func (t *Tracer) Close() error {
t.closeLock.Lock()
defer t.closeLock.Unlock()
select {
case <-t.closed:
return errTracerClosed
default:
}
close(t.closed)
runtime.SetFinalizer(t, nil)
// Close everything started in h.Start() in reverse order.
var merr error
if t.rb != nil {
err := t.rb.Close()
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("close ringbuf reader: %w", err))
}
}
if t.tp != nil {
err := t.tp.Close()
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("close tracepoint: %w", err))
}
}
err := t.objs.Close()
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("close eBPF objects: %w", err))
}
return merr
}

View File

@@ -0,0 +1,219 @@
#include "vmlinux-x86.h"
#include "bpf/bpf_helpers.h"
#include "bpf/bpf_tracing.h"
#include "bpf/bpf_core_read.h"
#define AF_INET 2
#define AF_INET6 10
#define PROTOCOL_TCP 6
#define PROTOCOL_UDP 17
char __license[] SEC("license") = "GPL";
struct sk_key {
u32 src_ip[4];
u32 dst_ip[4];
u16 src_port;
u16 dst_port;
u8 protocol;
u8 ipv6;
};
struct sk_info {
u64 rx;
u64 tx;
u64 reported;
};
// Max number of connections that will be kept. Increse the number if it's not enough.
#define SOCKOPS_MAP_SIZE 5000
struct {
__uint(type, BPF_MAP_TYPE_LRU_HASH);
__uint(max_entries, SOCKOPS_MAP_SIZE);
__type(key, struct sk_key);
__type(value, struct sk_info);
} pm_bandwidth_map SEC(".maps");
SEC("sockops")
int socket_operations(struct bpf_sock_ops *skops) {
switch (skops->op) {
case BPF_SOCK_OPS_TCP_CONNECT_CB: // Outgoing connections
// Set flag so any modification on the socket, will trigger this function.
bpf_sock_ops_cb_flags_set(skops, BPF_SOCK_OPS_ALL_CB_FLAGS);
return 0;
case BPF_SOCK_OPS_TCP_LISTEN_CB: // Listening ports
bpf_sock_ops_cb_flags_set(skops, BPF_SOCK_OPS_ALL_CB_FLAGS);
// No rx tx data for this socket object.
return 0;
case BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB: // Incoming connections
// Set flag so any modification on the socket, will trigger this function.
bpf_sock_ops_cb_flags_set(skops, BPF_SOCK_OPS_ALL_CB_FLAGS);
return 0;
default:
break;
}
struct bpf_sock *sk = skops->sk;
if (sk == NULL) {
return 0;
}
struct sk_key key = {0};
key.protocol = PROTOCOL_TCP;
if(sk->family == AF_INET) {
// Generate key for IPv4
key.src_ip[0] = sk->src_ip4;
key.src_port = sk->src_port;
key.dst_ip[0] = sk->dst_ip4;
key.dst_port = __builtin_bswap16(sk->dst_port);
key.ipv6 = 0;
struct sk_info newInfo = {0};
newInfo.rx = skops->bytes_received;
newInfo.tx = skops->bytes_acked;
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
} else if(sk->family == AF_INET6){
// Generate key for IPv6
key.src_ip[0] = sk->src_ip6[0];
key.src_ip[1] = sk->src_ip6[1];
key.src_ip[2] = sk->src_ip6[2];
key.src_ip[3] = sk->src_ip6[3];
key.src_port = sk->src_port;
key.dst_ip[0] = sk->dst_ip6[0];
key.dst_ip[1] = sk->dst_ip6[1];
key.dst_ip[2] = sk->dst_ip6[2];
key.dst_ip[3] = sk->dst_ip6[3];
key.dst_port = __builtin_bswap16(sk->dst_port);
key.ipv6 = 1;
struct sk_info newInfo = {0};
newInfo.rx = skops->bytes_received;
newInfo.tx = skops->bytes_acked;
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
}
return 0;
}
// udp_sendmsg hookes to the respective kernel function and saves the bandwidth data
SEC("fentry/udp_sendmsg")
int BPF_PROG(udp_sendmsg, struct sock *sk, struct msghdr *msg, size_t len) {
struct sock_common *skc = &sk->__sk_common;
// Create a key for the map and set all the nececery information.
struct sk_key key = {0};
key.protocol = PROTOCOL_UDP;
key.src_ip[0] = skc->skc_rcv_saddr;
key.dst_ip[0] = skc->skc_daddr;
key.src_port = skc->skc_num;
key.dst_port = __builtin_bswap16(skc->skc_dport);
key.ipv6 = 0;
// Update the map with the new information
struct sk_info *info = bpf_map_lookup_elem(&pm_bandwidth_map, &key);
if (info != NULL) {
__sync_fetch_and_add(&info->tx, len); // TODO: Use atomic instead.
__sync_fetch_and_and(&info->reported, 0); // TODO: Use atomic instead.
} else {
struct sk_info newInfo = {0};
newInfo.tx = len;
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
}
return 0;
};
// udp_recvmsg hookes to the respective kernel function and saves the bandwidth data
SEC("fentry/udp_recvmsg")
int BPF_PROG(udp_recvmsg, struct sock *sk, struct msghdr *msg, size_t len, int flags, int *addr_len) {
struct sock_common *skc = &sk->__sk_common;
// Create a key for the map and set all the nececery information.
struct sk_key key = {0};
key.protocol = PROTOCOL_UDP;
key.src_ip[0] = skc->skc_rcv_saddr;
key.dst_ip[0] = skc->skc_daddr;
key.src_port = skc->skc_num;
key.dst_port = __builtin_bswap16(skc->skc_dport);
key.ipv6 = 0;
// Update the map with the new information
struct sk_info *info = bpf_map_lookup_elem(&pm_bandwidth_map, &key);
if (info != NULL) {
__sync_fetch_and_add(&info->rx, len); // TODO: Use atomic instead.
__sync_fetch_and_and(&info->reported, 0); // TODO: Use atomic instead.
} else {
struct sk_info newInfo = {0};
newInfo.rx = len;
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
}
return 0;
};
// udpv6_sendmsg hookes to the respective kernel function and saves the bandwidth data
SEC("fentry/udpv6_sendmsg")
int BPF_PROG(udpv6_sendmsg, struct sock *sk, struct msghdr *msg, size_t len) {
struct sock_common *skc = &sk->__sk_common;
// Create a key for the map and set all the nececery information.
struct sk_key key = {0};
key.protocol = PROTOCOL_UDP;
for (int i = 0; i < 4; i++) {
key.src_ip[i] = skc->skc_v6_rcv_saddr.in6_u.u6_addr32[i];
key.dst_ip[i] = skc->skc_v6_rcv_saddr.in6_u.u6_addr32[i];
}
key.src_port = skc->skc_num;
key.dst_port = __builtin_bswap16(skc->skc_dport);
key.ipv6 = 1;
// Update the map with the new information
struct sk_info *info = bpf_map_lookup_elem(&pm_bandwidth_map, &key);
if (info != NULL) {
__sync_fetch_and_add(&info->tx, len); // TODO: Use atomic instead.
__sync_fetch_and_and(&info->reported, 0); // TODO: Use atomic instead.
} else {
struct sk_info newInfo = {0};
newInfo.tx = len;
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
}
return 0;
}
// udpv6_recvmsg hookes to the respective kernel function and saves the bandwidth data
SEC("fentry/udpv6_recvmsg")
int BPF_PROG(udpv6_recvmsg, struct sock *sk, struct msghdr *msg, size_t len, int flags, int *addr_len) {
struct sock_common *skc = &sk->__sk_common;
// Create a key for the map and set all the nececery information.
struct sk_key key = {0};
key.protocol = PROTOCOL_UDP;
for (int i = 0; i < 4; i++) {
key.src_ip[i] = skc->skc_v6_rcv_saddr.in6_u.u6_addr32[i];
key.dst_ip[i] = skc->skc_v6_rcv_saddr.in6_u.u6_addr32[i];
}
key.src_port = skc->skc_num;
key.dst_port = __builtin_bswap16(skc->skc_dport);
key.ipv6 = 1;
// Update the map with the new information
struct sk_info *info = bpf_map_lookup_elem(&pm_bandwidth_map, &key);
if (info != NULL) {
__sync_fetch_and_add(&info->rx, len); // TODO: Use atomic instead.
__sync_fetch_and_and(&info->reported, 0); // TODO: Use atomic instead.
} else {
struct sk_info newInfo = {0};
newInfo.rx = len;
bpf_map_update_elem(&pm_bandwidth_map, &key, &newInfo, BPF_ANY);
}
return 0;
}

View File

@@ -0,0 +1,484 @@
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
#ifndef __BPF_CORE_READ_H__
#define __BPF_CORE_READ_H__
/*
* enum bpf_field_info_kind is passed as a second argument into
* __builtin_preserve_field_info() built-in to get a specific aspect of
* a field, captured as a first argument. __builtin_preserve_field_info(field,
* info_kind) returns __u32 integer and produces BTF field relocation, which
* is understood and processed by libbpf during BPF object loading. See
* selftests/bpf for examples.
*/
enum bpf_field_info_kind {
BPF_FIELD_BYTE_OFFSET = 0, /* field byte offset */
BPF_FIELD_BYTE_SIZE = 1,
BPF_FIELD_EXISTS = 2, /* field existence in target kernel */
BPF_FIELD_SIGNED = 3,
BPF_FIELD_LSHIFT_U64 = 4,
BPF_FIELD_RSHIFT_U64 = 5,
};
/* second argument to __builtin_btf_type_id() built-in */
enum bpf_type_id_kind {
BPF_TYPE_ID_LOCAL = 0, /* BTF type ID in local program */
BPF_TYPE_ID_TARGET = 1, /* BTF type ID in target kernel */
};
/* second argument to __builtin_preserve_type_info() built-in */
enum bpf_type_info_kind {
BPF_TYPE_EXISTS = 0, /* type existence in target kernel */
BPF_TYPE_SIZE = 1, /* type size in target kernel */
BPF_TYPE_MATCHES = 2, /* type match in target kernel */
};
/* second argument to __builtin_preserve_enum_value() built-in */
enum bpf_enum_value_kind {
BPF_ENUMVAL_EXISTS = 0, /* enum value existence in kernel */
BPF_ENUMVAL_VALUE = 1, /* enum value value relocation */
};
#define __CORE_RELO(src, field, info) \
__builtin_preserve_field_info((src)->field, BPF_FIELD_##info)
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
#define __CORE_BITFIELD_PROBE_READ(dst, src, fld) \
bpf_probe_read_kernel( \
(void *)dst, \
__CORE_RELO(src, fld, BYTE_SIZE), \
(const void *)src + __CORE_RELO(src, fld, BYTE_OFFSET))
#else
/* semantics of LSHIFT_64 assumes loading values into low-ordered bytes, so
* for big-endian we need to adjust destination pointer accordingly, based on
* field byte size
*/
#define __CORE_BITFIELD_PROBE_READ(dst, src, fld) \
bpf_probe_read_kernel( \
(void *)dst + (8 - __CORE_RELO(src, fld, BYTE_SIZE)), \
__CORE_RELO(src, fld, BYTE_SIZE), \
(const void *)src + __CORE_RELO(src, fld, BYTE_OFFSET))
#endif
/*
* Extract bitfield, identified by s->field, and return its value as u64.
* All this is done in relocatable manner, so bitfield changes such as
* signedness, bit size, offset changes, this will be handled automatically.
* This version of macro is using bpf_probe_read_kernel() to read underlying
* integer storage. Macro functions as an expression and its return type is
* bpf_probe_read_kernel()'s return value: 0, on success, <0 on error.
*/
#define BPF_CORE_READ_BITFIELD_PROBED(s, field) ({ \
unsigned long long val = 0; \
\
__CORE_BITFIELD_PROBE_READ(&val, s, field); \
val <<= __CORE_RELO(s, field, LSHIFT_U64); \
if (__CORE_RELO(s, field, SIGNED)) \
val = ((long long)val) >> __CORE_RELO(s, field, RSHIFT_U64); \
else \
val = val >> __CORE_RELO(s, field, RSHIFT_U64); \
val; \
})
/*
* Extract bitfield, identified by s->field, and return its value as u64.
* This version of macro is using direct memory reads and should be used from
* BPF program types that support such functionality (e.g., typed raw
* tracepoints).
*/
#define BPF_CORE_READ_BITFIELD(s, field) ({ \
const void *p = (const void *)s + __CORE_RELO(s, field, BYTE_OFFSET); \
unsigned long long val; \
\
/* This is a so-called barrier_var() operation that makes specified \
* variable "a black box" for optimizing compiler. \
* It forces compiler to perform BYTE_OFFSET relocation on p and use \
* its calculated value in the switch below, instead of applying \
* the same relocation 4 times for each individual memory load. \
*/ \
asm volatile("" : "=r"(p) : "0"(p)); \
\
switch (__CORE_RELO(s, field, BYTE_SIZE)) { \
case 1: val = *(const unsigned char *)p; break; \
case 2: val = *(const unsigned short *)p; break; \
case 4: val = *(const unsigned int *)p; break; \
case 8: val = *(const unsigned long long *)p; break; \
} \
val <<= __CORE_RELO(s, field, LSHIFT_U64); \
if (__CORE_RELO(s, field, SIGNED)) \
val = ((long long)val) >> __CORE_RELO(s, field, RSHIFT_U64); \
else \
val = val >> __CORE_RELO(s, field, RSHIFT_U64); \
val; \
})
#define ___bpf_field_ref1(field) (field)
#define ___bpf_field_ref2(type, field) (((typeof(type) *)0)->field)
#define ___bpf_field_ref(args...) \
___bpf_apply(___bpf_field_ref, ___bpf_narg(args))(args)
/*
* Convenience macro to check that field actually exists in target kernel's.
* Returns:
* 1, if matching field is present in target kernel;
* 0, if no matching field found.
*
* Supports two forms:
* - field reference through variable access:
* bpf_core_field_exists(p->my_field);
* - field reference through type and field names:
* bpf_core_field_exists(struct my_type, my_field).
*/
#define bpf_core_field_exists(field...) \
__builtin_preserve_field_info(___bpf_field_ref(field), BPF_FIELD_EXISTS)
/*
* Convenience macro to get the byte size of a field. Works for integers,
* struct/unions, pointers, arrays, and enums.
*
* Supports two forms:
* - field reference through variable access:
* bpf_core_field_size(p->my_field);
* - field reference through type and field names:
* bpf_core_field_size(struct my_type, my_field).
*/
#define bpf_core_field_size(field...) \
__builtin_preserve_field_info(___bpf_field_ref(field), BPF_FIELD_BYTE_SIZE)
/*
* Convenience macro to get field's byte offset.
*
* Supports two forms:
* - field reference through variable access:
* bpf_core_field_offset(p->my_field);
* - field reference through type and field names:
* bpf_core_field_offset(struct my_type, my_field).
*/
#define bpf_core_field_offset(field...) \
__builtin_preserve_field_info(___bpf_field_ref(field), BPF_FIELD_BYTE_OFFSET)
/*
* Convenience macro to get BTF type ID of a specified type, using a local BTF
* information. Return 32-bit unsigned integer with type ID from program's own
* BTF. Always succeeds.
*/
#define bpf_core_type_id_local(type) \
__builtin_btf_type_id(*(typeof(type) *)0, BPF_TYPE_ID_LOCAL)
/*
* Convenience macro to get BTF type ID of a target kernel's type that matches
* specified local type.
* Returns:
* - valid 32-bit unsigned type ID in kernel BTF;
* - 0, if no matching type was found in a target kernel BTF.
*/
#define bpf_core_type_id_kernel(type) \
__builtin_btf_type_id(*(typeof(type) *)0, BPF_TYPE_ID_TARGET)
/*
* Convenience macro to check that provided named type
* (struct/union/enum/typedef) exists in a target kernel.
* Returns:
* 1, if such type is present in target kernel's BTF;
* 0, if no matching type is found.
*/
#define bpf_core_type_exists(type) \
__builtin_preserve_type_info(*(typeof(type) *)0, BPF_TYPE_EXISTS)
/*
* Convenience macro to check that provided named type
* (struct/union/enum/typedef) "matches" that in a target kernel.
* Returns:
* 1, if the type matches in the target kernel's BTF;
* 0, if the type does not match any in the target kernel
*/
#define bpf_core_type_matches(type) \
__builtin_preserve_type_info(*(typeof(type) *)0, BPF_TYPE_MATCHES)
/*
* Convenience macro to get the byte size of a provided named type
* (struct/union/enum/typedef) in a target kernel.
* Returns:
* >= 0 size (in bytes), if type is present in target kernel's BTF;
* 0, if no matching type is found.
*/
#define bpf_core_type_size(type) \
__builtin_preserve_type_info(*(typeof(type) *)0, BPF_TYPE_SIZE)
/*
* Convenience macro to check that provided enumerator value is defined in
* a target kernel.
* Returns:
* 1, if specified enum type and its enumerator value are present in target
* kernel's BTF;
* 0, if no matching enum and/or enum value within that enum is found.
*/
#define bpf_core_enum_value_exists(enum_type, enum_value) \
__builtin_preserve_enum_value(*(typeof(enum_type) *)enum_value, BPF_ENUMVAL_EXISTS)
/*
* Convenience macro to get the integer value of an enumerator value in
* a target kernel.
* Returns:
* 64-bit value, if specified enum type and its enumerator value are
* present in target kernel's BTF;
* 0, if no matching enum and/or enum value within that enum is found.
*/
#define bpf_core_enum_value(enum_type, enum_value) \
__builtin_preserve_enum_value(*(typeof(enum_type) *)enum_value, BPF_ENUMVAL_VALUE)
/*
* bpf_core_read() abstracts away bpf_probe_read_kernel() call and captures
* offset relocation for source address using __builtin_preserve_access_index()
* built-in, provided by Clang.
*
* __builtin_preserve_access_index() takes as an argument an expression of
* taking an address of a field within struct/union. It makes compiler emit
* a relocation, which records BTF type ID describing root struct/union and an
* accessor string which describes exact embedded field that was used to take
* an address. See detailed description of this relocation format and
* semantics in comments to struct bpf_field_reloc in libbpf_internal.h.
*
* This relocation allows libbpf to adjust BPF instruction to use correct
* actual field offset, based on target kernel BTF type that matches original
* (local) BTF, used to record relocation.
*/
#define bpf_core_read(dst, sz, src) \
bpf_probe_read_kernel(dst, sz, (const void *)__builtin_preserve_access_index(src))
/* NOTE: see comments for BPF_CORE_READ_USER() about the proper types use. */
#define bpf_core_read_user(dst, sz, src) \
bpf_probe_read_user(dst, sz, (const void *)__builtin_preserve_access_index(src))
/*
* bpf_core_read_str() is a thin wrapper around bpf_probe_read_str()
* additionally emitting BPF CO-RE field relocation for specified source
* argument.
*/
#define bpf_core_read_str(dst, sz, src) \
bpf_probe_read_kernel_str(dst, sz, (const void *)__builtin_preserve_access_index(src))
/* NOTE: see comments for BPF_CORE_READ_USER() about the proper types use. */
#define bpf_core_read_user_str(dst, sz, src) \
bpf_probe_read_user_str(dst, sz, (const void *)__builtin_preserve_access_index(src))
#define ___concat(a, b) a ## b
#define ___apply(fn, n) ___concat(fn, n)
#define ___nth(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, __11, N, ...) N
/*
* return number of provided arguments; used for switch-based variadic macro
* definitions (see ___last, ___arrow, etc below)
*/
#define ___narg(...) ___nth(_, ##__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
/*
* return 0 if no arguments are passed, N - otherwise; used for
* recursively-defined macros to specify termination (0) case, and generic
* (N) case (e.g., ___read_ptrs, ___core_read)
*/
#define ___empty(...) ___nth(_, ##__VA_ARGS__, N, N, N, N, N, N, N, N, N, N, 0)
#define ___last1(x) x
#define ___last2(a, x) x
#define ___last3(a, b, x) x
#define ___last4(a, b, c, x) x
#define ___last5(a, b, c, d, x) x
#define ___last6(a, b, c, d, e, x) x
#define ___last7(a, b, c, d, e, f, x) x
#define ___last8(a, b, c, d, e, f, g, x) x
#define ___last9(a, b, c, d, e, f, g, h, x) x
#define ___last10(a, b, c, d, e, f, g, h, i, x) x
#define ___last(...) ___apply(___last, ___narg(__VA_ARGS__))(__VA_ARGS__)
#define ___nolast2(a, _) a
#define ___nolast3(a, b, _) a, b
#define ___nolast4(a, b, c, _) a, b, c
#define ___nolast5(a, b, c, d, _) a, b, c, d
#define ___nolast6(a, b, c, d, e, _) a, b, c, d, e
#define ___nolast7(a, b, c, d, e, f, _) a, b, c, d, e, f
#define ___nolast8(a, b, c, d, e, f, g, _) a, b, c, d, e, f, g
#define ___nolast9(a, b, c, d, e, f, g, h, _) a, b, c, d, e, f, g, h
#define ___nolast10(a, b, c, d, e, f, g, h, i, _) a, b, c, d, e, f, g, h, i
#define ___nolast(...) ___apply(___nolast, ___narg(__VA_ARGS__))(__VA_ARGS__)
#define ___arrow1(a) a
#define ___arrow2(a, b) a->b
#define ___arrow3(a, b, c) a->b->c
#define ___arrow4(a, b, c, d) a->b->c->d
#define ___arrow5(a, b, c, d, e) a->b->c->d->e
#define ___arrow6(a, b, c, d, e, f) a->b->c->d->e->f
#define ___arrow7(a, b, c, d, e, f, g) a->b->c->d->e->f->g
#define ___arrow8(a, b, c, d, e, f, g, h) a->b->c->d->e->f->g->h
#define ___arrow9(a, b, c, d, e, f, g, h, i) a->b->c->d->e->f->g->h->i
#define ___arrow10(a, b, c, d, e, f, g, h, i, j) a->b->c->d->e->f->g->h->i->j
#define ___arrow(...) ___apply(___arrow, ___narg(__VA_ARGS__))(__VA_ARGS__)
#define ___type(...) typeof(___arrow(__VA_ARGS__))
#define ___read(read_fn, dst, src_type, src, accessor) \
read_fn((void *)(dst), sizeof(*(dst)), &((src_type)(src))->accessor)
/* "recursively" read a sequence of inner pointers using local __t var */
#define ___rd_first(fn, src, a) ___read(fn, &__t, ___type(src), src, a);
#define ___rd_last(fn, ...) \
___read(fn, &__t, ___type(___nolast(__VA_ARGS__)), __t, ___last(__VA_ARGS__));
#define ___rd_p1(fn, ...) const void *__t; ___rd_first(fn, __VA_ARGS__)
#define ___rd_p2(fn, ...) ___rd_p1(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
#define ___rd_p3(fn, ...) ___rd_p2(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
#define ___rd_p4(fn, ...) ___rd_p3(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
#define ___rd_p5(fn, ...) ___rd_p4(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
#define ___rd_p6(fn, ...) ___rd_p5(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
#define ___rd_p7(fn, ...) ___rd_p6(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
#define ___rd_p8(fn, ...) ___rd_p7(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
#define ___rd_p9(fn, ...) ___rd_p8(fn, ___nolast(__VA_ARGS__)) ___rd_last(fn, __VA_ARGS__)
#define ___read_ptrs(fn, src, ...) \
___apply(___rd_p, ___narg(__VA_ARGS__))(fn, src, __VA_ARGS__)
#define ___core_read0(fn, fn_ptr, dst, src, a) \
___read(fn, dst, ___type(src), src, a);
#define ___core_readN(fn, fn_ptr, dst, src, ...) \
___read_ptrs(fn_ptr, src, ___nolast(__VA_ARGS__)) \
___read(fn, dst, ___type(src, ___nolast(__VA_ARGS__)), __t, \
___last(__VA_ARGS__));
#define ___core_read(fn, fn_ptr, dst, src, a, ...) \
___apply(___core_read, ___empty(__VA_ARGS__))(fn, fn_ptr, dst, \
src, a, ##__VA_ARGS__)
/*
* BPF_CORE_READ_INTO() is a more performance-conscious variant of
* BPF_CORE_READ(), in which final field is read into user-provided storage.
* See BPF_CORE_READ() below for more details on general usage.
*/
#define BPF_CORE_READ_INTO(dst, src, a, ...) ({ \
___core_read(bpf_core_read, bpf_core_read, \
dst, (src), a, ##__VA_ARGS__) \
})
/*
* Variant of BPF_CORE_READ_INTO() for reading from user-space memory.
*
* NOTE: see comments for BPF_CORE_READ_USER() about the proper types use.
*/
#define BPF_CORE_READ_USER_INTO(dst, src, a, ...) ({ \
___core_read(bpf_core_read_user, bpf_core_read_user, \
dst, (src), a, ##__VA_ARGS__) \
})
/* Non-CO-RE variant of BPF_CORE_READ_INTO() */
#define BPF_PROBE_READ_INTO(dst, src, a, ...) ({ \
___core_read(bpf_probe_read_kernel, bpf_probe_read_kernel, \
dst, (src), a, ##__VA_ARGS__) \
})
/* Non-CO-RE variant of BPF_CORE_READ_USER_INTO().
*
* As no CO-RE relocations are emitted, source types can be arbitrary and are
* not restricted to kernel types only.
*/
#define BPF_PROBE_READ_USER_INTO(dst, src, a, ...) ({ \
___core_read(bpf_probe_read_user, bpf_probe_read_user, \
dst, (src), a, ##__VA_ARGS__) \
})
/*
* BPF_CORE_READ_STR_INTO() does same "pointer chasing" as
* BPF_CORE_READ() for intermediate pointers, but then executes (and returns
* corresponding error code) bpf_core_read_str() for final string read.
*/
#define BPF_CORE_READ_STR_INTO(dst, src, a, ...) ({ \
___core_read(bpf_core_read_str, bpf_core_read, \
dst, (src), a, ##__VA_ARGS__) \
})
/*
* Variant of BPF_CORE_READ_STR_INTO() for reading from user-space memory.
*
* NOTE: see comments for BPF_CORE_READ_USER() about the proper types use.
*/
#define BPF_CORE_READ_USER_STR_INTO(dst, src, a, ...) ({ \
___core_read(bpf_core_read_user_str, bpf_core_read_user, \
dst, (src), a, ##__VA_ARGS__) \
})
/* Non-CO-RE variant of BPF_CORE_READ_STR_INTO() */
#define BPF_PROBE_READ_STR_INTO(dst, src, a, ...) ({ \
___core_read(bpf_probe_read_kernel_str, bpf_probe_read_kernel, \
dst, (src), a, ##__VA_ARGS__) \
})
/*
* Non-CO-RE variant of BPF_CORE_READ_USER_STR_INTO().
*
* As no CO-RE relocations are emitted, source types can be arbitrary and are
* not restricted to kernel types only.
*/
#define BPF_PROBE_READ_USER_STR_INTO(dst, src, a, ...) ({ \
___core_read(bpf_probe_read_user_str, bpf_probe_read_user, \
dst, (src), a, ##__VA_ARGS__) \
})
/*
* BPF_CORE_READ() is used to simplify BPF CO-RE relocatable read, especially
* when there are few pointer chasing steps.
* E.g., what in non-BPF world (or in BPF w/ BCC) would be something like:
* int x = s->a.b.c->d.e->f->g;
* can be succinctly achieved using BPF_CORE_READ as:
* int x = BPF_CORE_READ(s, a.b.c, d.e, f, g);
*
* BPF_CORE_READ will decompose above statement into 4 bpf_core_read (BPF
* CO-RE relocatable bpf_probe_read_kernel() wrapper) calls, logically
* equivalent to:
* 1. const void *__t = s->a.b.c;
* 2. __t = __t->d.e;
* 3. __t = __t->f;
* 4. return __t->g;
*
* Equivalence is logical, because there is a heavy type casting/preservation
* involved, as well as all the reads are happening through
* bpf_probe_read_kernel() calls using __builtin_preserve_access_index() to
* emit CO-RE relocations.
*
* N.B. Only up to 9 "field accessors" are supported, which should be more
* than enough for any practical purpose.
*/
#define BPF_CORE_READ(src, a, ...) ({ \
___type((src), a, ##__VA_ARGS__) __r; \
BPF_CORE_READ_INTO(&__r, (src), a, ##__VA_ARGS__); \
__r; \
})
/*
* Variant of BPF_CORE_READ() for reading from user-space memory.
*
* NOTE: all the source types involved are still *kernel types* and need to
* exist in kernel (or kernel module) BTF, otherwise CO-RE relocation will
* fail. Custom user types are not relocatable with CO-RE.
* The typical situation in which BPF_CORE_READ_USER() might be used is to
* read kernel UAPI types from the user-space memory passed in as a syscall
* input argument.
*/
#define BPF_CORE_READ_USER(src, a, ...) ({ \
___type((src), a, ##__VA_ARGS__) __r; \
BPF_CORE_READ_USER_INTO(&__r, (src), a, ##__VA_ARGS__); \
__r; \
})
/* Non-CO-RE variant of BPF_CORE_READ() */
#define BPF_PROBE_READ(src, a, ...) ({ \
___type((src), a, ##__VA_ARGS__) __r; \
BPF_PROBE_READ_INTO(&__r, (src), a, ##__VA_ARGS__); \
__r; \
})
/*
* Non-CO-RE variant of BPF_CORE_READ_USER().
*
* As no CO-RE relocations are emitted, source types can be arbitrary and are
* not restricted to kernel types only.
*/
#define BPF_PROBE_READ_USER(src, a, ...) ({ \
___type((src), a, ##__VA_ARGS__) __r; \
BPF_PROBE_READ_USER_INTO(&__r, (src), a, ##__VA_ARGS__); \
__r; \
})
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,397 @@
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
#ifndef __BPF_HELPERS__
#define __BPF_HELPERS__
/*
* Note that bpf programs need to include either
* vmlinux.h (auto-generated from BTF) or linux/types.h
* in advance since bpf_helper_defs.h uses such types
* as __u64.
*/
#include "bpf_helper_defs.h"
#define __uint(name, val) int (*name)[val]
#define __type(name, val) typeof(val) *name
#define __array(name, val) typeof(val) *name[]
/*
* Helper macro to place programs, maps, license in
* different sections in elf_bpf file. Section names
* are interpreted by libbpf depending on the context (BPF programs, BPF maps,
* extern variables, etc).
* To allow use of SEC() with externs (e.g., for extern .maps declarations),
* make sure __attribute__((unused)) doesn't trigger compilation warning.
*/
#if __GNUC__ && !__clang__
/*
* Pragma macros are broken on GCC
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55578
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90400
*/
#define SEC(name) __attribute__((section(name), used))
#else
#define SEC(name) \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wignored-attributes\"") \
__attribute__((section(name), used)) \
_Pragma("GCC diagnostic pop") \
#endif
/* Avoid 'linux/stddef.h' definition of '__always_inline'. */
#undef __always_inline
#define __always_inline inline __attribute__((always_inline))
#ifndef __noinline
#define __noinline __attribute__((noinline))
#endif
#ifndef __weak
#define __weak __attribute__((weak))
#endif
/*
* Use __hidden attribute to mark a non-static BPF subprogram effectively
* static for BPF verifier's verification algorithm purposes, allowing more
* extensive and permissive BPF verification process, taking into account
* subprogram's caller context.
*/
#define __hidden __attribute__((visibility("hidden")))
/* When utilizing vmlinux.h with BPF CO-RE, user BPF programs can't include
* any system-level headers (such as stddef.h, linux/version.h, etc), and
* commonly-used macros like NULL and KERNEL_VERSION aren't available through
* vmlinux.h. This just adds unnecessary hurdles and forces users to re-define
* them on their own. So as a convenience, provide such definitions here.
*/
#ifndef NULL
#define NULL ((void *)0)
#endif
#ifndef KERNEL_VERSION
#define KERNEL_VERSION(a, b, c) (((a) << 16) + ((b) << 8) + ((c) > 255 ? 255 : (c)))
#endif
/*
* Helper macros to manipulate data structures
*/
#ifndef offsetof
#define offsetof(TYPE, MEMBER) ((unsigned long)&((TYPE *)0)->MEMBER)
#endif
#ifndef container_of
#define container_of(ptr, type, member) \
({ \
void *__mptr = (void *)(ptr); \
((type *)(__mptr - offsetof(type, member))); \
})
#endif
/*
* Compiler (optimization) barrier.
*/
#ifndef barrier
#define barrier() asm volatile("" ::: "memory")
#endif
/* Variable-specific compiler (optimization) barrier. It's a no-op which makes
* compiler believe that there is some black box modification of a given
* variable and thus prevents compiler from making extra assumption about its
* value and potential simplifications and optimizations on this variable.
*
* E.g., compiler might often delay or even omit 32-bit to 64-bit casting of
* a variable, making some code patterns unverifiable. Putting barrier_var()
* in place will ensure that cast is performed before the barrier_var()
* invocation, because compiler has to pessimistically assume that embedded
* asm section might perform some extra operations on that variable.
*
* This is a variable-specific variant of more global barrier().
*/
#ifndef barrier_var
#define barrier_var(var) asm volatile("" : "+r"(var))
#endif
/*
* Helper macro to throw a compilation error if __bpf_unreachable() gets
* built into the resulting code. This works given BPF back end does not
* implement __builtin_trap(). This is useful to assert that certain paths
* of the program code are never used and hence eliminated by the compiler.
*
* For example, consider a switch statement that covers known cases used by
* the program. __bpf_unreachable() can then reside in the default case. If
* the program gets extended such that a case is not covered in the switch
* statement, then it will throw a build error due to the default case not
* being compiled out.
*/
#ifndef __bpf_unreachable
# define __bpf_unreachable() __builtin_trap()
#endif
/*
* Helper function to perform a tail call with a constant/immediate map slot.
*/
#if __clang_major__ >= 8 && defined(__bpf__)
static __always_inline void
bpf_tail_call_static(void *ctx, const void *map, const __u32 slot)
{
if (!__builtin_constant_p(slot))
__bpf_unreachable();
/*
* Provide a hard guarantee that LLVM won't optimize setting r2 (map
* pointer) and r3 (constant map index) from _different paths_ ending
* up at the _same_ call insn as otherwise we won't be able to use the
* jmpq/nopl retpoline-free patching by the x86-64 JIT in the kernel
* given they mismatch. See also d2e4c1e6c294 ("bpf: Constant map key
* tracking for prog array pokes") for details on verifier tracking.
*
* Note on clobber list: we need to stay in-line with BPF calling
* convention, so even if we don't end up using r0, r4, r5, we need
* to mark them as clobber so that LLVM doesn't end up using them
* before / after the call.
*/
asm volatile("r1 = %[ctx]\n\t"
"r2 = %[map]\n\t"
"r3 = %[slot]\n\t"
"call 12"
:: [ctx]"r"(ctx), [map]"r"(map), [slot]"i"(slot)
: "r0", "r1", "r2", "r3", "r4", "r5");
}
#endif
enum libbpf_pin_type {
LIBBPF_PIN_NONE,
/* PIN_BY_NAME: pin maps by name (in /sys/fs/bpf by default) */
LIBBPF_PIN_BY_NAME,
};
enum libbpf_tristate {
TRI_NO = 0,
TRI_YES = 1,
TRI_MODULE = 2,
};
#define __kconfig __attribute__((section(".kconfig")))
#define __ksym __attribute__((section(".ksyms")))
#define __kptr_untrusted __attribute__((btf_type_tag("kptr_untrusted")))
#define __kptr __attribute__((btf_type_tag("kptr")))
#define bpf_ksym_exists(sym) ({ \
_Static_assert(!__builtin_constant_p(!!sym), #sym " should be marked as __weak"); \
!!sym; \
})
#ifndef ___bpf_concat
#define ___bpf_concat(a, b) a ## b
#endif
#ifndef ___bpf_apply
#define ___bpf_apply(fn, n) ___bpf_concat(fn, n)
#endif
#ifndef ___bpf_nth
#define ___bpf_nth(_, _1, _2, _3, _4, _5, _6, _7, _8, _9, _a, _b, _c, N, ...) N
#endif
#ifndef ___bpf_narg
#define ___bpf_narg(...) \
___bpf_nth(_, ##__VA_ARGS__, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
#endif
#define ___bpf_fill0(arr, p, x) do {} while (0)
#define ___bpf_fill1(arr, p, x) arr[p] = x
#define ___bpf_fill2(arr, p, x, args...) arr[p] = x; ___bpf_fill1(arr, p + 1, args)
#define ___bpf_fill3(arr, p, x, args...) arr[p] = x; ___bpf_fill2(arr, p + 1, args)
#define ___bpf_fill4(arr, p, x, args...) arr[p] = x; ___bpf_fill3(arr, p + 1, args)
#define ___bpf_fill5(arr, p, x, args...) arr[p] = x; ___bpf_fill4(arr, p + 1, args)
#define ___bpf_fill6(arr, p, x, args...) arr[p] = x; ___bpf_fill5(arr, p + 1, args)
#define ___bpf_fill7(arr, p, x, args...) arr[p] = x; ___bpf_fill6(arr, p + 1, args)
#define ___bpf_fill8(arr, p, x, args...) arr[p] = x; ___bpf_fill7(arr, p + 1, args)
#define ___bpf_fill9(arr, p, x, args...) arr[p] = x; ___bpf_fill8(arr, p + 1, args)
#define ___bpf_fill10(arr, p, x, args...) arr[p] = x; ___bpf_fill9(arr, p + 1, args)
#define ___bpf_fill11(arr, p, x, args...) arr[p] = x; ___bpf_fill10(arr, p + 1, args)
#define ___bpf_fill12(arr, p, x, args...) arr[p] = x; ___bpf_fill11(arr, p + 1, args)
#define ___bpf_fill(arr, args...) \
___bpf_apply(___bpf_fill, ___bpf_narg(args))(arr, 0, args)
/*
* BPF_SEQ_PRINTF to wrap bpf_seq_printf to-be-printed values
* in a structure.
*/
#define BPF_SEQ_PRINTF(seq, fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_seq_printf(seq, ___fmt, sizeof(___fmt), \
___param, sizeof(___param)); \
})
/*
* BPF_SNPRINTF wraps the bpf_snprintf helper with variadic arguments instead of
* an array of u64.
*/
#define BPF_SNPRINTF(out, out_size, fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_snprintf(out, out_size, ___fmt, \
___param, sizeof(___param)); \
})
#ifdef BPF_NO_GLOBAL_DATA
#define BPF_PRINTK_FMT_MOD
#else
#define BPF_PRINTK_FMT_MOD static const
#endif
#define __bpf_printk(fmt, ...) \
({ \
BPF_PRINTK_FMT_MOD char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), \
##__VA_ARGS__); \
})
/*
* __bpf_vprintk wraps the bpf_trace_vprintk helper with variadic arguments
* instead of an array of u64.
*/
#define __bpf_vprintk(fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_trace_vprintk(___fmt, sizeof(___fmt), \
___param, sizeof(___param)); \
})
/* Use __bpf_printk when bpf_printk call has 3 or fewer fmt args
* Otherwise use __bpf_vprintk
*/
#define ___bpf_pick_printk(...) \
___bpf_nth(_, ##__VA_ARGS__, __bpf_vprintk, __bpf_vprintk, __bpf_vprintk, \
__bpf_vprintk, __bpf_vprintk, __bpf_vprintk, __bpf_vprintk, \
__bpf_vprintk, __bpf_vprintk, __bpf_printk /*3*/, __bpf_printk /*2*/,\
__bpf_printk /*1*/, __bpf_printk /*0*/)
/* Helper macro to print out debug messages */
#define bpf_printk(fmt, args...) ___bpf_pick_printk(args)(fmt, ##args)
struct bpf_iter_num;
extern int bpf_iter_num_new(struct bpf_iter_num *it, int start, int end) __weak __ksym;
extern int *bpf_iter_num_next(struct bpf_iter_num *it) __weak __ksym;
extern void bpf_iter_num_destroy(struct bpf_iter_num *it) __weak __ksym;
#ifndef bpf_for_each
/* bpf_for_each(iter_type, cur_elem, args...) provides generic construct for
* using BPF open-coded iterators without having to write mundane explicit
* low-level loop logic. Instead, it provides for()-like generic construct
* that can be used pretty naturally. E.g., for some hypothetical cgroup
* iterator, you'd write:
*
* struct cgroup *cg, *parent_cg = <...>;
*
* bpf_for_each(cgroup, cg, parent_cg, CG_ITER_CHILDREN) {
* bpf_printk("Child cgroup id = %d", cg->cgroup_id);
* if (cg->cgroup_id == 123)
* break;
* }
*
* I.e., it looks almost like high-level for each loop in other languages,
* supports continue/break, and is verifiable by BPF verifier.
*
* For iterating integers, the difference betwen bpf_for_each(num, i, N, M)
* and bpf_for(i, N, M) is in that bpf_for() provides additional proof to
* verifier that i is in [N, M) range, and in bpf_for_each() case i is `int
* *`, not just `int`. So for integers bpf_for() is more convenient.
*
* Note: this macro relies on C99 feature of allowing to declare variables
* inside for() loop, bound to for() loop lifetime. It also utilizes GCC
* extension: __attribute__((cleanup(<func>))), supported by both GCC and
* Clang.
*/
#define bpf_for_each(type, cur, args...) for ( \
/* initialize and define destructor */ \
struct bpf_iter_##type ___it __attribute__((aligned(8), /* enforce, just in case */, \
cleanup(bpf_iter_##type##_destroy))), \
/* ___p pointer is just to call bpf_iter_##type##_new() *once* to init ___it */ \
*___p __attribute__((unused)) = ( \
bpf_iter_##type##_new(&___it, ##args), \
/* this is a workaround for Clang bug: it currently doesn't emit BTF */ \
/* for bpf_iter_##type##_destroy() when used from cleanup() attribute */ \
(void)bpf_iter_##type##_destroy, (void *)0); \
/* iteration and termination check */ \
(((cur) = bpf_iter_##type##_next(&___it))); \
)
#endif /* bpf_for_each */
#ifndef bpf_for
/* bpf_for(i, start, end) implements a for()-like looping construct that sets
* provided integer variable *i* to values starting from *start* through,
* but not including, *end*. It also proves to BPF verifier that *i* belongs
* to range [start, end), so this can be used for accessing arrays without
* extra checks.
*
* Note: *start* and *end* are assumed to be expressions with no side effects
* and whose values do not change throughout bpf_for() loop execution. They do
* not have to be statically known or constant, though.
*
* Note: similarly to bpf_for_each(), it relies on C99 feature of declaring for()
* loop bound variables and cleanup attribute, supported by GCC and Clang.
*/
#define bpf_for(i, start, end) for ( \
/* initialize and define destructor */ \
struct bpf_iter_num ___it __attribute__((aligned(8), /* enforce, just in case */ \
cleanup(bpf_iter_num_destroy))), \
/* ___p pointer is necessary to call bpf_iter_num_new() *once* to init ___it */ \
*___p __attribute__((unused)) = ( \
bpf_iter_num_new(&___it, (start), (end)), \
/* this is a workaround for Clang bug: it currently doesn't emit BTF */ \
/* for bpf_iter_num_destroy() when used from cleanup() attribute */ \
(void)bpf_iter_num_destroy, (void *)0); \
({ \
/* iteration step */ \
int *___t = bpf_iter_num_next(&___it); \
/* termination and bounds check */ \
(___t && ((i) = *___t, (i) >= (start) && (i) < (end))); \
}); \
)
#endif /* bpf_for */
#ifndef bpf_repeat
/* bpf_repeat(N) performs N iterations without exposing iteration number
*
* Note: similarly to bpf_for_each(), it relies on C99 feature of declaring for()
* loop bound variables and cleanup attribute, supported by GCC and Clang.
*/
#define bpf_repeat(N) for ( \
/* initialize and define destructor */ \
struct bpf_iter_num ___it __attribute__((aligned(8), /* enforce, just in case */ \
cleanup(bpf_iter_num_destroy))), \
/* ___p pointer is necessary to call bpf_iter_num_new() *once* to init ___it */ \
*___p __attribute__((unused)) = ( \
bpf_iter_num_new(&___it, 0, (N)), \
/* this is a workaround for Clang bug: it currently doesn't emit BTF */ \
/* for bpf_iter_num_destroy() when used from cleanup() attribute */ \
(void)bpf_iter_num_destroy, (void *)0); \
bpf_iter_num_next(&___it); \
/* nothing here */ \
)
#endif /* bpf_repeat */
#endif

View File

@@ -0,0 +1,924 @@
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
#ifndef __BPF_TRACING_H__
#define __BPF_TRACING_H__
// #include <bpf/bpf_helpers.h>
#include "bpf_helpers.h"
/* Scan the ARCH passed in from ARCH env variable (see Makefile) */
#if defined(__TARGET_ARCH_x86)
#define bpf_target_x86
#define bpf_target_defined
#elif defined(__TARGET_ARCH_s390)
#define bpf_target_s390
#define bpf_target_defined
#elif defined(__TARGET_ARCH_arm)
#define bpf_target_arm
#define bpf_target_defined
#elif defined(__TARGET_ARCH_arm64)
#define bpf_target_arm64
#define bpf_target_defined
#elif defined(__TARGET_ARCH_mips)
#define bpf_target_mips
#define bpf_target_defined
#elif defined(__TARGET_ARCH_powerpc)
#define bpf_target_powerpc
#define bpf_target_defined
#elif defined(__TARGET_ARCH_sparc)
#define bpf_target_sparc
#define bpf_target_defined
#elif defined(__TARGET_ARCH_riscv)
#define bpf_target_riscv
#define bpf_target_defined
#elif defined(__TARGET_ARCH_arc)
#define bpf_target_arc
#define bpf_target_defined
#elif defined(__TARGET_ARCH_loongarch)
#define bpf_target_loongarch
#define bpf_target_defined
#else
/* Fall back to what the compiler says */
#if defined(__x86_64__)
#define bpf_target_x86
#define bpf_target_defined
#elif defined(__s390__)
#define bpf_target_s390
#define bpf_target_defined
#elif defined(__arm__)
#define bpf_target_arm
#define bpf_target_defined
#elif defined(__aarch64__)
#define bpf_target_arm64
#define bpf_target_defined
#elif defined(__mips__)
#define bpf_target_mips
#define bpf_target_defined
#elif defined(__powerpc__)
#define bpf_target_powerpc
#define bpf_target_defined
#elif defined(__sparc__)
#define bpf_target_sparc
#define bpf_target_defined
#elif defined(__riscv) && __riscv_xlen == 64
#define bpf_target_riscv
#define bpf_target_defined
#elif defined(__arc__)
#define bpf_target_arc
#define bpf_target_defined
#elif defined(__loongarch__)
#define bpf_target_loongarch
#define bpf_target_defined
#endif /* no compiler target */
#endif
#ifndef __BPF_TARGET_MISSING
#define __BPF_TARGET_MISSING "GCC error \"Must specify a BPF target arch via __TARGET_ARCH_xxx\""
#endif
#if defined(bpf_target_x86)
/*
* https://en.wikipedia.org/wiki/X86_calling_conventions#System_V_AMD64_ABI
*/
#if defined(__KERNEL__) || defined(__VMLINUX_H__)
#define __PT_PARM1_REG di
#define __PT_PARM2_REG si
#define __PT_PARM3_REG dx
#define __PT_PARM4_REG cx
#define __PT_PARM5_REG r8
#define __PT_PARM6_REG r9
/*
* Syscall uses r10 for PARM4. See arch/x86/entry/entry_64.S:entry_SYSCALL_64
* comments in Linux sources. And refer to syscall(2) manpage.
*/
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG r10
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
#define __PT_RET_REG sp
#define __PT_FP_REG bp
#define __PT_RC_REG ax
#define __PT_SP_REG sp
#define __PT_IP_REG ip
#else
#ifdef __i386__
/* i386 kernel is built with -mregparm=3 */
#define __PT_PARM1_REG eax
#define __PT_PARM2_REG edx
#define __PT_PARM3_REG ecx
/* i386 syscall ABI is very different, refer to syscall(2) manpage */
#define __PT_PARM1_SYSCALL_REG ebx
#define __PT_PARM2_SYSCALL_REG ecx
#define __PT_PARM3_SYSCALL_REG edx
#define __PT_PARM4_SYSCALL_REG esi
#define __PT_PARM5_SYSCALL_REG edi
#define __PT_PARM6_SYSCALL_REG ebp
#define __PT_RET_REG esp
#define __PT_FP_REG ebp
#define __PT_RC_REG eax
#define __PT_SP_REG esp
#define __PT_IP_REG eip
#else /* __i386__ */
#define __PT_PARM1_REG rdi
#define __PT_PARM2_REG rsi
#define __PT_PARM3_REG rdx
#define __PT_PARM4_REG rcx
#define __PT_PARM5_REG r8
#define __PT_PARM6_REG r9
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG r10
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
#define __PT_RET_REG rsp
#define __PT_FP_REG rbp
#define __PT_RC_REG rax
#define __PT_SP_REG rsp
#define __PT_IP_REG rip
#endif /* __i386__ */
#endif /* __KERNEL__ || __VMLINUX_H__ */
#elif defined(bpf_target_s390)
/*
* https://github.com/IBM/s390x-abi/releases/download/v1.6/lzsabi_s390x.pdf
*/
struct pt_regs___s390 {
unsigned long orig_gpr2;
};
/* s390 provides user_pt_regs instead of struct pt_regs to userspace */
#define __PT_REGS_CAST(x) ((const user_pt_regs *)(x))
#define __PT_PARM1_REG gprs[2]
#define __PT_PARM2_REG gprs[3]
#define __PT_PARM3_REG gprs[4]
#define __PT_PARM4_REG gprs[5]
#define __PT_PARM5_REG gprs[6]
#define __PT_PARM1_SYSCALL_REG orig_gpr2
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG gprs[7]
#define PT_REGS_PARM1_SYSCALL(x) PT_REGS_PARM1_CORE_SYSCALL(x)
#define PT_REGS_PARM1_CORE_SYSCALL(x) \
BPF_CORE_READ((const struct pt_regs___s390 *)(x), __PT_PARM1_SYSCALL_REG)
#define __PT_RET_REG gprs[14]
#define __PT_FP_REG gprs[11] /* Works only with CONFIG_FRAME_POINTER */
#define __PT_RC_REG gprs[2]
#define __PT_SP_REG gprs[15]
#define __PT_IP_REG psw.addr
#elif defined(bpf_target_arm)
/*
* https://github.com/ARM-software/abi-aa/blob/main/aapcs32/aapcs32.rst#machine-registers
*/
#define __PT_PARM1_REG uregs[0]
#define __PT_PARM2_REG uregs[1]
#define __PT_PARM3_REG uregs[2]
#define __PT_PARM4_REG uregs[3]
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG uregs[4]
#define __PT_PARM6_SYSCALL_REG uregs[5]
#define __PT_PARM7_SYSCALL_REG uregs[6]
#define __PT_RET_REG uregs[14]
#define __PT_FP_REG uregs[11] /* Works only with CONFIG_FRAME_POINTER */
#define __PT_RC_REG uregs[0]
#define __PT_SP_REG uregs[13]
#define __PT_IP_REG uregs[12]
#elif defined(bpf_target_arm64)
/*
* https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#machine-registers
*/
struct pt_regs___arm64 {
unsigned long orig_x0;
};
/* arm64 provides struct user_pt_regs instead of struct pt_regs to userspace */
#define __PT_REGS_CAST(x) ((const struct user_pt_regs *)(x))
#define __PT_PARM1_REG regs[0]
#define __PT_PARM2_REG regs[1]
#define __PT_PARM3_REG regs[2]
#define __PT_PARM4_REG regs[3]
#define __PT_PARM5_REG regs[4]
#define __PT_PARM6_REG regs[5]
#define __PT_PARM7_REG regs[6]
#define __PT_PARM8_REG regs[7]
#define __PT_PARM1_SYSCALL_REG orig_x0
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
#define PT_REGS_PARM1_SYSCALL(x) PT_REGS_PARM1_CORE_SYSCALL(x)
#define PT_REGS_PARM1_CORE_SYSCALL(x) \
BPF_CORE_READ((const struct pt_regs___arm64 *)(x), __PT_PARM1_SYSCALL_REG)
#define __PT_RET_REG regs[30]
#define __PT_FP_REG regs[29] /* Works only with CONFIG_FRAME_POINTER */
#define __PT_RC_REG regs[0]
#define __PT_SP_REG sp
#define __PT_IP_REG pc
#elif defined(bpf_target_mips)
/*
* N64 ABI is assumed right now.
* https://en.wikipedia.org/wiki/MIPS_architecture#Calling_conventions
*/
#define __PT_PARM1_REG regs[4]
#define __PT_PARM2_REG regs[5]
#define __PT_PARM3_REG regs[6]
#define __PT_PARM4_REG regs[7]
#define __PT_PARM5_REG regs[8]
#define __PT_PARM6_REG regs[9]
#define __PT_PARM7_REG regs[10]
#define __PT_PARM8_REG regs[11]
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG /* only N32/N64 */
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG /* only N32/N64 */
#define __PT_RET_REG regs[31]
#define __PT_FP_REG regs[30] /* Works only with CONFIG_FRAME_POINTER */
#define __PT_RC_REG regs[2]
#define __PT_SP_REG regs[29]
#define __PT_IP_REG cp0_epc
#elif defined(bpf_target_powerpc)
/*
* http://refspecs.linux-foundation.org/elf/elfspec_ppc.pdf (page 3-14,
* section "Function Calling Sequence")
*/
#define __PT_PARM1_REG gpr[3]
#define __PT_PARM2_REG gpr[4]
#define __PT_PARM3_REG gpr[5]
#define __PT_PARM4_REG gpr[6]
#define __PT_PARM5_REG gpr[7]
#define __PT_PARM6_REG gpr[8]
#define __PT_PARM7_REG gpr[9]
#define __PT_PARM8_REG gpr[10]
/* powerpc does not select ARCH_HAS_SYSCALL_WRAPPER. */
#define PT_REGS_SYSCALL_REGS(ctx) ctx
#define __PT_PARM1_SYSCALL_REG orig_gpr3
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
#if !defined(__arch64__)
#define __PT_PARM7_SYSCALL_REG __PT_PARM7_REG /* only powerpc (not powerpc64) */
#endif
#define __PT_RET_REG regs[31]
#define __PT_FP_REG __unsupported__
#define __PT_RC_REG gpr[3]
#define __PT_SP_REG sp
#define __PT_IP_REG nip
#elif defined(bpf_target_sparc)
/*
* https://en.wikipedia.org/wiki/Calling_convention#SPARC
*/
#define __PT_PARM1_REG u_regs[UREG_I0]
#define __PT_PARM2_REG u_regs[UREG_I1]
#define __PT_PARM3_REG u_regs[UREG_I2]
#define __PT_PARM4_REG u_regs[UREG_I3]
#define __PT_PARM5_REG u_regs[UREG_I4]
#define __PT_PARM6_REG u_regs[UREG_I5]
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
#define __PT_RET_REG u_regs[UREG_I7]
#define __PT_FP_REG __unsupported__
#define __PT_RC_REG u_regs[UREG_I0]
#define __PT_SP_REG u_regs[UREG_FP]
/* Should this also be a bpf_target check for the sparc case? */
#if defined(__arch64__)
#define __PT_IP_REG tpc
#else
#define __PT_IP_REG pc
#endif
#elif defined(bpf_target_riscv)
/*
* https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#risc-v-calling-conventions
*/
#define __PT_REGS_CAST(x) ((const struct user_regs_struct *)(x))
#define __PT_PARM1_REG a0
#define __PT_PARM2_REG a1
#define __PT_PARM3_REG a2
#define __PT_PARM4_REG a3
#define __PT_PARM5_REG a4
#define __PT_PARM6_REG a5
#define __PT_PARM7_REG a6
#define __PT_PARM8_REG a7
/* riscv does not select ARCH_HAS_SYSCALL_WRAPPER. */
#define PT_REGS_SYSCALL_REGS(ctx) ctx
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
#define __PT_RET_REG ra
#define __PT_FP_REG s0
#define __PT_RC_REG a0
#define __PT_SP_REG sp
#define __PT_IP_REG pc
#elif defined(bpf_target_arc)
/*
* Section "Function Calling Sequence" (page 24):
* https://raw.githubusercontent.com/wiki/foss-for-synopsys-dwc-arc-processors/toolchain/files/ARCv2_ABI.pdf
*/
/* arc provides struct user_pt_regs instead of struct pt_regs to userspace */
#define __PT_REGS_CAST(x) ((const struct user_regs_struct *)(x))
#define __PT_PARM1_REG scratch.r0
#define __PT_PARM2_REG scratch.r1
#define __PT_PARM3_REG scratch.r2
#define __PT_PARM4_REG scratch.r3
#define __PT_PARM5_REG scratch.r4
#define __PT_PARM6_REG scratch.r5
#define __PT_PARM7_REG scratch.r6
#define __PT_PARM8_REG scratch.r7
/* arc does not select ARCH_HAS_SYSCALL_WRAPPER. */
#define PT_REGS_SYSCALL_REGS(ctx) ctx
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
#define __PT_RET_REG scratch.blink
#define __PT_FP_REG scratch.fp
#define __PT_RC_REG scratch.r0
#define __PT_SP_REG scratch.sp
#define __PT_IP_REG scratch.ret
#elif defined(bpf_target_loongarch)
/*
* https://docs.kernel.org/loongarch/introduction.html
* https://loongson.github.io/LoongArch-Documentation/LoongArch-ELF-ABI-EN.html
*/
/* loongarch provides struct user_pt_regs instead of struct pt_regs to userspace */
#define __PT_REGS_CAST(x) ((const struct user_pt_regs *)(x))
#define __PT_PARM1_REG regs[4]
#define __PT_PARM2_REG regs[5]
#define __PT_PARM3_REG regs[6]
#define __PT_PARM4_REG regs[7]
#define __PT_PARM5_REG regs[8]
#define __PT_PARM6_REG regs[9]
#define __PT_PARM7_REG regs[10]
#define __PT_PARM8_REG regs[11]
/* loongarch does not select ARCH_HAS_SYSCALL_WRAPPER. */
#define PT_REGS_SYSCALL_REGS(ctx) ctx
#define __PT_PARM1_SYSCALL_REG __PT_PARM1_REG
#define __PT_PARM2_SYSCALL_REG __PT_PARM2_REG
#define __PT_PARM3_SYSCALL_REG __PT_PARM3_REG
#define __PT_PARM4_SYSCALL_REG __PT_PARM4_REG
#define __PT_PARM5_SYSCALL_REG __PT_PARM5_REG
#define __PT_PARM6_SYSCALL_REG __PT_PARM6_REG
#define __PT_RET_REG regs[1]
#define __PT_FP_REG regs[22]
#define __PT_RC_REG regs[4]
#define __PT_SP_REG regs[3]
#define __PT_IP_REG csr_era
#endif
#if defined(bpf_target_defined)
struct pt_regs;
/* allow some architectures to override `struct pt_regs` */
#ifndef __PT_REGS_CAST
#define __PT_REGS_CAST(x) (x)
#endif
/*
* Different architectures support different number of arguments passed
* through registers. i386 supports just 3, some arches support up to 8.
*/
#ifndef __PT_PARM4_REG
#define __PT_PARM4_REG __unsupported__
#endif
#ifndef __PT_PARM5_REG
#define __PT_PARM5_REG __unsupported__
#endif
#ifndef __PT_PARM6_REG
#define __PT_PARM6_REG __unsupported__
#endif
#ifndef __PT_PARM7_REG
#define __PT_PARM7_REG __unsupported__
#endif
#ifndef __PT_PARM8_REG
#define __PT_PARM8_REG __unsupported__
#endif
/*
* Similarly, syscall-specific conventions might differ between function call
* conventions within each architecutre. All supported architectures pass
* either 6 or 7 syscall arguments in registers.
*
* See syscall(2) manpage for succinct table with information on each arch.
*/
#ifndef __PT_PARM7_SYSCALL_REG
#define __PT_PARM7_SYSCALL_REG __unsupported__
#endif
#define PT_REGS_PARM1(x) (__PT_REGS_CAST(x)->__PT_PARM1_REG)
#define PT_REGS_PARM2(x) (__PT_REGS_CAST(x)->__PT_PARM2_REG)
#define PT_REGS_PARM3(x) (__PT_REGS_CAST(x)->__PT_PARM3_REG)
#define PT_REGS_PARM4(x) (__PT_REGS_CAST(x)->__PT_PARM4_REG)
#define PT_REGS_PARM5(x) (__PT_REGS_CAST(x)->__PT_PARM5_REG)
#define PT_REGS_PARM6(x) (__PT_REGS_CAST(x)->__PT_PARM6_REG)
#define PT_REGS_PARM7(x) (__PT_REGS_CAST(x)->__PT_PARM7_REG)
#define PT_REGS_PARM8(x) (__PT_REGS_CAST(x)->__PT_PARM8_REG)
#define PT_REGS_RET(x) (__PT_REGS_CAST(x)->__PT_RET_REG)
#define PT_REGS_FP(x) (__PT_REGS_CAST(x)->__PT_FP_REG)
#define PT_REGS_RC(x) (__PT_REGS_CAST(x)->__PT_RC_REG)
#define PT_REGS_SP(x) (__PT_REGS_CAST(x)->__PT_SP_REG)
#define PT_REGS_IP(x) (__PT_REGS_CAST(x)->__PT_IP_REG)
#define PT_REGS_PARM1_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM1_REG)
#define PT_REGS_PARM2_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM2_REG)
#define PT_REGS_PARM3_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM3_REG)
#define PT_REGS_PARM4_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM4_REG)
#define PT_REGS_PARM5_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM5_REG)
#define PT_REGS_PARM6_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM6_REG)
#define PT_REGS_PARM7_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM7_REG)
#define PT_REGS_PARM8_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM8_REG)
#define PT_REGS_RET_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_RET_REG)
#define PT_REGS_FP_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_FP_REG)
#define PT_REGS_RC_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_RC_REG)
#define PT_REGS_SP_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_SP_REG)
#define PT_REGS_IP_CORE(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_IP_REG)
#if defined(bpf_target_powerpc)
#define BPF_KPROBE_READ_RET_IP(ip, ctx) ({ (ip) = (ctx)->link; })
#define BPF_KRETPROBE_READ_RET_IP BPF_KPROBE_READ_RET_IP
#elif defined(bpf_target_sparc)
#define BPF_KPROBE_READ_RET_IP(ip, ctx) ({ (ip) = PT_REGS_RET(ctx); })
#define BPF_KRETPROBE_READ_RET_IP BPF_KPROBE_READ_RET_IP
#else
#define BPF_KPROBE_READ_RET_IP(ip, ctx) \
({ bpf_probe_read_kernel(&(ip), sizeof(ip), (void *)PT_REGS_RET(ctx)); })
#define BPF_KRETPROBE_READ_RET_IP(ip, ctx) \
({ bpf_probe_read_kernel(&(ip), sizeof(ip), (void *)(PT_REGS_FP(ctx) + sizeof(ip))); })
#endif
#ifndef PT_REGS_PARM1_SYSCALL
#define PT_REGS_PARM1_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM1_SYSCALL_REG)
#define PT_REGS_PARM1_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM1_SYSCALL_REG)
#endif
#ifndef PT_REGS_PARM2_SYSCALL
#define PT_REGS_PARM2_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM2_SYSCALL_REG)
#define PT_REGS_PARM2_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM2_SYSCALL_REG)
#endif
#ifndef PT_REGS_PARM3_SYSCALL
#define PT_REGS_PARM3_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM3_SYSCALL_REG)
#define PT_REGS_PARM3_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM3_SYSCALL_REG)
#endif
#ifndef PT_REGS_PARM4_SYSCALL
#define PT_REGS_PARM4_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM4_SYSCALL_REG)
#define PT_REGS_PARM4_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM4_SYSCALL_REG)
#endif
#ifndef PT_REGS_PARM5_SYSCALL
#define PT_REGS_PARM5_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM5_SYSCALL_REG)
#define PT_REGS_PARM5_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM5_SYSCALL_REG)
#endif
#ifndef PT_REGS_PARM6_SYSCALL
#define PT_REGS_PARM6_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM6_SYSCALL_REG)
#define PT_REGS_PARM6_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM6_SYSCALL_REG)
#endif
#ifndef PT_REGS_PARM7_SYSCALL
#define PT_REGS_PARM7_SYSCALL(x) (__PT_REGS_CAST(x)->__PT_PARM7_SYSCALL_REG)
#define PT_REGS_PARM7_CORE_SYSCALL(x) BPF_CORE_READ(__PT_REGS_CAST(x), __PT_PARM7_SYSCALL_REG)
#endif
#else /* defined(bpf_target_defined) */
#define PT_REGS_PARM1(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM2(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM3(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM4(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM5(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM6(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM7(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM8(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_RET(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_FP(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_RC(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_SP(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_IP(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM1_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM2_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM3_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM4_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM5_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM6_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM7_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM8_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_RET_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_FP_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_RC_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_SP_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_IP_CORE(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define BPF_KPROBE_READ_RET_IP(ip, ctx) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define BPF_KRETPROBE_READ_RET_IP(ip, ctx) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM1_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM2_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM3_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM4_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM5_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM6_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM7_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM1_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM2_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM3_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM4_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM5_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM6_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#define PT_REGS_PARM7_CORE_SYSCALL(x) ({ _Pragma(__BPF_TARGET_MISSING); 0l; })
#endif /* defined(bpf_target_defined) */
/*
* When invoked from a syscall handler kprobe, returns a pointer to a
* struct pt_regs containing syscall arguments and suitable for passing to
* PT_REGS_PARMn_SYSCALL() and PT_REGS_PARMn_CORE_SYSCALL().
*/
#ifndef PT_REGS_SYSCALL_REGS
/* By default, assume that the arch selects ARCH_HAS_SYSCALL_WRAPPER. */
#define PT_REGS_SYSCALL_REGS(ctx) ((struct pt_regs *)PT_REGS_PARM1(ctx))
#endif
#ifndef ___bpf_concat
#define ___bpf_concat(a, b) a ## b
#endif
#ifndef ___bpf_apply
#define ___bpf_apply(fn, n) ___bpf_concat(fn, n)
#endif
#ifndef ___bpf_nth
#define ___bpf_nth(_, _1, _2, _3, _4, _5, _6, _7, _8, _9, _a, _b, _c, N, ...) N
#endif
#ifndef ___bpf_narg
#define ___bpf_narg(...) ___bpf_nth(_, ##__VA_ARGS__, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
#endif
#define ___bpf_ctx_cast0() ctx
#define ___bpf_ctx_cast1(x) ___bpf_ctx_cast0(), (void *)ctx[0]
#define ___bpf_ctx_cast2(x, args...) ___bpf_ctx_cast1(args), (void *)ctx[1]
#define ___bpf_ctx_cast3(x, args...) ___bpf_ctx_cast2(args), (void *)ctx[2]
#define ___bpf_ctx_cast4(x, args...) ___bpf_ctx_cast3(args), (void *)ctx[3]
#define ___bpf_ctx_cast5(x, args...) ___bpf_ctx_cast4(args), (void *)ctx[4]
#define ___bpf_ctx_cast6(x, args...) ___bpf_ctx_cast5(args), (void *)ctx[5]
#define ___bpf_ctx_cast7(x, args...) ___bpf_ctx_cast6(args), (void *)ctx[6]
#define ___bpf_ctx_cast8(x, args...) ___bpf_ctx_cast7(args), (void *)ctx[7]
#define ___bpf_ctx_cast9(x, args...) ___bpf_ctx_cast8(args), (void *)ctx[8]
#define ___bpf_ctx_cast10(x, args...) ___bpf_ctx_cast9(args), (void *)ctx[9]
#define ___bpf_ctx_cast11(x, args...) ___bpf_ctx_cast10(args), (void *)ctx[10]
#define ___bpf_ctx_cast12(x, args...) ___bpf_ctx_cast11(args), (void *)ctx[11]
#define ___bpf_ctx_cast(args...) ___bpf_apply(___bpf_ctx_cast, ___bpf_narg(args))(args)
/*
* BPF_PROG is a convenience wrapper for generic tp_btf/fentry/fexit and
* similar kinds of BPF programs, that accept input arguments as a single
* pointer to untyped u64 array, where each u64 can actually be a typed
* pointer or integer of different size. Instead of requring user to write
* manual casts and work with array elements by index, BPF_PROG macro
* allows user to declare a list of named and typed input arguments in the
* same syntax as for normal C function. All the casting is hidden and
* performed transparently, while user code can just assume working with
* function arguments of specified type and name.
*
* Original raw context argument is preserved as well as 'ctx' argument.
* This is useful when using BPF helpers that expect original context
* as one of the parameters (e.g., for bpf_perf_event_output()).
*/
#define BPF_PROG(name, args...) \
name(unsigned long long *ctx); \
static __always_inline typeof(name(0)) \
____##name(unsigned long long *ctx, ##args); \
typeof(name(0)) name(unsigned long long *ctx) \
{ \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
return ____##name(___bpf_ctx_cast(args)); \
_Pragma("GCC diagnostic pop") \
} \
static __always_inline typeof(name(0)) \
____##name(unsigned long long *ctx, ##args)
#ifndef ___bpf_nth2
#define ___bpf_nth2(_, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \
_14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, N, ...) N
#endif
#ifndef ___bpf_narg2
#define ___bpf_narg2(...) \
___bpf_nth2(_, ##__VA_ARGS__, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, \
6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0)
#endif
#define ___bpf_treg_cnt(t) \
__builtin_choose_expr(sizeof(t) == 1, 1, \
__builtin_choose_expr(sizeof(t) == 2, 1, \
__builtin_choose_expr(sizeof(t) == 4, 1, \
__builtin_choose_expr(sizeof(t) == 8, 1, \
__builtin_choose_expr(sizeof(t) == 16, 2, \
(void)0)))))
#define ___bpf_reg_cnt0() (0)
#define ___bpf_reg_cnt1(t, x) (___bpf_reg_cnt0() + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt2(t, x, args...) (___bpf_reg_cnt1(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt3(t, x, args...) (___bpf_reg_cnt2(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt4(t, x, args...) (___bpf_reg_cnt3(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt5(t, x, args...) (___bpf_reg_cnt4(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt6(t, x, args...) (___bpf_reg_cnt5(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt7(t, x, args...) (___bpf_reg_cnt6(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt8(t, x, args...) (___bpf_reg_cnt7(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt9(t, x, args...) (___bpf_reg_cnt8(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt10(t, x, args...) (___bpf_reg_cnt9(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt11(t, x, args...) (___bpf_reg_cnt10(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt12(t, x, args...) (___bpf_reg_cnt11(args) + ___bpf_treg_cnt(t))
#define ___bpf_reg_cnt(args...) ___bpf_apply(___bpf_reg_cnt, ___bpf_narg2(args))(args)
#define ___bpf_union_arg(t, x, n) \
__builtin_choose_expr(sizeof(t) == 1, ({ union { __u8 z[1]; t x; } ___t = { .z = {ctx[n]}}; ___t.x; }), \
__builtin_choose_expr(sizeof(t) == 2, ({ union { __u16 z[1]; t x; } ___t = { .z = {ctx[n]} }; ___t.x; }), \
__builtin_choose_expr(sizeof(t) == 4, ({ union { __u32 z[1]; t x; } ___t = { .z = {ctx[n]} }; ___t.x; }), \
__builtin_choose_expr(sizeof(t) == 8, ({ union { __u64 z[1]; t x; } ___t = {.z = {ctx[n]} }; ___t.x; }), \
__builtin_choose_expr(sizeof(t) == 16, ({ union { __u64 z[2]; t x; } ___t = {.z = {ctx[n], ctx[n + 1]} }; ___t.x; }), \
(void)0)))))
#define ___bpf_ctx_arg0(n, args...)
#define ___bpf_ctx_arg1(n, t, x) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt1(t, x))
#define ___bpf_ctx_arg2(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt2(t, x, args)) ___bpf_ctx_arg1(n, args)
#define ___bpf_ctx_arg3(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt3(t, x, args)) ___bpf_ctx_arg2(n, args)
#define ___bpf_ctx_arg4(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt4(t, x, args)) ___bpf_ctx_arg3(n, args)
#define ___bpf_ctx_arg5(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt5(t, x, args)) ___bpf_ctx_arg4(n, args)
#define ___bpf_ctx_arg6(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt6(t, x, args)) ___bpf_ctx_arg5(n, args)
#define ___bpf_ctx_arg7(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt7(t, x, args)) ___bpf_ctx_arg6(n, args)
#define ___bpf_ctx_arg8(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt8(t, x, args)) ___bpf_ctx_arg7(n, args)
#define ___bpf_ctx_arg9(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt9(t, x, args)) ___bpf_ctx_arg8(n, args)
#define ___bpf_ctx_arg10(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt10(t, x, args)) ___bpf_ctx_arg9(n, args)
#define ___bpf_ctx_arg11(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt11(t, x, args)) ___bpf_ctx_arg10(n, args)
#define ___bpf_ctx_arg12(n, t, x, args...) , ___bpf_union_arg(t, x, n - ___bpf_reg_cnt12(t, x, args)) ___bpf_ctx_arg11(n, args)
#define ___bpf_ctx_arg(args...) ___bpf_apply(___bpf_ctx_arg, ___bpf_narg2(args))(___bpf_reg_cnt(args), args)
#define ___bpf_ctx_decl0()
#define ___bpf_ctx_decl1(t, x) , t x
#define ___bpf_ctx_decl2(t, x, args...) , t x ___bpf_ctx_decl1(args)
#define ___bpf_ctx_decl3(t, x, args...) , t x ___bpf_ctx_decl2(args)
#define ___bpf_ctx_decl4(t, x, args...) , t x ___bpf_ctx_decl3(args)
#define ___bpf_ctx_decl5(t, x, args...) , t x ___bpf_ctx_decl4(args)
#define ___bpf_ctx_decl6(t, x, args...) , t x ___bpf_ctx_decl5(args)
#define ___bpf_ctx_decl7(t, x, args...) , t x ___bpf_ctx_decl6(args)
#define ___bpf_ctx_decl8(t, x, args...) , t x ___bpf_ctx_decl7(args)
#define ___bpf_ctx_decl9(t, x, args...) , t x ___bpf_ctx_decl8(args)
#define ___bpf_ctx_decl10(t, x, args...) , t x ___bpf_ctx_decl9(args)
#define ___bpf_ctx_decl11(t, x, args...) , t x ___bpf_ctx_decl10(args)
#define ___bpf_ctx_decl12(t, x, args...) , t x ___bpf_ctx_decl11(args)
#define ___bpf_ctx_decl(args...) ___bpf_apply(___bpf_ctx_decl, ___bpf_narg2(args))(args)
/*
* BPF_PROG2 is an enhanced version of BPF_PROG in order to handle struct
* arguments. Since each struct argument might take one or two u64 values
* in the trampoline stack, argument type size is needed to place proper number
* of u64 values for each argument. Therefore, BPF_PROG2 has different
* syntax from BPF_PROG. For example, for the following BPF_PROG syntax:
*
* int BPF_PROG(test2, int a, int b) { ... }
*
* the corresponding BPF_PROG2 syntax is:
*
* int BPF_PROG2(test2, int, a, int, b) { ... }
*
* where type and the corresponding argument name are separated by comma.
*
* Use BPF_PROG2 macro if one of the arguments might be a struct/union larger
* than 8 bytes:
*
* int BPF_PROG2(test_struct_arg, struct bpf_testmod_struct_arg_1, a, int, b,
* int, c, int, d, struct bpf_testmod_struct_arg_2, e, int, ret)
* {
* // access a, b, c, d, e, and ret directly
* ...
* }
*/
#define BPF_PROG2(name, args...) \
name(unsigned long long *ctx); \
static __always_inline typeof(name(0)) \
____##name(unsigned long long *ctx ___bpf_ctx_decl(args)); \
typeof(name(0)) name(unsigned long long *ctx) \
{ \
return ____##name(ctx ___bpf_ctx_arg(args)); \
} \
static __always_inline typeof(name(0)) \
____##name(unsigned long long *ctx ___bpf_ctx_decl(args))
struct pt_regs;
#define ___bpf_kprobe_args0() ctx
#define ___bpf_kprobe_args1(x) ___bpf_kprobe_args0(), (void *)PT_REGS_PARM1(ctx)
#define ___bpf_kprobe_args2(x, args...) ___bpf_kprobe_args1(args), (void *)PT_REGS_PARM2(ctx)
#define ___bpf_kprobe_args3(x, args...) ___bpf_kprobe_args2(args), (void *)PT_REGS_PARM3(ctx)
#define ___bpf_kprobe_args4(x, args...) ___bpf_kprobe_args3(args), (void *)PT_REGS_PARM4(ctx)
#define ___bpf_kprobe_args5(x, args...) ___bpf_kprobe_args4(args), (void *)PT_REGS_PARM5(ctx)
#define ___bpf_kprobe_args6(x, args...) ___bpf_kprobe_args5(args), (void *)PT_REGS_PARM6(ctx)
#define ___bpf_kprobe_args7(x, args...) ___bpf_kprobe_args6(args), (void *)PT_REGS_PARM7(ctx)
#define ___bpf_kprobe_args8(x, args...) ___bpf_kprobe_args7(args), (void *)PT_REGS_PARM8(ctx)
#define ___bpf_kprobe_args(args...) ___bpf_apply(___bpf_kprobe_args, ___bpf_narg(args))(args)
/*
* BPF_KPROBE serves the same purpose for kprobes as BPF_PROG for
* tp_btf/fentry/fexit BPF programs. It hides the underlying platform-specific
* low-level way of getting kprobe input arguments from struct pt_regs, and
* provides a familiar typed and named function arguments syntax and
* semantics of accessing kprobe input paremeters.
*
* Original struct pt_regs* context is preserved as 'ctx' argument. This might
* be necessary when using BPF helpers like bpf_perf_event_output().
*/
#define BPF_KPROBE(name, args...) \
name(struct pt_regs *ctx); \
static __always_inline typeof(name(0)) \
____##name(struct pt_regs *ctx, ##args); \
typeof(name(0)) name(struct pt_regs *ctx) \
{ \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
return ____##name(___bpf_kprobe_args(args)); \
_Pragma("GCC diagnostic pop") \
} \
static __always_inline typeof(name(0)) \
____##name(struct pt_regs *ctx, ##args)
#define ___bpf_kretprobe_args0() ctx
#define ___bpf_kretprobe_args1(x) ___bpf_kretprobe_args0(), (void *)PT_REGS_RC(ctx)
#define ___bpf_kretprobe_args(args...) ___bpf_apply(___bpf_kretprobe_args, ___bpf_narg(args))(args)
/*
* BPF_KRETPROBE is similar to BPF_KPROBE, except, it only provides optional
* return value (in addition to `struct pt_regs *ctx`), but no input
* arguments, because they will be clobbered by the time probed function
* returns.
*/
#define BPF_KRETPROBE(name, args...) \
name(struct pt_regs *ctx); \
static __always_inline typeof(name(0)) \
____##name(struct pt_regs *ctx, ##args); \
typeof(name(0)) name(struct pt_regs *ctx) \
{ \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
return ____##name(___bpf_kretprobe_args(args)); \
_Pragma("GCC diagnostic pop") \
} \
static __always_inline typeof(name(0)) ____##name(struct pt_regs *ctx, ##args)
/* If kernel has CONFIG_ARCH_HAS_SYSCALL_WRAPPER, read pt_regs directly */
#define ___bpf_syscall_args0() ctx
#define ___bpf_syscall_args1(x) ___bpf_syscall_args0(), (void *)PT_REGS_PARM1_SYSCALL(regs)
#define ___bpf_syscall_args2(x, args...) ___bpf_syscall_args1(args), (void *)PT_REGS_PARM2_SYSCALL(regs)
#define ___bpf_syscall_args3(x, args...) ___bpf_syscall_args2(args), (void *)PT_REGS_PARM3_SYSCALL(regs)
#define ___bpf_syscall_args4(x, args...) ___bpf_syscall_args3(args), (void *)PT_REGS_PARM4_SYSCALL(regs)
#define ___bpf_syscall_args5(x, args...) ___bpf_syscall_args4(args), (void *)PT_REGS_PARM5_SYSCALL(regs)
#define ___bpf_syscall_args6(x, args...) ___bpf_syscall_args5(args), (void *)PT_REGS_PARM6_SYSCALL(regs)
#define ___bpf_syscall_args7(x, args...) ___bpf_syscall_args6(args), (void *)PT_REGS_PARM7_SYSCALL(regs)
#define ___bpf_syscall_args(args...) ___bpf_apply(___bpf_syscall_args, ___bpf_narg(args))(args)
/* If kernel doesn't have CONFIG_ARCH_HAS_SYSCALL_WRAPPER, we have to BPF_CORE_READ from pt_regs */
#define ___bpf_syswrap_args0() ctx
#define ___bpf_syswrap_args1(x) ___bpf_syswrap_args0(), (void *)PT_REGS_PARM1_CORE_SYSCALL(regs)
#define ___bpf_syswrap_args2(x, args...) ___bpf_syswrap_args1(args), (void *)PT_REGS_PARM2_CORE_SYSCALL(regs)
#define ___bpf_syswrap_args3(x, args...) ___bpf_syswrap_args2(args), (void *)PT_REGS_PARM3_CORE_SYSCALL(regs)
#define ___bpf_syswrap_args4(x, args...) ___bpf_syswrap_args3(args), (void *)PT_REGS_PARM4_CORE_SYSCALL(regs)
#define ___bpf_syswrap_args5(x, args...) ___bpf_syswrap_args4(args), (void *)PT_REGS_PARM5_CORE_SYSCALL(regs)
#define ___bpf_syswrap_args6(x, args...) ___bpf_syswrap_args5(args), (void *)PT_REGS_PARM6_CORE_SYSCALL(regs)
#define ___bpf_syswrap_args7(x, args...) ___bpf_syswrap_args6(args), (void *)PT_REGS_PARM7_CORE_SYSCALL(regs)
#define ___bpf_syswrap_args(args...) ___bpf_apply(___bpf_syswrap_args, ___bpf_narg(args))(args)
/*
* BPF_KSYSCALL is a variant of BPF_KPROBE, which is intended for
* tracing syscall functions, like __x64_sys_close. It hides the underlying
* platform-specific low-level way of getting syscall input arguments from
* struct pt_regs, and provides a familiar typed and named function arguments
* syntax and semantics of accessing syscall input parameters.
*
* Original struct pt_regs * context is preserved as 'ctx' argument. This might
* be necessary when using BPF helpers like bpf_perf_event_output().
*
* At the moment BPF_KSYSCALL does not transparently handle all the calling
* convention quirks for the following syscalls:
*
* - mmap(): __ARCH_WANT_SYS_OLD_MMAP.
* - clone(): CONFIG_CLONE_BACKWARDS, CONFIG_CLONE_BACKWARDS2 and
* CONFIG_CLONE_BACKWARDS3.
* - socket-related syscalls: __ARCH_WANT_SYS_SOCKETCALL.
* - compat syscalls.
*
* This may or may not change in the future. User needs to take extra measures
* to handle such quirks explicitly, if necessary.
*
* This macro relies on BPF CO-RE support and virtual __kconfig externs.
*/
#define BPF_KSYSCALL(name, args...) \
name(struct pt_regs *ctx); \
extern _Bool LINUX_HAS_SYSCALL_WRAPPER __kconfig; \
static __always_inline typeof(name(0)) \
____##name(struct pt_regs *ctx, ##args); \
typeof(name(0)) name(struct pt_regs *ctx) \
{ \
struct pt_regs *regs = LINUX_HAS_SYSCALL_WRAPPER \
? (struct pt_regs *)PT_REGS_PARM1(ctx) \
: ctx; \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
if (LINUX_HAS_SYSCALL_WRAPPER) \
return ____##name(___bpf_syswrap_args(args)); \
else \
return ____##name(___bpf_syscall_args(args)); \
_Pragma("GCC diagnostic pop") \
} \
static __always_inline typeof(name(0)) \
____##name(struct pt_regs *ctx, ##args)
#define BPF_KPROBE_SYSCALL BPF_KSYSCALL
/* BPF_UPROBE and BPF_URETPROBE are identical to BPF_KPROBE and BPF_KRETPROBE,
* but are named way less confusingly for SEC("uprobe") and SEC("uretprobe")
* use cases.
*/
#define BPF_UPROBE(name, args...) BPF_KPROBE(name, ##args)
#define BPF_URETPROBE(name, args...) BPF_KRETPROBE(name, ##args)
#endif

View File

@@ -0,0 +1,116 @@
#include "vmlinux-x86.h"
#include "bpf/bpf_helpers.h"
#include "bpf/bpf_tracing.h"
#define ARGLEN 32 // maximum amount of args in argv we'll copy
#define ARGSIZE 1024 // maximum byte length of each arg in argv we'll copy
char __license[] SEC("license") = "GPL";
// Ring buffer for all connection events
struct {
__uint(type, BPF_MAP_TYPE_RINGBUF);
__uint(max_entries, 1 << 24);
} pm_exec_map SEC(".maps");
// This struct is defined according to
// /sys/kernel/debug/tracing/events/syscalls/sys_enter_execve/format
struct exec_info {
u16 common_type; // offset=0, size=2
u8 common_flags; // offset=2, size=1
u8 common_preempt_count; // offset=3, size=1
s32 common_pid; // offset=4, size=4
s32 syscall_nr; // offset=8, size=4
u32 pad; // offset=12, size=4 (pad)
const u8 *filename; // offset=16, size=8 (ptr)
const u8 *const *argv; // offset=24, size=8 (ptr)
const u8 *const *envp; // offset=32, size=8 (ptr)
};
// The event struct. This struct must be kept in sync with the Golang
// counterpart.
struct event_t {
// Details about the process being launched.
u8 filename[ARGSIZE];
u8 argv[ARGLEN][ARGSIZE];
u32 argc; // set to ARGLEN + 1 if there were more than ARGLEN arguments
u32 uid;
u32 gid;
u32 pid;
// Name of the calling process.
u8 comm[ARGSIZE];
};
// Tracepoint at the top of execve() syscall.
SEC("tracepoint/syscalls/sys_enter_execve")
s32 enter_execve(struct exec_info *ctx) {
// Reserve memory for our event on the `events` ring buffer defined above.
struct event_t *event;
event = bpf_ringbuf_reserve(&pm_exec_map, sizeof(struct event_t), 0);
if (!event) {
bpf_printk("could not reserve ringbuf memory");
return 1;
}
// Store process/calling process details.
u64 uidgid = bpf_get_current_uid_gid();
u64 pidtgid = bpf_get_current_pid_tgid();
event->uid = uidgid; // uid is the first 32 bits
event->gid = uidgid >> 32; // gid is the last 32 bits NOLINT(readability-magic-numbers)
event->pid = pidtgid; // pid is the first 32 bits
s32 ret = bpf_get_current_comm(&event->comm, sizeof(event->comm));
if (ret) {
bpf_printk("could not get current comm: %d", ret);
bpf_ringbuf_discard(event, 0);
return 1;
}
// Write the filename in addition to argv[0] because the filename contains
// the full path to the file which could be more useful in some situations.
ret = bpf_probe_read_user_str(event->filename, sizeof(event->filename), ctx->filename);
if (ret < 0) {
bpf_printk("could not read filename into event struct: %d", ret);
bpf_ringbuf_discard(event, 0);
return 1;
}
// Copy everything from ctx->argv to event->argv, incrementing event->argc
// as we go.
for (s32 i = 0; i < ARGLEN; i++) {
if (!(&ctx->argv[i])) {
goto out;
}
// Copying the arg into it's own variable before copying it into
// event->argv[i] prevents memory corruption.
const u8 *argp = NULL;
ret = bpf_probe_read_user(&argp, sizeof(argp), &ctx->argv[i]);
if (ret || !argp) {
goto out;
}
// Copy argp to event->argv[i].
ret = bpf_probe_read_user_str(event->argv[i], sizeof(event->argv[i]), argp);
if (ret < 0) {
bpf_printk("read argv %d: %d", i, ret);
goto out;
}
event->argc++;
}
// This won't get hit if we `goto out` in the loop above. This is to signify
// to userspace that we couldn't copy all of the arguments because it
// exceeded ARGLEN.
event->argc++;
out:
// Write the event to the ring buffer and notify userspace. This will cause
// the `Read()` call in userspace to return if it was blocked.
bpf_ringbuf_submit(event, 0);
return 0;
}

View File

@@ -0,0 +1,186 @@
#include "vmlinux-x86.h"
#include "bpf/bpf_helpers.h"
#include "bpf/bpf_tracing.h"
// IP Version
#define AF_INET 2
#define AF_INET6 10
// Protocols
#define TCP 6
#define UDP 17
#define UDPLite 136
#define OUTBOUND 0
#define INBOUND 1
char __license[] SEC("license") = "GPL";
// Ring buffer for all connection events
struct {
__uint(type, BPF_MAP_TYPE_RINGBUF);
__uint(max_entries, 1 << 24);
} pm_connection_events SEC(".maps");
// Event struct that will be sent to Go on each new connection. (The name should be the same as the go generate command)
struct Event {
u32 saddr[4];
u32 daddr[4];
u16 sport;
u16 dport;
u32 pid;
u8 ipVersion;
u8 protocol;
u8 direction;
};
struct Event *unused __attribute__((unused));
// Fentry of tcp_connect will be executed when equivalent kernel function is called.
// In the kernel all IP address and ports should be set before tcp_connect is called. [this-function] -> tcp_connect
SEC("fentry/tcp_connect")
int BPF_PROG(tcp_connect, struct sock *sk) {
// Alloc space for the event
struct Event *tcp_info;
tcp_info = bpf_ringbuf_reserve(&pm_connection_events, sizeof(struct Event), 0);
if (!tcp_info) {
return 0;
}
// Read PID (Careful: This is the Thread Group ID in kernel speak!)
tcp_info->pid = __builtin_bswap32((u32)(bpf_get_current_pid_tgid() >> 32));
// Set protocol
tcp_info->protocol = TCP;
// Set direction
tcp_info->direction = OUTBOUND;
// Set src and dist ports
tcp_info->sport = __builtin_bswap16(sk->__sk_common.skc_num);
tcp_info->dport = sk->__sk_common.skc_dport;
// Set src and dist IPs
if (sk->__sk_common.skc_family == AF_INET) {
tcp_info->saddr[0] = __builtin_bswap32(sk->__sk_common.skc_rcv_saddr);
tcp_info->daddr[0] = __builtin_bswap32(sk->__sk_common.skc_daddr);
// Set IP version
tcp_info->ipVersion = 4;
} else if (sk->__sk_common.skc_family == AF_INET6) {
for(int i = 0; i < 4; i++) {
tcp_info->saddr[i] = __builtin_bswap32(sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[i]);
}
for(int i = 0; i < 4; i++) {
tcp_info->daddr[i] = __builtin_bswap32(sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[i]);
}
// Set IP version
tcp_info->ipVersion = 6;
}
// Send event
bpf_ringbuf_submit(tcp_info, 0);
return 0;
};
// Fexit(function exit) of udp_v4_connect will be executed after the ip4_datagram_connect kernel function is called.
// ip4_datagram_connect -> udp_v4_connect
SEC("fexit/ip4_datagram_connect")
int BPF_PROG(udp_v4_connect, struct sock *sk) {
// Ignore everything else then IPv4
if (sk->__sk_common.skc_family != AF_INET) {
return 0;
}
// ip4_datagram_connect return error
if (sk->__sk_common.skc_dport == 0) {
return 0;
}
// Allocate space for the event.
struct Event *udp_info;
udp_info = bpf_ringbuf_reserve(&pm_connection_events, sizeof(struct Event), 0);
if (!udp_info) {
return 0;
}
// Read PID (Careful: This is the Thread Group ID in kernel speak!)
udp_info->pid = __builtin_bswap32((u32)(bpf_get_current_pid_tgid() >> 32));
// Set src and dst ports
udp_info->sport = __builtin_bswap16(sk->__sk_common.skc_num);
udp_info->dport = sk->__sk_common.skc_dport;
// Set src and dst IPs
udp_info->saddr[0] = __builtin_bswap32(sk->__sk_common.skc_rcv_saddr);
udp_info->daddr[0] = __builtin_bswap32(sk->__sk_common.skc_daddr);
// Set IP version
udp_info->ipVersion = 4;
// Set protocol
if(sk->sk_protocol == IPPROTO_UDPLITE) {
udp_info->protocol = UDPLite;
} else {
udp_info->protocol = UDP;
}
// Send event
bpf_ringbuf_submit(udp_info, 0);
return 0;
}
// Fentry(function enter) of udp_v6_connect will be executed after the ip6_datagram_connect kernel function is called.
// ip6_datagram_connect -> udp_v6_connect
SEC("fexit/ip6_datagram_connect")
int BPF_PROG(udp_v6_connect, struct sock *sk) {
// Ignore everything else then IPv6
if (sk->__sk_common.skc_family != AF_INET6) {
return 0;
}
// ip6_datagram_connect return error
if (sk->__sk_common.skc_dport == 0) {
return 0;
}
// Make sure its udp6 socket
struct udp6_sock *us = bpf_skc_to_udp6_sock(sk);
if (!us) {
return 0;
}
// Allocate space for the event.
struct Event *udp_info;
udp_info = bpf_ringbuf_reserve(&pm_connection_events, sizeof(struct Event), 0);
if (!udp_info) {
return 0;
}
// Read PID (Careful: This is the Thread Group ID in kernel speak!)
udp_info->pid = __builtin_bswap32((u32)(bpf_get_current_pid_tgid() >> 32));
// Set src and dst ports
udp_info->sport = __builtin_bswap16(sk->__sk_common.skc_num);
udp_info->dport = sk->__sk_common.skc_dport;
// Set src and dst IPs
for(int i = 0; i < 4; i++) {
udp_info->saddr[i] = __builtin_bswap32(sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[i]);
}
for(int i = 0; i < 4; i++) {
udp_info->daddr[i] = __builtin_bswap32(sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[i]);
}
// IP version
udp_info->ipVersion = 6;
// Set protocol
if(sk->sk_protocol == IPPROTO_UDPLITE) {
udp_info->protocol = UDPLite;
} else {
udp_info->protocol = UDP;
}
// Send event
bpf_ringbuf_submit(udp_info, 0);
return 0;
}

View File

@@ -0,0 +1,17 @@
#!/usr/bin/env bash
# Version of libbpf to fetch headers from
LIBBPF_VERSION=1.2.0
# The headers we want
prefix=libbpf-"$LIBBPF_VERSION"
headers=(
"$prefix"/src/bpf_core_read.h
"$prefix"/src/bpf_helper_defs.h
"$prefix"/src/bpf_helpers.h
"$prefix"/src/bpf_tracing.h
)
# Fetch libbpf release and extract the desired headers
curl -sL "https://github.com/libbpf/libbpf/archive/refs/tags/v${LIBBPF_VERSION}.tar.gz" | \
tar -xz --xform='s#.*/#bpf/#' "${headers[@]}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
//go:build !windows && !linux
package interception
import (
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
)
// start starts the interception.
func startInterception(_ chan packet.Packet) error {
log.Critical("interception: this platform has no support for packet interception - a lot of functionality will be broken")
return nil
}
// stop starts the interception.
func stopInterception() error {
return nil
}
// ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again.
func ResetVerdictOfAllConnections() error {
return nil
}
// UpdateVerdictOfConnection updates the verdict of the given connection in the OS integration.
func UpdateVerdictOfConnection(conn *network.Connection) error {
return nil
}

View File

@@ -0,0 +1,48 @@
package interception
import (
"context"
"time"
bandwidth "github.com/safing/portmaster/service/firewall/interception/ebpf/bandwidth"
conn_listener "github.com/safing/portmaster/service/firewall/interception/ebpf/connection_listener"
"github.com/safing/portmaster/service/firewall/interception/nfq"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
)
// start starts the interception.
func startInterception(packets chan packet.Packet) error {
// Start packet interception via nfqueue.
err := StartNfqueueInterception(packets)
if err != nil {
return err
}
// Start ebpf new connection listener.
module.StartServiceWorker("ebpf connection listener", 0, func(ctx context.Context) error {
return conn_listener.ConnectionListenerWorker(ctx, packets)
})
// Start ebpf bandwidth stats monitor.
module.StartServiceWorker("ebpf bandwidth stats monitor", 0, func(ctx context.Context) error {
return bandwidth.BandwidthStatsWorker(ctx, 1*time.Second, BandwidthUpdates)
})
return nil
}
// stop starts the interception.
func stopInterception() error {
return StopNfqueueInterception()
}
// ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again.
func ResetVerdictOfAllConnections() error {
return nfq.DeleteAllMarkedConnection()
}
// UpdateVerdictOfConnection deletes the verdict of the given connection so it can be initialized again with the next packet.
func UpdateVerdictOfConnection(conn *network.Connection) error {
return nfq.DeleteMarkedConnection(conn)
}

View File

@@ -0,0 +1,172 @@
package interception
import (
"context"
"fmt"
"time"
"github.com/safing/portbase/log"
kext1 "github.com/safing/portmaster/firewall/interception/windowskext"
kext2 "github.com/safing/portmaster/firewall/interception/windowskext2"
"github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/updates"
)
var useOldKext = false
// start starts the interception.
func startInterception(packets chan packet.Packet) error {
kextFile, err := updates.GetPlatformFile("kext/portmaster-kext.sys")
if err != nil {
return fmt.Errorf("interception: could not get kext sys: %s", err)
}
err = kext2.Init(kextFile.Path())
if err != nil {
return fmt.Errorf("interception: could not init windows kext: %s", err)
}
err = kext2.Start()
if err != nil {
return fmt.Errorf("interception: could not start windows kext: %s", err)
}
version, err := kext2.GetVersion()
if err != nil {
return fmt.Errorf("interception: failed to read version: %s", err)
}
log.Debugf("Kext version: %s", version.String())
if version.Major < 2 {
useOldKext = true
// Transfer ownership.
kext1.SetKextHandler(kext2.GetKextHandle())
kext1.SetKextService(kext2.GetKextServiceHandle(), kextFile.Path())
// Start packet handler.
module.StartServiceWorker("kext packet handler", 0, func(ctx context.Context) error {
kext1.Handler(ctx, packets)
return nil
})
// Start bandwidth stats monitor.
module.StartServiceWorker("kext bandwidth stats monitor", 0, func(ctx context.Context) error {
return kext1.BandwidthStatsWorker(ctx, 1*time.Second, BandwidthUpdates)
})
} else {
// Start packet handler.
module.StartServiceWorker("kext packet handler", 0, func(ctx context.Context) error {
kext2.Handler(ctx, packets, BandwidthUpdates)
return nil
})
// Start bandwidth stats monitor.
module.StartServiceWorker("kext bandwidth request worker", 0, func(ctx context.Context) error {
timer := time.NewTicker(1 * time.Second)
for {
select {
case <-timer.C:
{
err := kext2.SendBandwidthStatsRequest()
if err != nil {
return err
}
}
case <-ctx.Done():
{
return nil
}
}
}
})
// Start kext logging. The worker will periodically send request to the kext to send logs.
module.StartServiceWorker("kext log request worker", 0, func(ctx context.Context) error {
timer := time.NewTicker(1 * time.Second)
for {
select {
case <-timer.C:
{
err := kext2.SendLogRequest()
if err != nil {
return err
}
}
case <-ctx.Done():
{
return nil
}
}
}
})
module.StartServiceWorker("kext clean ended connection worker", 0, func(ctx context.Context) error {
timer := time.NewTicker(30 * time.Second)
for {
select {
case <-timer.C:
{
err := kext2.SendCleanEndedConnection()
if err != nil {
return err
}
}
case <-ctx.Done():
{
return nil
}
}
}
})
}
return nil
}
// stop starts the interception.
func stopInterception() error {
if useOldKext {
return kext1.Stop()
}
return kext2.Stop()
}
// ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again.
func ResetVerdictOfAllConnections() error {
if useOldKext {
return kext1.ClearCache()
}
return kext2.ClearCache()
}
// UpdateVerdictOfConnection updates the verdict of the given connection in the kernel extension.
func UpdateVerdictOfConnection(conn *network.Connection) error {
if useOldKext {
return kext1.UpdateVerdict(conn)
}
return kext2.UpdateVerdict(conn)
}
// GetKextVersion returns the version of the kernel extension.
func GetKextVersion() (string, error) {
if useOldKext {
version, err := kext1.GetVersion()
if err != nil {
return "", err
}
return version.String(), nil
} else {
version, err := kext2.GetVersion()
if err != nil {
return "", err
}
return version.String(), nil
}
}

View File

@@ -0,0 +1,80 @@
package interception
import (
"flag"
"fmt"
"os"
"sync"
"time"
"github.com/safing/portbase/log"
)
var (
packetMetricsDestination string
metrics = &packetMetrics{
done: make(chan struct{}),
}
)
func init() {
flag.StringVar(&packetMetricsDestination, "write-packet-metrics", "", "write packet metrics to the specified file")
}
type (
performanceRecord struct {
start int64
duration time.Duration
verdict string
}
packetMetrics struct {
done chan struct{}
l sync.Mutex
records []*performanceRecord
}
)
func (pm *packetMetrics) record(tp *tracedPacket, verdict string) {
go func(start int64, duration time.Duration) {
pm.l.Lock()
defer pm.l.Unlock()
pm.records = append(pm.records, &performanceRecord{
start: start,
duration: duration,
verdict: verdict,
})
}(tp.start.UnixNano(), time.Since(tp.start))
}
func (pm *packetMetrics) writeMetrics() {
if packetMetricsDestination == "" {
return
}
f, err := os.Create(packetMetricsDestination)
if err != nil {
log.Errorf("Failed to create packet metrics file: %s", err)
return
}
defer func() {
_ = f.Close()
}()
for {
select {
case <-pm.done:
return
case <-time.After(time.Second * 5):
}
pm.l.Lock()
records := pm.records
pm.records = nil
pm.l.Unlock()
for _, r := range records {
fmt.Fprintf(f, "%d;%s;%s;%.2f\n", r.start, r.verdict, r.duration, float64(r.duration)/float64(time.Microsecond))
}
}
}

View File

@@ -0,0 +1,63 @@
package interception
import (
"flag"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/service/network/packet"
)
var (
module *modules.Module
// Packets is a stream of interception network packest.
Packets = make(chan packet.Packet, 1000)
// BandwidthUpdates is a stream of bandwidth usage update for connections.
BandwidthUpdates = make(chan *packet.BandwidthUpdate, 1000)
disableInterception bool
)
func init() {
flag.BoolVar(&disableInterception, "disable-interception", false, "disable packet interception; this breaks a lot of functionality")
module = modules.Register("interception", prep, start, stop, "base", "updates", "network", "notifications", "profiles")
}
func prep() error {
return nil
}
// Start starts the interception.
func start() error {
if disableInterception {
log.Warning("interception: packet interception is disabled via flag - this breaks a lot of functionality")
return nil
}
inputPackets := Packets
if packetMetricsDestination != "" {
go metrics.writeMetrics()
inputPackets = make(chan packet.Packet)
go func() {
for p := range inputPackets {
Packets <- tracePacket(p)
}
}()
}
return startInterception(inputPackets)
}
// Stop starts the interception.
func stop() error {
if disableInterception {
return nil
}
close(metrics.done)
return stopInterception()
}

View File

@@ -0,0 +1,126 @@
//go:build linux
package nfq
import (
"encoding/binary"
"errors"
"fmt"
ct "github.com/florianl/go-conntrack"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network"
)
var nfct *ct.Nfct // Conntrack handler. NFCT: Network Filter Connection Tracking.
// InitNFCT initializes the network filter conntrack library.
func InitNFCT() error {
var err error
nfct, err = ct.Open(&ct.Config{})
if err != nil {
return err
}
return nil
}
// TeardownNFCT deinitializes the network filter conntrack library.
func TeardownNFCT() {
if nfct != nil {
_ = nfct.Close()
}
}
// DeleteAllMarkedConnection deletes all marked entries from the conntrack table.
func DeleteAllMarkedConnection() error {
if nfct == nil {
return errors.New("nfq: nfct not initialized")
}
// Delete all ipv4 marked connections
deleted := deleteMarkedConnections(nfct, ct.IPv4)
if netenv.IPv6Enabled() {
// Delete all ipv6 marked connections
deleted += deleteMarkedConnections(nfct, ct.IPv6)
}
log.Infof("nfq: deleted %d conntrack entries to reset permanent connection verdicts", deleted)
return nil
}
func deleteMarkedConnections(nfct *ct.Nfct, f ct.Family) (deleted int) {
// initialize variables
permanentFlags := []uint32{MarkAcceptAlways, MarkBlockAlways, MarkDropAlways, MarkRerouteNS, MarkRerouteSPN}
filter := ct.FilterAttr{}
filter.MarkMask = []byte{0xFF, 0xFF, 0xFF, 0xFF}
filter.Mark = []byte{0x00, 0x00, 0x00, 0x00} // 4 zeros starting value
numberOfErrors := 0
var deleteError error = nil
// Get all connections from the specified family (ipv4 or ipv6)
for _, mark := range permanentFlags {
binary.BigEndian.PutUint32(filter.Mark, mark) // Little endian is in reverse not sure why. BigEndian makes it in correct order.
currentConnections, err := nfct.Query(ct.Conntrack, f, filter)
if err != nil {
log.Warningf("nfq: error on conntrack query: %s", err)
continue
}
for _, connection := range currentConnections {
deleteError = nfct.Delete(ct.Conntrack, ct.IPv4, connection)
if err != nil {
numberOfErrors++
} else {
deleted++
}
}
}
if numberOfErrors > 0 {
log.Warningf("nfq: failed to delete %d conntrack entries last error is: %s", numberOfErrors, deleteError)
}
return deleted
}
// DeleteMarkedConnection removes a specific connection from the conntrack table.
func DeleteMarkedConnection(conn *network.Connection) error {
if nfct == nil {
return errors.New("nfq: nfct not initialized")
}
con := ct.Con{
Origin: &ct.IPTuple{
Src: &conn.LocalIP,
Dst: &conn.Entity.IP,
Proto: &ct.ProtoTuple{
Number: &conn.Entity.Protocol,
SrcPort: &conn.LocalPort,
DstPort: &conn.Entity.Port,
},
},
}
connections, err := nfct.Get(ct.Conntrack, ct.IPv4, con)
if err != nil {
return fmt.Errorf("nfq: failed to find entry for connection %s: %w", conn.String(), err)
}
if len(connections) > 1 {
log.Warningf("nfq: multiple entries found for single connection: %s -> %d", conn.String(), len(connections))
}
for _, connection := range connections {
deleteErr := nfct.Delete(ct.Conntrack, ct.IPv4, connection)
if err == nil {
err = deleteErr
}
}
if err != nil {
log.Warningf("nfq: error while deleting conntrack entries for connection %s: %s", conn.String(), err)
}
return nil
}

View File

@@ -0,0 +1,241 @@
//go:build linux
// Package nfq contains a nfqueue library experiment.
package nfq
import (
"context"
"runtime"
"strings"
"sync/atomic"
"time"
"github.com/florianl/go-nfqueue"
"github.com/tevino/abool"
"golang.org/x/sys/unix"
"github.com/safing/portbase/log"
pmpacket "github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/process"
)
// Queue wraps a nfqueue.
type Queue struct {
id uint16
afFamily uint8
nf atomic.Value
packets chan pmpacket.Packet
cancelSocketCallback context.CancelFunc
restart chan struct{}
pendingVerdicts uint64
verdictCompleted chan struct{}
}
func (q *Queue) getNfq() *nfqueue.Nfqueue {
return q.nf.Load().(*nfqueue.Nfqueue) //nolint:forcetypeassert // TODO: Check.
}
// New opens a new nfQueue.
func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
afFamily := unix.AF_INET
if v6 {
afFamily = unix.AF_INET6
}
ctx, cancel := context.WithCancel(context.Background())
q := &Queue{
id: qid,
afFamily: uint8(afFamily),
nf: atomic.Value{},
restart: make(chan struct{}, 1),
packets: make(chan pmpacket.Packet, 1000),
cancelSocketCallback: cancel,
verdictCompleted: make(chan struct{}, 1),
}
// Do not retry if the first one fails immediately as it
// might point to a deeper integration error that's not fixable
// with retrying ...
if err := q.open(ctx); err != nil {
return nil, err
}
go func() {
Wait:
for {
select {
case <-ctx.Done():
return
case <-q.restart:
runtime.Gosched()
}
for {
err := q.open(ctx)
if err == nil {
continue Wait
}
// Wait 100 ms and then try again ...
log.Errorf("Failed to open nfqueue: %s", err)
select {
case <-ctx.Done():
return
case <-time.After(100 * time.Millisecond):
}
}
}
}()
return q, nil
}
// open opens a new netlink socket and creates a new nfqueue.
// Upon success, the new nfqueue is atomically stored in Queue.nf.
// Users must use Queue.getNfq to access it. open does not care about
// any other value or queue that might be stored in Queue.nf at
// the time open is called.
func (q *Queue) open(ctx context.Context) error {
cfg := &nfqueue.Config{
NfQueue: q.id,
MaxPacketLen: 1600, // mtu is normally around 1500, make sure to capture it.
MaxQueueLen: 0xffff,
AfFamily: q.afFamily,
Copymode: nfqueue.NfQnlCopyPacket,
ReadTimeout: 1000 * time.Millisecond,
WriteTimeout: 1000 * time.Millisecond,
}
nf, err := nfqueue.Open(cfg)
if err != nil {
return err
}
if err := nf.RegisterWithErrorFunc(ctx, q.packetHandler(ctx), q.handleError); err != nil {
_ = nf.Close()
return err
}
q.nf.Store(nf)
return nil
}
func (q *Queue) handleError(e error) int {
// embedded interface is required to work-around some
// dep-vendoring weirdness
if opError, ok := e.(interface { //nolint:errorlint // TODO: Check if we can remove workaround.
Timeout() bool
Temporary() bool
}); ok {
if opError.Timeout() || opError.Temporary() {
c := atomic.LoadUint64(&q.pendingVerdicts)
if c > 0 {
log.Tracef("nfqueue: waiting for %d pending verdicts", c)
for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here
<-q.verdictCompleted
}
}
return 0
}
}
// Check if the queue was already closed. Unfortunately, the exposed error
// variable is in an internal stdlib package. Therefore, check for the error
// string instead. :(
// Official error variable is defined here:
// https://github.com/golang/go/blob/0e85fd7561de869add933801c531bf25dee9561c/src/internal/poll/fd.go#L24
if !strings.HasSuffix(e.Error(), "use of closed file") {
log.Errorf("nfqueue: encountered error while receiving packets: %s\n", e.Error())
}
// Close the existing socket
if nf := q.getNfq(); nf != nil {
// Call Close() on the Con directly, as nf.Close() calls waitgroup.Wait(), which then may deadlock.
_ = nf.Con.Close()
}
// Trigger a restart of the queue
q.restart <- struct{}{}
return 1
}
func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int {
return func(attrs nfqueue.Attribute) int {
if attrs.PacketID == nil {
// we need a packet id to set a verdict,
// if we don't get an ID there's hardly anything
// we can do.
return 0
}
pkt := &packet{
pktID: *attrs.PacketID,
queue: q,
verdictSet: make(chan struct{}),
verdictPending: abool.New(),
}
pkt.Info().PID = process.UndefinedProcessID
pkt.Info().SeenAt = time.Now()
if attrs.Payload == nil {
// There is not payload.
log.Warningf("nfqueue: packet #%d has no payload", pkt.pktID)
return 0
}
if err := pmpacket.Parse(*attrs.Payload, &pkt.Base); err != nil {
log.Warningf("nfqueue: failed to parse payload: %s", err)
_ = pkt.Drop()
return 0
}
select {
case q.packets <- pkt:
// DEBUG:
// log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt))
case <-ctx.Done():
return 0
case <-time.After(time.Second):
log.Warningf("nfqueue: failed to queue packet (%s since it was handed over by the kernel)", time.Since(pkt.Info().SeenAt))
}
go func() {
select {
case <-pkt.verdictSet:
case <-time.After(20 * time.Second):
log.Warningf("nfqueue: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt))
if err := pkt.Drop(); err != nil {
log.Warningf("nfqueue: failed to apply default-drop to unveridcted packet %s (%s -> %s)", pkt.ID(), pkt.Info().Src, pkt.Info().Dst)
}
}
}()
return 0 // continue calling this fn
}
}
// Destroy destroys the queue. Any error encountered is logged.
func (q *Queue) Destroy() {
if q == nil {
return
}
q.cancelSocketCallback()
if nf := q.getNfq(); nf != nil {
if err := nf.Close(); err != nil {
log.Errorf("nfqueue: failed to close queue %d: %s", q.id, err)
}
}
}
// PacketChannel returns the packet channel.
func (q *Queue) PacketChannel() <-chan pmpacket.Packet {
return q.packets
}

View File

@@ -0,0 +1,178 @@
//go:build linux
package nfq
import (
"errors"
"fmt"
"sync/atomic"
"github.com/florianl/go-nfqueue"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
pmpacket "github.com/safing/portmaster/service/network/packet"
)
// Firewalling marks used by the Portmaster.
// See TODO on packet.mark() on their relevance
// and a possibility to remove most IPtables rules.
const (
MarkAccept = 1700
MarkBlock = 1701
MarkDrop = 1702
MarkAcceptAlways = 1710
MarkBlockAlways = 1711
MarkDropAlways = 1712
MarkRerouteNS = 1799
MarkRerouteSPN = 1717
)
func markToString(mark int) string {
switch mark {
case MarkAccept:
return "Accept"
case MarkBlock:
return "Block"
case MarkDrop:
return "Drop"
case MarkAcceptAlways:
return "AcceptAlways"
case MarkBlockAlways:
return "BlockAlways"
case MarkDropAlways:
return "DropAlways"
case MarkRerouteNS:
return "RerouteNS"
case MarkRerouteSPN:
return "RerouteSPN"
}
return "unknown"
}
// packet implements the packet.Packet interface.
type packet struct {
pmpacket.Base
pktID uint32
queue *Queue
verdictSet chan struct{}
verdictPending *abool.AtomicBool
}
func (pkt *packet) ID() string {
return fmt.Sprintf("pkt:%d qid:%d", pkt.pktID, pkt.queue.id)
}
// LoadPacketData does nothing on Linux, as data is always fully parsed.
func (pkt *packet) LoadPacketData() error {
return nil
}
// TODO(ppacher): revisit the following behavior:
//
// The legacy implementation of nfqueue (and the interception) module
// always accept a packet but may mark it so that a subsequent rule in
// the C17 chain drops, rejects or modifies it.
//
// For drop/return we could use the actual nfQueue verdicts Drop and Stop.
// Re-routing to local NS or SPN can be done by modifying the packet here
// and using SetVerdictModPacket and reject can be implemented using a simple
// raw-socket.
func (pkt *packet) mark(mark int) (err error) {
if pkt.verdictPending.SetToIf(false, true) {
defer close(pkt.verdictSet)
return pkt.setMark(mark)
}
return errors.New("verdict already set")
}
func (pkt *packet) setMark(mark int) error {
atomic.AddUint64(&pkt.queue.pendingVerdicts, 1)
defer func() {
atomic.AddUint64(&pkt.queue.pendingVerdicts, ^uint64(0))
select {
case pkt.queue.verdictCompleted <- struct{}{}:
default:
}
}()
for {
if err := pkt.queue.getNfq().SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil {
// embedded interface is required to work-around some
// dep-vendoring weirdness
if opErr, ok := err.(interface { //nolint:errorlint // TODO: Check if we can remove workaround.
Timeout() bool
Temporary() bool
}); ok {
if opErr.Timeout() || opErr.Temporary() {
continue
}
}
log.Tracer(pkt.Ctx()).Errorf("nfqueue: failed to set verdict %s for %s (%s -> %s): %s", markToString(mark), pkt.ID(), pkt.Info().Src, pkt.Info().Dst, err)
return err
}
break
}
// DEBUG:
// log.Tracer(pkt.Ctx()).Tracef(
// "nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s",
// pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id,
// markToString(mark), time.Since(pkt.Info().SeenAt),
// )
return nil
}
func (pkt *packet) Accept() error {
return pkt.mark(MarkAccept)
}
func (pkt *packet) Block() error {
if pkt.Info().Protocol == pmpacket.ICMP {
// ICMP packets attributed to a blocked connection are always allowed, as
// rejection ICMP packets will have the same mark as the blocked
// connection. This is why we need to drop blocked ICMP packets instead.
return pkt.mark(MarkDrop)
}
return pkt.mark(MarkBlock)
}
func (pkt *packet) Drop() error {
return pkt.mark(MarkDrop)
}
func (pkt *packet) PermanentAccept() error {
// If the packet is localhost only, do not permanently accept the outgoing
// packet, as the packet mark will be copied to the connection mark, which
// will stick and it will bypass the incoming queue.
if !pkt.Info().Inbound && pkt.Info().Dst.IsLoopback() {
return pkt.Accept()
}
return pkt.mark(MarkAcceptAlways)
}
func (pkt *packet) PermanentBlock() error {
if pkt.Info().Protocol == pmpacket.ICMP || pkt.Info().Protocol == pmpacket.ICMPv6 {
// ICMP packets attributed to a blocked connection are always allowed, as
// rejection ICMP packets will have the same mark as the blocked
// connection. This is why we need to drop blocked ICMP packets instead.
return pkt.mark(MarkDropAlways)
}
return pkt.mark(MarkBlockAlways)
}
func (pkt *packet) PermanentDrop() error {
return pkt.mark(MarkDropAlways)
}
func (pkt *packet) RerouteToNameserver() error {
return pkt.mark(MarkRerouteNS)
}
func (pkt *packet) RerouteToTunnel() error {
return pkt.mark(MarkRerouteSPN)
}

View File

@@ -0,0 +1,354 @@
package interception
import (
"context"
"flag"
"fmt"
"sort"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/firewall/interception/nfq"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network/packet"
)
var (
v4chains []string
v4rules []string
v4once []string
v6chains []string
v6rules []string
v6once []string
out4Queue nfQueue
in4Queue nfQueue
out6Queue nfQueue
in6Queue nfQueue
shutdownSignal = make(chan struct{})
experimentalNfqueueBackend bool
)
func init() {
flag.BoolVar(&experimentalNfqueueBackend, "experimental-nfqueue", false, "(deprecated flag; always used)")
}
// nfQueue encapsulates nfQueue providers.
type nfQueue interface {
PacketChannel() <-chan packet.Packet
Destroy()
}
func init() {
v4chains = []string{
"mangle PORTMASTER-INGEST-OUTPUT",
"mangle PORTMASTER-INGEST-INPUT",
"filter PORTMASTER-FILTER",
"nat PORTMASTER-REDIRECT",
}
v4rules = []string{
"mangle PORTMASTER-INGEST-OUTPUT -j CONNMARK --restore-mark",
"mangle PORTMASTER-INGEST-OUTPUT -m mark --mark 0 -j NFQUEUE --queue-num 17040 --queue-bypass",
"mangle PORTMASTER-INGEST-INPUT -j CONNMARK --restore-mark",
"mangle PORTMASTER-INGEST-INPUT -m mark --mark 0 -j NFQUEUE --queue-num 17140 --queue-bypass",
"filter PORTMASTER-FILTER -m mark --mark 0 -j DROP",
"filter PORTMASTER-FILTER -m mark --mark 1700 -j RETURN",
// Accepting ICMP packets with mark 1701 is required for rejecting to work,
// as the rejection ICMP packet will have the same mark. Blocked ICMP
// packets will always result in a drop within the Portmaster.
"filter PORTMASTER-FILTER -m mark --mark 1701 -p icmp -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1701 -j REJECT --reject-with icmp-admin-prohibited",
"filter PORTMASTER-FILTER -m mark --mark 1702 -j DROP",
"filter PORTMASTER-FILTER -j CONNMARK --save-mark",
"filter PORTMASTER-FILTER -m mark --mark 1710 -j RETURN",
// Accepting ICMP packets with mark 1711 is required for rejecting to work,
// as the rejection ICMP packet will have the same mark. Blocked ICMP
// packets will always result in a drop within the Portmaster.
"filter PORTMASTER-FILTER -m mark --mark 1711 -p icmp -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1711 -j REJECT --reject-with icmp-admin-prohibited",
"filter PORTMASTER-FILTER -m mark --mark 1712 -j DROP",
"filter PORTMASTER-FILTER -m mark --mark 1717 -j RETURN",
"nat PORTMASTER-REDIRECT -m mark --mark 1799 -p udp -j DNAT --to 127.0.0.17:53",
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p tcp -j DNAT --to 127.0.0.17:717",
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p udp -j DNAT --to 127.0.0.17:717",
// "nat PORTMASTER-REDIRECT -m mark --mark 1717 ! -p tcp ! -p udp -j DNAT --to 127.0.0.17",
}
v4once = []string{
"mangle OUTPUT -j PORTMASTER-INGEST-OUTPUT",
"mangle INPUT -j PORTMASTER-INGEST-INPUT",
"filter OUTPUT -j PORTMASTER-FILTER",
"filter INPUT -j PORTMASTER-FILTER",
"nat OUTPUT -j PORTMASTER-REDIRECT",
}
v6chains = []string{
"mangle PORTMASTER-INGEST-OUTPUT",
"mangle PORTMASTER-INGEST-INPUT",
"filter PORTMASTER-FILTER",
"nat PORTMASTER-REDIRECT",
}
v6rules = []string{
"mangle PORTMASTER-INGEST-OUTPUT -j CONNMARK --restore-mark",
"mangle PORTMASTER-INGEST-OUTPUT -m mark --mark 0 -j NFQUEUE --queue-num 17060 --queue-bypass",
"mangle PORTMASTER-INGEST-INPUT -j CONNMARK --restore-mark",
"mangle PORTMASTER-INGEST-INPUT -m mark --mark 0 -j NFQUEUE --queue-num 17160 --queue-bypass",
"filter PORTMASTER-FILTER -m mark --mark 0 -j DROP",
"filter PORTMASTER-FILTER -m mark --mark 1700 -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1701 -p icmpv6 -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1701 -j REJECT --reject-with icmp6-adm-prohibited",
"filter PORTMASTER-FILTER -m mark --mark 1702 -j DROP",
"filter PORTMASTER-FILTER -j CONNMARK --save-mark",
"filter PORTMASTER-FILTER -m mark --mark 1710 -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1711 -p icmpv6 -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1711 -j REJECT --reject-with icmp6-adm-prohibited",
"filter PORTMASTER-FILTER -m mark --mark 1712 -j DROP",
"filter PORTMASTER-FILTER -m mark --mark 1717 -j RETURN",
"nat PORTMASTER-REDIRECT -m mark --mark 1799 -p udp -j DNAT --to [::1]:53",
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p tcp -j DNAT --to [::1]:717",
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p udp -j DNAT --to [::1]:717",
// "nat PORTMASTER-REDIRECT -m mark --mark 1717 ! -p tcp ! -p udp -j DNAT --to [::1]",
}
v6once = []string{
"mangle OUTPUT -j PORTMASTER-INGEST-OUTPUT",
"mangle INPUT -j PORTMASTER-INGEST-INPUT",
"filter OUTPUT -j PORTMASTER-FILTER",
"filter INPUT -j PORTMASTER-FILTER",
"nat OUTPUT -j PORTMASTER-REDIRECT",
}
// Reverse because we'd like to insert in a loop
_ = sort.Reverse(sort.StringSlice(v4once)) // silence vet (sort is used just like in the docs)
_ = sort.Reverse(sort.StringSlice(v6once)) // silence vet (sort is used just like in the docs)
}
func activateNfqueueFirewall() error {
if err := activateIPTables(iptables.ProtocolIPv4, v4rules, v4once, v4chains); err != nil {
return err
}
if netenv.IPv6Enabled() {
if err := activateIPTables(iptables.ProtocolIPv6, v6rules, v6once, v6chains); err != nil {
return err
}
}
if err := nfq.InitNFCT(); err != nil {
return err
}
_ = nfq.DeleteAllMarkedConnection()
return nil
}
// DeactivateNfqueueFirewall drops portmaster related IP tables rules.
// Any errors encountered accumulated into a *multierror.Error.
func DeactivateNfqueueFirewall() error {
// IPv4
var result *multierror.Error
if err := deactivateIPTables(iptables.ProtocolIPv4, v4once, v4chains); err != nil {
result = multierror.Append(result, err)
}
// IPv6
if netenv.IPv6Enabled() {
if err := deactivateIPTables(iptables.ProtocolIPv6, v6once, v6chains); err != nil {
result = multierror.Append(result, err)
}
}
_ = nfq.DeleteAllMarkedConnection()
nfq.TeardownNFCT()
return result.ErrorOrNil()
}
func activateIPTables(protocol iptables.Protocol, rules, once, chains []string) error {
tbls, err := iptables.NewWithProtocol(protocol)
if err != nil {
return err
}
for _, chain := range chains {
splittedRule := strings.Split(chain, " ")
if err = tbls.ClearChain(splittedRule[0], splittedRule[1]); err != nil {
return err
}
}
for _, rule := range rules {
splittedRule := strings.Split(rule, " ")
if err = tbls.Append(splittedRule[0], splittedRule[1], splittedRule[2:]...); err != nil {
return err
}
}
for _, rule := range once {
splittedRule := strings.Split(rule, " ")
ok, err := tbls.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
if err != nil {
return err
}
if !ok {
if err = tbls.Insert(splittedRule[0], splittedRule[1], 1, splittedRule[2:]...); err != nil {
return err
}
}
}
return nil
}
func deactivateIPTables(protocol iptables.Protocol, rules, chains []string) error {
tbls, err := iptables.NewWithProtocol(protocol)
if err != nil {
return err
}
var multierr *multierror.Error
for _, rule := range rules {
splittedRule := strings.Split(rule, " ")
ok, err := tbls.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
if err != nil {
multierr = multierror.Append(multierr, err)
}
if ok {
if err = tbls.Delete(splittedRule[0], splittedRule[1], splittedRule[2:]...); err != nil {
multierr = multierror.Append(multierr, err)
}
}
}
for _, chain := range chains {
splittedRule := strings.Split(chain, " ")
if err = tbls.ClearChain(splittedRule[0], splittedRule[1]); err != nil {
multierr = multierror.Append(multierr, err)
}
if err = tbls.DeleteChain(splittedRule[0], splittedRule[1]); err != nil {
multierr = multierror.Append(multierr, err)
}
}
return multierr.ErrorOrNil()
}
// StartNfqueueInterception starts the nfqueue interception.
func StartNfqueueInterception(packets chan<- packet.Packet) (err error) {
// @deprecated, remove in v1
if experimentalNfqueueBackend {
log.Warningf("[DEPRECATED] --experimental-nfqueue has been deprecated as the backend is now used by default")
log.Warningf("[DEPRECATED] please remove the flag from your configuration!")
}
err = activateNfqueueFirewall()
if err != nil {
_ = StopNfqueueInterception()
return fmt.Errorf("could not initialize nfqueue: %w", err)
}
out4Queue, err = nfq.New(17040, false)
if err != nil {
_ = StopNfqueueInterception()
return fmt.Errorf("nfqueue(IPv4, out): %w", err)
}
in4Queue, err = nfq.New(17140, false)
if err != nil {
_ = StopNfqueueInterception()
return fmt.Errorf("nfqueue(IPv4, in): %w", err)
}
if netenv.IPv6Enabled() {
out6Queue, err = nfq.New(17060, true)
if err != nil {
_ = StopNfqueueInterception()
return fmt.Errorf("nfqueue(IPv6, out): %w", err)
}
in6Queue, err = nfq.New(17160, true)
if err != nil {
_ = StopNfqueueInterception()
return fmt.Errorf("nfqueue(IPv6, in): %w", err)
}
} else {
log.Warningf("interception: no IPv6 stack detected, disabling IPv6 network integration")
out6Queue = &disabledNfQueue{}
in6Queue = &disabledNfQueue{}
}
module.StartServiceWorker("nfqueue packet handler", 0, func(_ context.Context) error {
return handleInterception(packets)
})
return nil
}
// StopNfqueueInterception stops the nfqueue interception.
func StopNfqueueInterception() error {
defer close(shutdownSignal)
if out4Queue != nil {
out4Queue.Destroy()
}
if in4Queue != nil {
in4Queue.Destroy()
}
if out6Queue != nil {
out6Queue.Destroy()
}
if in6Queue != nil {
in6Queue.Destroy()
}
err := DeactivateNfqueueFirewall()
if err != nil {
return fmt.Errorf("interception: error while deactivating nfqueue: %w", err)
}
return nil
}
func handleInterception(packets chan<- packet.Packet) error {
for {
var pkt packet.Packet
select {
case <-shutdownSignal:
return nil
case pkt = <-out4Queue.PacketChannel():
pkt.SetOutbound()
case pkt = <-in4Queue.PacketChannel():
pkt.SetInbound()
case pkt = <-out6Queue.PacketChannel():
pkt.SetOutbound()
case pkt = <-in6Queue.PacketChannel():
pkt.SetInbound()
}
select {
case packets <- pkt:
case <-shutdownSignal:
return nil
}
}
}
type disabledNfQueue struct{}
func (dnfq *disabledNfQueue) PacketChannel() <-chan packet.Packet {
return nil
}
func (dnfq *disabledNfQueue) Destroy() {}

View File

@@ -0,0 +1,67 @@
package interception
import (
"time"
"github.com/safing/portmaster/service/network/packet"
)
type tracedPacket struct {
start time.Time
packet.Packet
}
func tracePacket(p packet.Packet) packet.Packet {
return &tracedPacket{
start: time.Now(),
Packet: p,
}
}
func (p *tracedPacket) markServed(v string) {
if packetMetricsDestination == "" {
return
}
metrics.record(p, v)
}
func (p *tracedPacket) Accept() error {
defer p.markServed("accept")
return p.Packet.Accept()
}
func (p *tracedPacket) Block() error {
defer p.markServed("block")
return p.Packet.Block()
}
func (p *tracedPacket) Drop() error {
defer p.markServed("drop")
return p.Packet.Drop()
}
func (p *tracedPacket) PermanentAccept() error {
defer p.markServed("perm-accept")
return p.Packet.PermanentAccept()
}
func (p *tracedPacket) PermanentBlock() error {
defer p.markServed("perm-block")
return p.Packet.PermanentBlock()
}
func (p *tracedPacket) PermanentDrop() error {
defer p.markServed("perm-drop")
return p.Packet.PermanentDrop()
}
func (p *tracedPacket) RerouteToNameserver() error {
defer p.markServed("reroute-ns")
return p.Packet.RerouteToNameserver()
}
func (p *tracedPacket) RerouteToTunnel() error {
defer p.markServed("reroute-tunnel")
return p.Packet.RerouteToTunnel()
}

View File

@@ -0,0 +1,132 @@
//go:build windows
// +build windows
package windowskext
// This file contains example code how to read bandwidth stats from the kext. Its not ment to be used in production.
import (
"context"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/packet"
)
type Rxtxdata struct {
rx uint64
tx uint64
}
type Key struct {
localIP [4]uint32
remoteIP [4]uint32
localPort uint16
remotePort uint16
ipv6 bool
protocol uint8
}
var m = make(map[Key]Rxtxdata)
func BandwidthStatsWorker(ctx context.Context, collectInterval time.Duration, bandwidthUpdates chan *packet.BandwidthUpdate) error {
// Setup ticker.
ticker := time.NewTicker(collectInterval)
defer ticker.Stop()
// Collect bandwidth at every tick.
for {
select {
case <-ticker.C:
err := reportBandwidth(ctx, bandwidthUpdates)
if err != nil {
return err
}
case <-ctx.Done():
return nil
}
}
}
func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.BandwidthUpdate) error {
stats, err := GetConnectionsStats()
if err != nil {
return err
}
// Report all statistics.
for i, stat := range stats {
connID := packet.CreateConnectionID(
packet.IPProtocol(stat.protocol),
convertArrayToIP(stat.localIP, stat.ipV6 == 1), stat.localPort,
convertArrayToIP(stat.remoteIP, stat.ipV6 == 1), stat.remotePort,
false,
)
update := &packet.BandwidthUpdate{
ConnID: connID,
BytesReceived: stat.receivedBytes,
BytesSent: stat.transmittedBytes,
Method: packet.Additive,
}
select {
case bandwidthUpdates <- update:
case <-ctx.Done():
return nil
default:
log.Warningf("kext: bandwidth update queue is full, skipping rest of batch (%d entries)", len(stats)-i)
return nil
}
}
return nil
}
func StartBandwidthConsoleLogger() {
go func() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for range ticker.C {
conns, err := GetConnectionsStats()
if err != nil {
continue
}
for _, conn := range conns {
if conn.receivedBytes == 0 && conn.transmittedBytes == 0 {
continue
}
key := Key{
localIP: conn.localIP,
remoteIP: conn.remoteIP,
localPort: conn.localPort,
remotePort: conn.remotePort,
ipv6: conn.ipV6 == 1,
protocol: conn.protocol,
}
// First we get a "copy" of the entry
if entry, ok := m[key]; ok {
// Then we modify the copy
entry.rx += conn.receivedBytes
entry.tx += conn.transmittedBytes
// Then we reassign map entry
m[key] = entry
} else {
m[key] = Rxtxdata{
rx: conn.receivedBytes,
tx: conn.transmittedBytes,
}
}
}
log.Debug("----------------------------------")
for key, value := range m {
log.Debugf(
"Conn: %d %s:%d %s:%d rx:%d tx:%d", key.protocol,
convertArrayToIP(key.localIP, key.ipv6), key.localPort,
convertArrayToIP(key.remoteIP, key.ipv6), key.remotePort,
value.rx, value.tx,
)
}
}
}()
}

View File

@@ -0,0 +1,4 @@
// +build windows
// Package windowskext provides network interception capabilities on windows via the Portmaster Kernel Extension.
package windowskext

View File

@@ -0,0 +1,209 @@
//go:build windows
// +build windows
package windowskext
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"time"
"unsafe"
"github.com/safing/portmaster/service/process"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
)
const (
// VerdictRequestFlagFastTrackPermitted is set on packets that have been
// already permitted by the kernel extension and the verdict request is only
// informational.
VerdictRequestFlagFastTrackPermitted = 1
// VerdictRequestFlagSocketAuth indicates that the verdict request is for a
// connection that was intercepted on an ALE layer instead of in the network
// stack itself. Thus, no packet data is available.
VerdictRequestFlagSocketAuth = 2
// VerdictRequestFlagExpectSocketAuth indicates that the next verdict
// requests is expected to be an informational socket auth request from
// the ALE layer.
VerdictRequestFlagExpectSocketAuth = 4
)
// Do not change the order of the members! The structure is used to communicate with the kernel extension.
// VerdictRequest is the request structure from the Kext.
type VerdictRequest struct {
id uint32 // ID from RegisterPacket
pid uint64 // Process ID - info only packets
direction uint8
ipV6 uint8 // True: IPv6, False: IPv4
protocol uint8 // Protocol
flags uint8 // Flags
localIP [4]uint32 // Source Address
remoteIP [4]uint32 // Destination Address
localPort uint16 // Source Port
remotePort uint16 // Destination port
_ uint32 // compartmentID
_ uint32 // interfaceIndex
_ uint32 // subInterfaceIndex
packetSize uint32
}
// Do not change the order of the members! The structure is used to communicate with the kernel extension.
type VerdictInfo struct {
id uint32 // ID from RegisterPacket
verdict network.Verdict // verdict for the connection
}
// Do not change the order of the members! The structure to communicate with the kernel extension.
type VerdictUpdateInfo struct {
localIP [4]uint32 // Source Address, only srcIP[0] if IPv4
remoteIP [4]uint32 // Destination Address
localPort uint16 // Source Port
remotePort uint16 // Destination port
ipV6 uint8 // True: IPv6, False: IPv4
protocol uint8 // Protocol (UDP, TCP, ...)
verdict uint8 // New verdict
}
type ConnectionStat struct {
localIP [4]uint32 //Source Address, only srcIP[0] if IPv4
remoteIP [4]uint32 //Destination Address
localPort uint16 //Source Port
remotePort uint16 //Destination port
receivedBytes uint64 //Number of bytes recived on this connection
transmittedBytes uint64 //Number of bytes transsmited from this connection
ipV6 uint8 //True: IPv6, False: IPv4
protocol uint8 //Protocol (UDP, TCP, ...)
}
type VersionInfo struct {
major uint8
minor uint8
revision uint8
build uint8
}
func (v *VersionInfo) String() string {
return fmt.Sprintf("%d.%d.%d.%d", v.major, v.minor, v.revision, v.build)
}
// Handler transforms received packets to the Packet interface.
func Handler(ctx context.Context, packets chan packet.Packet) {
for {
packetInfo, err := RecvVerdictRequest()
if err != nil {
// Check if we are done with processing.
if errors.Is(err, ErrKextNotReady) {
return
}
log.Warningf("failed to get packet from windows kext: %s", err)
continue
}
if packetInfo == nil {
continue
}
// log.Tracef("packet: %+v", packetInfo)
// New Packet
new := &Packet{
verdictRequest: packetInfo,
verdictSet: abool.NewBool(false),
}
info := new.Info()
info.Inbound = packetInfo.direction > 0
info.InTunnel = false
info.Protocol = packet.IPProtocol(packetInfo.protocol)
info.PID = int(packetInfo.pid)
info.SeenAt = time.Now()
// Check PID
if info.PID == 0 {
// Windows does not have zero PIDs.
// Set to UndefinedProcessID.
info.PID = process.UndefinedProcessID
}
// Set IP version
if packetInfo.ipV6 == 1 {
info.Version = packet.IPv6
} else {
info.Version = packet.IPv4
}
// Set IPs
if info.Inbound {
// Inbound
info.Src = convertArrayToIP(packetInfo.remoteIP, info.Version == packet.IPv6)
info.Dst = convertArrayToIP(packetInfo.localIP, info.Version == packet.IPv6)
} else {
// Outbound
info.Src = convertArrayToIP(packetInfo.localIP, info.Version == packet.IPv6)
info.Dst = convertArrayToIP(packetInfo.remoteIP, info.Version == packet.IPv6)
}
// Set Ports
if info.Inbound {
// Inbound
info.SrcPort = packetInfo.remotePort
info.DstPort = packetInfo.localPort
} else {
// Outbound
info.SrcPort = packetInfo.localPort
info.DstPort = packetInfo.remotePort
}
packets <- new
}
}
// convertArrayToIP converts an array of uint32 values to a net.IP address.
func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP {
if !ipv6 {
addressBuf := make([]byte, 4)
binary.BigEndian.PutUint32(addressBuf, input[0])
return net.IP(addressBuf)
}
addressBuf := make([]byte, 16)
for i := 0; i < 4; i++ {
binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
}
return net.IP(addressBuf)
}
func ipAddressToArray(ip net.IP, isIPv6 bool) [4]uint32 {
array := [4]uint32{0}
if isIPv6 {
for i := 0; i < 4; i++ {
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[i], 4), getUInt32Value(&ip[i]))
}
} else {
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[0], 4), getUInt32Value(&ip[0]))
}
return array
}
func asByteArray[T any](obj *T) []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj))
}
func asByteArrayWithLength[T any](obj *T, size uint32) []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), size)
}
func getUInt32Value[T any](obj *T) uint32 {
return *(*uint32)(unsafe.Pointer(obj))
}

View File

@@ -0,0 +1,344 @@
//go:build windows
// +build windows
package windowskext
import (
"errors"
"fmt"
"sync"
"syscall"
"unsafe"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
"golang.org/x/sys/windows"
)
// Package errors
var (
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension")
kextLock sync.RWMutex
driverPath string
kextHandle windows.Handle
service *KextService
)
const (
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
driverName = "PortmasterKext"
)
// Init initializes the DLL and the Kext (Kernel Driver).
func Init(path string) error {
kextHandle = winInvalidHandleValue
driverPath = path
return nil
}
// Start intercepting.
func Start() error {
kextLock.Lock()
defer kextLock.Unlock()
// initialize and start driver service
var err error
service, err = createKextService(driverName, driverPath)
if err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
running, err := service.isRunning()
if err == nil && !running {
err = service.start(true)
if err != nil {
return fmt.Errorf("failed to start service: %w", err)
}
} else if err != nil {
return fmt.Errorf("service not initialized: %w", err)
}
// Open the driver
filename := `\\.\` + driverName
kextHandle, err = openDriver(filename)
// driver was not installed
if err != nil {
return fmt.Errorf("failed to open driver: %q %w", filename, err)
}
return nil
}
func SetKextHandler(handle windows.Handle) {
kextHandle = handle
}
func SetKextService(handle windows.Handle, path string) {
service = &KextService{handle: handle}
driverPath = path
}
// Stop intercepting.
func Stop() error {
// Prepare kernel for shutdown
err := shutdownRequest()
if err != nil {
log.Warningf("winkext: shutdown request failed: %s", err)
}
kextLock.Lock()
defer kextLock.Unlock()
err = closeDriver(kextHandle)
if err != nil {
log.Warningf("winkext: failed to close the handle: %s", err)
}
err = service.stop(true)
if err != nil {
log.Warningf("winkext: failed to stop service: %s", err)
}
// Driver file may change on the next start so it's better to delete the service
err = service.delete()
if err != nil {
log.Warningf("winkext: failed to delete service: %s", err)
}
kextHandle = winInvalidHandleValue
return nil
}
func shutdownRequest() error {
kextLock.RLock()
defer kextLock.RUnlock()
if kextHandle == winInvalidHandleValue {
return ErrKextNotReady
}
// Sent a shutdown request so the kernel extension can prepare.
_, err := deviceIOControl(kextHandle, IOCTL_SHUTDOWN_REQUEST, nil, nil)
return err
}
// RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil.
func RecvVerdictRequest() (*VerdictRequest, error) {
kextLock.RLock()
defer kextLock.RUnlock()
if kextHandle == winInvalidHandleValue {
return nil, ErrKextNotReady
}
// DEBUG:
// timestamp := time.Now()
// defer log.Tracef("winkext: getting verdict request took %s", time.Since(timestamp))
// Initialize struct for the output data
var new VerdictRequest
// Make driver request
data := asByteArray(&new)
bytesRead, err := deviceIOControl(kextHandle, IOCTL_RECV_VERDICT_REQ, nil, data)
if err != nil {
return nil, err
}
if bytesRead == 0 {
return nil, nil // no error, no new verdict request
}
return &new, nil
}
// SetVerdict sets the verdict for a packet and/or connection.
func SetVerdict(pkt *Packet, verdict network.Verdict) error {
if pkt.verdictRequest.pid != 0 {
return nil // Ignore info only packets
}
if pkt.verdictRequest.id == 0 {
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: no packet ID", verdict)
return ErrNoPacketID
}
kextLock.RLock()
defer kextLock.RUnlock()
if kextHandle == winInvalidHandleValue {
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: kext not ready", verdict)
return ErrKextNotReady
}
verdictInfo := VerdictInfo{pkt.verdictRequest.id, verdict}
// Make driver request
data := asByteArray(&verdictInfo)
_, err := deviceIOControl(kextHandle, IOCTL_SET_VERDICT, data, nil)
if err != nil {
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id)
return err
}
return nil
}
// GetPayload returns the payload of a packet.
func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
if packetID == 0 {
return nil, ErrNoPacketID
}
// Check if driver is initialized
kextLock.RLock()
defer kextLock.RUnlock()
if kextHandle == winInvalidHandleValue {
return nil, ErrKextNotReady
}
buf := make([]byte, packetSize)
// Combine id and length
payload := struct {
id uint32
length uint32
}{packetID, packetSize}
// Make driver request
data := asByteArray(&payload)
bytesRead, err := deviceIOControl(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
if err != nil {
return nil, err
}
// check the result and return
if bytesRead == 0 {
return nil, errors.New("windows kext did not return any data")
}
if bytesRead < uint32(len(buf)) {
return buf[:bytesRead], nil
}
return buf, nil
}
func ClearCache() error {
kextLock.RLock()
defer kextLock.RUnlock()
// Check if driver is initialized
if kextHandle == winInvalidHandleValue {
log.Error("kext: failed to clear the cache: kext not ready")
return ErrKextNotReady
}
// Make driver request
_, err := deviceIOControl(kextHandle, IOCTL_CLEAR_CACHE, nil, nil)
return err
}
func UpdateVerdict(conn *network.Connection) error {
kextLock.RLock()
defer kextLock.RUnlock()
// Check if driver is initialized
if kextHandle == winInvalidHandleValue {
log.Error("kext: failed to clear the cache: kext not ready")
return ErrKextNotReady
}
var isIpv6 uint8 = 0
if conn.IPVersion == packet.IPv6 {
isIpv6 = 1
}
// initialize variables
info := VerdictUpdateInfo{
ipV6: isIpv6,
protocol: uint8(conn.IPProtocol),
localIP: ipAddressToArray(conn.LocalIP, isIpv6 == 1),
localPort: conn.LocalPort,
remoteIP: ipAddressToArray(conn.Entity.IP, isIpv6 == 1),
remotePort: conn.Entity.Port,
verdict: uint8(conn.Verdict),
}
// Make driver request
data := asByteArray(&info)
_, err := deviceIOControl(kextHandle, IOCTL_UPDATE_VERDICT, data, nil)
return err
}
func GetVersion() (*VersionInfo, error) {
kextLock.RLock()
defer kextLock.RUnlock()
// Check if driver is initialized
if kextHandle == winInvalidHandleValue {
log.Error("kext: failed to clear the cache: kext not ready")
return nil, ErrKextNotReady
}
data := make([]uint8, 4)
_, err := deviceIOControl(kextHandle, IOCTL_VERSION, nil, data)
if err != nil {
return nil, err
}
version := &VersionInfo{
major: data[0],
minor: data[1],
revision: data[2],
build: data[3],
}
return version, nil
}
var sizeOfConnectionStat = uint32(unsafe.Sizeof(ConnectionStat{}))
func GetConnectionsStats() ([]ConnectionStat, error) {
kextLock.RLock()
defer kextLock.RUnlock()
// Check if driver is initialized
if kextHandle == winInvalidHandleValue {
log.Error("kext: failed to clear the cache: kext not ready")
return nil, ErrKextNotReady
}
var data [100]ConnectionStat
size := len(data)
bytesReturned, err := deviceIOControl(kextHandle, IOCTL_GET_CONNECTIONS_STAT, asByteArray(&size), asByteArray(&data))
if err != nil {
return nil, err
}
return data[:bytesReturned/sizeOfConnectionStat], nil
}
func openDriver(filename string) (windows.Handle, error) {
u16filename, err := syscall.UTF16FromString(filename)
if err != nil {
return winInvalidHandleValue, fmt.Errorf("failed to convert driver filename to UTF16 string %w", err)
}
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0)
if err != nil {
return winInvalidHandleValue, err
}
return handle, nil
}
func closeDriver(handle windows.Handle) error {
if kextHandle == winInvalidHandleValue {
return ErrKextNotReady
}
return windows.CloseHandle(handle)
}

View File

@@ -0,0 +1,137 @@
//go:build windows
// +build windows
package windowskext
import (
"sync"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
)
// Packet represents an IP packet.
type Packet struct {
packet.Base
verdictRequest *VerdictRequest
verdictSet *abool.AtomicBool
payloadLoaded bool
lock sync.Mutex
}
// FastTrackedByIntegration returns whether the packet has been fast-track
// accepted by the OS integration.
func (pkt *Packet) FastTrackedByIntegration() bool {
return pkt.verdictRequest.flags&VerdictRequestFlagFastTrackPermitted > 0
}
// InfoOnly returns whether the packet is informational only and does not
// represent an actual packet.
func (pkt *Packet) InfoOnly() bool {
return pkt.verdictRequest.flags&VerdictRequestFlagSocketAuth > 0
}
// ExpectInfo returns whether the next packet is expected to be informational only.
func (pkt *Packet) ExpectInfo() bool {
return pkt.verdictRequest.flags&VerdictRequestFlagExpectSocketAuth > 0
}
// GetPayload returns the full raw packet.
func (pkt *Packet) LoadPacketData() error {
pkt.lock.Lock()
defer pkt.lock.Unlock()
if pkt.verdictRequest.id == 0 {
return ErrNoPacketID
}
if !pkt.payloadLoaded {
pkt.payloadLoaded = true
payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize)
if err != nil {
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to load payload: %s", err)
return packet.ErrFailedToLoadPayload
}
err = packet.Parse(payload, &pkt.Base)
if err != nil {
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to parse payload: %s", err)
return packet.ErrFailedToLoadPayload
}
}
if len(pkt.Raw()) == 0 {
return packet.ErrFailedToLoadPayload
}
return nil
}
// Accept accepts the packet.
func (pkt *Packet) Accept() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, -network.VerdictAccept)
}
return nil
}
// Block blocks the packet.
func (pkt *Packet) Block() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, -network.VerdictBlock)
}
return nil
}
// Drop drops the packet.
func (pkt *Packet) Drop() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, -network.VerdictDrop)
}
return nil
}
// PermanentAccept permanently accepts connection (and the current packet).
func (pkt *Packet) PermanentAccept() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictAccept)
}
return nil
}
// PermanentBlock permanently blocks connection (and the current packet).
func (pkt *Packet) PermanentBlock() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictBlock)
}
return nil
}
// PermanentDrop permanently drops connection (and the current packet).
func (pkt *Packet) PermanentDrop() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictDrop)
}
return nil
}
// RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet).
func (pkt *Packet) RerouteToNameserver() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictRerouteToNameserver)
}
return nil
}
// RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet).
func (pkt *Packet) RerouteToTunnel() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictRerouteToTunnel)
}
return nil
}

View File

@@ -0,0 +1,198 @@
//go:build windows
// +build windows
package windowskext
import (
"fmt"
"syscall"
"time"
"github.com/safing/portbase/log"
"golang.org/x/sys/windows"
)
type KextService struct {
handle windows.Handle
}
func createKextService(driverName string, driverPath string) (*KextService, error) {
// Open the service manager:
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
if err != nil {
return nil, fmt.Errorf("failed to open service manager: %d", err)
}
defer windows.CloseServiceHandle(manager)
// Convert the driver name to a UTF16 string
driverNameU16, err := syscall.UTF16FromString(driverName)
if err != nil {
return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err)
}
// Check if there is an old service.
service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS)
if err == nil {
log.Warning("kext: old driver service was found")
oldService := &KextService{handle: service}
err := deleteService(manager, oldService, driverNameU16)
if err != nil {
return nil, fmt.Errorf("failed to delete old driver service: %s", err)
}
service = winInvalidHandleValue
log.Info("kext: old driver service was deleted successfully")
}
driverPathU16, err := syscall.UTF16FromString(driverPath)
// Create the service
service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil)
if err != nil {
return nil, err
}
return &KextService{handle: service}, nil
}
func deleteService(manager windows.Handle, service *KextService, driverName []uint16) error {
// Stop and wait before deleting
_ = service.stop(true)
// Try to delete even if stop failed
err := service.delete()
if err != nil {
return fmt.Errorf("failed to delete old service: %s", err)
}
// Wait until we can no longer open the old service.
// Not very efficient but NotifyServiceStatusChange cannot be used with driver service.
start := time.Now()
timeLimit := time.Duration(30 * time.Second)
for {
handle, err := windows.OpenService(manager, &driverName[0], windows.SERVICE_ALL_ACCESS)
if err != nil {
break
}
_ = windows.CloseServiceHandle(handle)
if time.Since(start) > timeLimit {
return fmt.Errorf("time limit reached")
}
time.Sleep(100 * time.Millisecond)
}
return nil
}
func (s *KextService) isValid() bool {
return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
}
func (s *KextService) isRunning() (bool, error) {
if !s.isValid() {
return false, fmt.Errorf("kext service not initialized")
}
var status windows.SERVICE_STATUS
err := windows.QueryServiceStatus(s.handle, &status)
if err != nil {
return false, err
}
return status.CurrentState == windows.SERVICE_RUNNING, nil
}
func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) {
var status windows.SERVICE_STATUS
status.CurrentState = windows.SERVICE_NO_CHANGE
start := time.Now()
for status.CurrentState != neededStatus {
err := windows.QueryServiceStatus(handle, &status)
if err != nil {
return false, fmt.Errorf("failed while waiting for service to start: %w", err)
}
if time.Since(start) > timeLimit {
return false, fmt.Errorf("time limit reached")
}
// Sleep for 1/10 of the wait hint, recommended time from microsoft
time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond)
}
return true, nil
}
func (s *KextService) start(wait bool) error {
if !s.isValid() {
return fmt.Errorf("kext service not initialized")
}
// Start the service:
err := windows.StartService(s.handle, 0, nil)
if err != nil {
err = windows.GetLastError()
if err != windows.ERROR_SERVICE_ALREADY_RUNNING {
// Failed to start service; clean-up:
var status windows.SERVICE_STATUS
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
_ = windows.DeleteService(s.handle)
_ = windows.CloseServiceHandle(s.handle)
s.handle = winInvalidHandleValue
return err
}
}
// Wait for service to start
if wait {
success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second))
if err != nil || !success {
return fmt.Errorf("service did not start: %w", err)
}
}
return nil
}
func (s *KextService) stop(wait bool) error {
if !s.isValid() {
return fmt.Errorf("kext service not initialized")
}
// Stop the service
var status windows.SERVICE_STATUS
err := windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
if err != nil {
return fmt.Errorf("service failed to stop: %w", err)
}
// Wait for service to stop
if wait {
success, err := waitForServiceStatus(s.handle, windows.SERVICE_STOPPED, time.Duration(10*time.Second))
if err != nil || !success {
return fmt.Errorf("service did not stop: %w", err)
}
}
return nil
}
func (s *KextService) delete() error {
if !s.isValid() {
return fmt.Errorf("kext service not initialized")
}
err := windows.DeleteService(s.handle)
if err != nil {
return fmt.Errorf("failed to delete service: %s", err)
}
// Service wont be deleted until all handles are closed.
err = windows.CloseServiceHandle(s.handle)
if err != nil {
return fmt.Errorf("failed to close service handle: %s", err)
}
s.handle = winInvalidHandleValue
return nil
}

View File

@@ -0,0 +1,72 @@
//go:build windows
// +build windows
package windowskext
import "golang.org/x/sys/windows"
const (
METHOD_BUFFERED = 0
METHOD_IN_DIRECT = 1
METHOD_OUT_DIRECT = 2
METHOD_NEITHER = 3
SIOCTL_TYPE = 40000
)
var (
IOCTL_VERSION = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_SHUTDOWN_REQUEST = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_UPDATE_VERDICT = ctlCode(SIOCTL_TYPE, 0x806, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_GET_CONNECTIONS_STAT = ctlCode(SIOCTL_TYPE, 0x807, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
)
func ctlCode(device_type, function, method, access uint32) uint32 {
return (device_type << 16) | (access << 14) | (function << 2) | method
}
func deviceIOControlAsync(handle windows.Handle, code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) {
var inDataPtr *byte = nil
var inDataSize uint32 = 0
if inData != nil {
inDataPtr = &inData[0]
inDataSize = uint32(len(inData))
}
var outDataPtr *byte = nil
var outDataSize uint32 = 0
if outData != nil {
outDataPtr = &outData[0]
outDataSize = uint32(len(outData))
}
overlapped := &windows.Overlapped{}
err := windows.DeviceIoControl(handle,
code,
inDataPtr, inDataSize,
outDataPtr, outDataSize,
nil, overlapped)
if err != nil {
return nil, err
}
return overlapped, nil
}
func deviceIOControl(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
overlapped, err := deviceIOControlAsync(handle, code, inData, outData)
if err != nil {
return 0, err
}
var bytesReturned uint32
err = windows.GetOverlappedResult(handle, overlapped, &bytesReturned, true)
return bytesReturned, err
}

649
service/firewall/master.go Normal file
View File

@@ -0,0 +1,649 @@
package firewall
import (
"context"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"github.com/agext/levenshtein"
"golang.org/x/net/publicsuffix"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/detection/dga"
"github.com/safing/portmaster/service/intel/customlists"
"github.com/safing/portmaster/service/intel/filterlists"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/service/profile/endpoints"
)
const noReasonOptionKey = ""
type deciderFn func(context.Context, *network.Connection, *profile.LayeredProfile, packet.Packet) bool
var defaultDeciders = []deciderFn{
checkPortmasterConnection,
checkIfBroadcastReply,
checkConnectionType,
checkConnectionScope,
checkEndpointLists,
checkInvalidIP,
checkResolverScope,
checkConnectivityDomain,
checkBypassPrevention,
checkFilterLists,
checkCustomFilterList,
checkDomainHeuristics,
checkAutoPermitRelated,
}
// decideOnConnection makes a decision about a connection.
// When called, the connection and profile is already locked.
func decideOnConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) {
// Check if we have a process and profile.
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
conn.Deny("unknown process or profile", noReasonOptionKey)
return
}
// Check if the layered profile needs updating.
if layeredProfile.NeedsUpdate() {
// Update revision counter in connection.
conn.ProfileRevisionCounter = layeredProfile.Update(
conn.Process().MatchingData(),
conn.Process().CreateProfileCallback,
)
conn.SaveWhenFinished()
// Reset verdict for connection.
log.Tracer(ctx).Infof("filter: profile updated, re-evaluating verdict of %s", conn)
// Reset entity if it exists.
if conn.Entity != nil {
conn.Entity.ResetLists()
}
} else {
// Check if the revision counter of the connection needs updating.
revCnt := layeredProfile.RevisionCnt()
if conn.ProfileRevisionCounter != revCnt {
conn.ProfileRevisionCounter = revCnt
conn.SaveWhenFinished()
}
}
// prepare the entity and resolve all filterlist matches
conn.Entity.ResolveSubDomainLists(ctx, layeredProfile.FilterSubDomains())
conn.Entity.EnableCNAMECheck(ctx, layeredProfile.FilterCNAMEs())
conn.Entity.LoadLists(ctx)
// Run all deciders and return if they came to a conclusion.
done, defaultAction := runDeciders(ctx, defaultDeciders, conn, layeredProfile, pkt)
if done {
return
}
// DNS Request are always default allowed, as the endpoint lists could not
// be checked fully.
if conn.Type == network.DNSRequest {
conn.Accept("allowing dns request", noReasonOptionKey)
return
}
// Deciders did not conclude, use default action.
switch defaultAction {
case profile.DefaultActionPermit:
conn.Accept("allowed by default action", profile.CfgOptionDefaultActionKey)
case profile.DefaultActionAsk:
// Only prompt if there has not been a decision already.
// This prevents prompts from being created when re-evaluating connections.
if conn.Verdict == network.VerdictUndecided {
prompt(ctx, conn)
}
default:
conn.Deny("blocked by default action", profile.CfgOptionDefaultActionKey)
}
}
func runDeciders(ctx context.Context, selectedDeciders []deciderFn, conn *network.Connection, layeredProfile *profile.LayeredProfile, pkt packet.Packet) (done bool, defaultAction uint8) {
// Read-lock all the profiles.
layeredProfile.LockForUsage()
defer layeredProfile.UnlockForUsage()
// Go though all deciders, return if one sets an action.
for _, decider := range selectedDeciders {
if decider(ctx, conn, layeredProfile, pkt) {
return true, profile.DefaultActionNotSet
}
}
// Return the default action.
return false, layeredProfile.DefaultAction()
}
// checkPortmasterConnection allows all connection that originate from
// portmaster itself.
func checkPortmasterConnection(ctx context.Context, conn *network.Connection, _ *profile.LayeredProfile, _ packet.Packet) bool {
// Grant own outgoing or local connections.
// Blocking our own connections can lead to a very literal deadlock.
// This can currently happen, as fast-tracked connections are also
// reset in the OS integration and might show up in the connection
// handling if a packet in the other direction hits the firewall first.
// Ignore other processes.
if conn.Process().Pid != ownPID {
return false
}
// Ignore inbound connection if non-local.
if conn.Inbound {
myIP, err := netenv.IsMyIP(conn.Entity.IP)
if err != nil {
log.Tracer(ctx).Debugf("filter: failed to check if %s is own IP for granting own connection: %s", conn.Entity.IP, err)
return false
}
if !myIP {
return false
}
}
log.Tracer(ctx).Infof("filter: granting own connection %s", conn)
conn.Accept("connection by Portmaster", noReasonOptionKey)
conn.Internal = true
return true
}
func checkIfBroadcastReply(ctx context.Context, conn *network.Connection, _ *profile.LayeredProfile, _ packet.Packet) bool {
// Only check inbound connections.
if !conn.Inbound {
return false
}
// Only check if the process has been identified.
if !conn.Process().IsIdentified() {
return false
}
// Check if the remote IP is part of a local network.
localNet, err := netenv.GetLocalNetwork(conn.Entity.IP)
if err != nil {
log.Tracer(ctx).Warningf("filter: failed to get local network: %s", err)
return false
}
if localNet == nil {
return false
}
// Search for a matching requesting connection.
requestingConn := network.GetMulticastRequestConn(conn, localNet)
if requestingConn == nil {
return false
}
conn.Accept(
fmt.Sprintf(
"response to multi/broadcast query to %s/%s",
packet.IPProtocol(requestingConn.Entity.Protocol),
net.JoinHostPort(
requestingConn.Entity.IP.String(),
strconv.Itoa(int(requestingConn.Entity.Port)),
),
),
"",
)
return true
}
func checkEndpointLists(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
// DNS request from the system resolver require a special decision process,
// because the original requesting process is not known. Here, we only check
// global-only and the most important per-app aspects. The resulting
// connection is then blocked when the original requesting process is known.
if conn.Type == network.DNSRequest && conn.Process().IsSystemResolver() {
return checkEndpointListsForSystemResolverDNSRequests(ctx, conn, p)
}
var result endpoints.EPResult
var reason endpoints.Reason
// check endpoints list
var optionKey string
if conn.Inbound {
result, reason = p.MatchServiceEndpoint(ctx, conn.Entity)
optionKey = profile.CfgOptionServiceEndpointsKey
} else {
result, reason = p.MatchEndpoint(ctx, conn.Entity)
optionKey = profile.CfgOptionEndpointsKey
}
switch result {
case endpoints.Denied, endpoints.MatchError:
conn.DenyWithContext(reason.String(), optionKey, reason.Context())
return true
case endpoints.Permitted:
conn.AcceptWithContext(reason.String(), optionKey, reason.Context())
return true
case endpoints.NoMatch:
return false
}
return false
}
// checkEndpointListsForSystemResolverDNSRequests is a special version of
// checkEndpointLists that is only meant for DNS queries by the system
// resolver. It only checks the endpoint filter list of the local profile and
// does not include the global profile.
func checkEndpointListsForSystemResolverDNSRequests(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile) bool {
var profileEndpoints endpoints.Endpoints
var optionKey string
if conn.Inbound {
profileEndpoints = p.LocalProfileWithoutLocking().GetServiceEndpoints()
optionKey = profile.CfgOptionServiceEndpointsKey
} else {
profileEndpoints = p.LocalProfileWithoutLocking().GetEndpoints()
optionKey = profile.CfgOptionEndpointsKey
}
if profileEndpoints.IsSet() {
result, reason := profileEndpoints.Match(ctx, conn.Entity)
if endpoints.IsDecision(result) {
switch result {
case endpoints.Denied, endpoints.MatchError:
conn.DenyWithContext(reason.String(), optionKey, reason.Context())
return true
case endpoints.Permitted:
conn.AcceptWithContext(reason.String(), optionKey, reason.Context())
return true
case endpoints.NoMatch:
return false
}
}
}
return false
}
var p2pFilterLists = []string{"17-P2P"}
func checkConnectionType(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
switch {
// Block incoming connection, if not from localhost.
case p.BlockInbound() && conn.Inbound &&
!conn.Entity.IPScope.IsLocalhost():
conn.Drop("inbound connections blocked", profile.CfgOptionBlockInboundKey)
return true
// Check for P2P and related connections.
case p.BlockP2P() && !conn.Inbound:
switch {
// Block anything that is in the P2P filter list.
case conn.Entity.MatchLists(p2pFilterLists):
conn.Block("P2P assistive infrastructure blocked based on filter list", profile.CfgOptionBlockP2PKey)
return true
// Remaining P2P deciders only apply to IP connections.
case conn.Type != network.IPConnection:
return false
// Block well known ports of P2P assistive infrastructure.
case conn.Entity.DstPort() == 3478 || // STUN/TURN
conn.Entity.DstPort() == 5349: // STUN/TURN over TLS/DTLS
conn.Block("P2P assistive infrastructure blocked based on port", profile.CfgOptionBlockP2PKey)
return true
// Block direct connections with not previous DNS request.
case conn.Entity.IPScope.IsGlobal() &&
conn.Entity.Domain == "":
conn.Block("direct connections (P2P) blocked", profile.CfgOptionBlockP2PKey)
return true
default:
return false
}
default:
return false
}
}
func checkConnectivityDomain(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
switch {
case conn.Entity.Domain == "":
// Only applies if a domain is available.
return false
case netenv.GetOnlineStatus() > netenv.StatusPortal:
// Special grant only applies if network status is Portal (or even more limited).
return false
case conn.Inbound:
// Special grant only applies to outgoing connections.
return false
case p.BlockScopeInternet():
// Special grant only applies if application is allowed to connect to the Internet.
return false
case netenv.IsConnectivityDomain(conn.Entity.Domain):
// Special grant!
conn.Accept("special grant for connectivity domain during network bootstrap", noReasonOptionKey)
return true
default:
// Not a special grant domain
return false
}
}
func checkConnectionScope(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
// If we are handling a DNS request, check if we can immediately block it.
if conn.Type == network.DNSRequest {
// DNS is expected to resolve to LAN or Internet addresses.
// Localhost queries are immediately responded to by the nameserver.
if p.BlockScopeInternet() && p.BlockScopeLAN() {
conn.Block("Internet and LAN access blocked", profile.CfgOptionBlockScopeInternetKey)
return true
}
return false
}
// Check if the network scope is permitted.
switch conn.Entity.IPScope {
case netutils.Global, netutils.GlobalMulticast:
if p.BlockScopeInternet() {
conn.Deny("Internet access blocked", profile.CfgOptionBlockScopeInternetKey) // Block Outbound / Drop Inbound
return true
}
case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast:
if p.BlockScopeLAN() {
conn.Block("LAN access blocked", profile.CfgOptionBlockScopeLANKey) // Block Outbound / Drop Inbound
return true
}
case netutils.HostLocal:
if p.BlockScopeLocal() {
conn.Block("Localhost access blocked", profile.CfgOptionBlockScopeLocalKey) // Block Outbound / Drop Inbound
return true
}
case netutils.Undefined, netutils.Invalid:
// Block Invalid / Undefined IPs _after_ the rules.
return false
default:
conn.Deny("invalid IP", noReasonOptionKey) // Block Outbound / Drop Inbound
return true
}
return false
}
func checkInvalidIP(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
// Only applies to IP connections.
if conn.Type != network.IPConnection {
return false
}
// Block Invalid / Undefined IPs.
switch conn.Entity.IPScope { //nolint:exhaustive // Only looking for specific values.
case netutils.Undefined, netutils.Invalid:
conn.Deny("invalid IP", noReasonOptionKey) // Block Outbound / Drop Inbound
return true
}
return false
}
func checkBypassPrevention(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
if p.PreventBypassing() {
// check for bypass protection
result, reason, reasonCtx := PreventBypassing(ctx, conn)
switch result {
case endpoints.Denied, endpoints.MatchError:
// Also block on MatchError to be on the safe side.
// PreventBypassing does not use any data that needs to be loaded, so it should not fail anyway.
conn.BlockWithContext("bypass prevention: "+reason, profile.CfgOptionPreventBypassingKey, reasonCtx)
return true
case endpoints.Permitted:
conn.AcceptWithContext("bypass prevention: "+reason, profile.CfgOptionPreventBypassingKey, reasonCtx)
return true
case endpoints.NoMatch:
return false
}
}
return false
}
func checkFilterLists(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
// apply privacy filter lists
result, reason := p.MatchFilterLists(ctx, conn.Entity)
switch result {
case endpoints.Denied:
// If the connection matches a filter list, check if the "unbreak" list matches too and abort blocking.
resolvedUnbreakFilterListIDs := filterlists.GetUnbreakFilterListIDs()
for _, blockedListID := range conn.Entity.BlockedByLists {
for _, unbreakListID := range resolvedUnbreakFilterListIDs {
if blockedListID == unbreakListID {
log.Tracer(ctx).Debugf("filter: unbreak filter %s matched, ignoring other filter list matches", unbreakListID)
return false
}
}
}
// Otherwise, continue with blocking.
conn.DenyWithContext(reason.String(), profile.CfgOptionFilterListsKey, reason.Context())
return true
case endpoints.NoMatch:
// nothing to do
case endpoints.Permitted, endpoints.MatchError:
fallthrough
default:
log.Tracer(ctx).Debugf("filter: filter lists returned unsupported verdict: %s", result)
}
return false
}
func checkResolverScope(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
// If the IP address was resolved, check the scope of the resolver.
switch {
case conn.Type != network.IPConnection:
// Only applies to IP connections.
case !p.RemoveOutOfScopeDNS():
// Out of scope checking is not active.
case conn.Resolver == nil:
// IP address of connection was not resolved.
case conn.Resolver.IPScope.IsGlobal() &&
(conn.Entity.IPScope.IsLAN() || conn.Entity.IPScope.IsLocalhost()):
// Block global resolvers from returning LAN/Localhost IPs.
conn.Block("DNS server horizon violation: global DNS server returned local IP address", profile.CfgOptionRemoveOutOfScopeDNSKey)
return true
case conn.Resolver.IPScope.IsLAN() &&
conn.Entity.IPScope.IsLocalhost():
// Block LAN resolvers from returning Localhost IPs.
conn.Block("DNS server horizon violation: LAN DNS server returned localhost IP address", profile.CfgOptionRemoveOutOfScopeDNSKey)
return true
}
return false
}
func checkDomainHeuristics(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
// Don't check domain heuristics if no domain is available.
if conn.Entity.Domain == "" {
return false
}
// Check if domain heuristics are enabled.
if !p.DomainHeuristics() {
return false
}
trimmedDomain := strings.TrimRight(conn.Entity.Domain, ".")
etld1, err := publicsuffix.EffectiveTLDPlusOne(trimmedDomain)
if err != nil {
// Don't run the check if the domain is a TLD.
return false
}
domainToCheck := strings.Split(etld1, ".")[0]
score := dga.LmsScore(domainToCheck)
if score < 5 {
log.Tracer(ctx).Debugf(
"filter: possible data tunnel by %s in eTLD+1 %s: %s has an lms score of %.2f",
conn.Process(),
etld1,
domainToCheck,
score,
)
conn.Block("possible DGA domain commonly used by malware", profile.CfgOptionDomainHeuristicsKey)
return true
}
log.Tracer(ctx).Tracef("filter: LMS score of eTLD+1 %s is %.2f", etld1, score)
// 100 is a somewhat arbitrary threshold to ensure we don't mess
// around with CDN domain names to early. They use short second-level
// domains that would trigger LMS checks but are to small to actually
// exfiltrate data.
if len(conn.Entity.Domain) > len(etld1)+100 {
domainToCheck = trimmedDomain[0:len(etld1)]
score := dga.LmsScoreOfDomain(domainToCheck)
if score < 10 {
log.Tracer(ctx).Debugf(
"filter: possible data tunnel by %s in subdomain of %s: %s has an lms score of %.2f",
conn.Process(),
conn.Entity.Domain,
domainToCheck,
score,
)
conn.Block("possible data tunnel for covert communication and protection bypassing", profile.CfgOptionDomainHeuristicsKey)
return true
}
log.Tracer(ctx).Tracef("filter: LMS score of entire domain is %.2f", score)
}
return false
}
func checkAutoPermitRelated(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
// Auto permit is disabled for default action permit.
if p.DefaultAction() == profile.DefaultActionPermit {
return false
}
// Check if auto permit is disabled.
if p.DisableAutoPermit() {
return false
}
// Check for relation to auto permit.
related, reason := checkRelation(conn)
if related {
conn.Accept(reason, profile.CfgOptionDisableAutoPermitKey)
return true
}
return false
}
// checkRelation tries to find a relation between a process and a communication. This is for better out of the box experience and is _not_ meant to thwart intentional malware.
func checkRelation(conn *network.Connection) (related bool, reason string) {
// Don't check relation if no domain is available.
if conn.Entity.Domain == "" {
return false, ""
}
// Don't check for unknown processes.
if conn.Process().Pid < 0 {
return false, ""
}
pathElements := strings.Split(conn.Process().Path, string(filepath.Separator))
// only look at the last two path segments
if len(pathElements) > 2 {
pathElements = pathElements[len(pathElements)-2:]
}
domainElements := strings.Split(conn.Entity.Domain, ".")
var domainElement string
var processElement string
matchLoop:
for _, domainElement = range domainElements {
for _, pathElement := range pathElements {
if levenshtein.Match(domainElement, pathElement, nil) > 0.5 {
related = true
processElement = pathElement
break matchLoop
}
}
if levenshtein.Match(domainElement, conn.Process().Name, nil) > 0.5 {
related = true
processElement = conn.Process().Name
break matchLoop
}
if levenshtein.Match(domainElement, conn.Process().ExecName, nil) > 0.5 {
related = true
processElement = conn.Process().ExecName
break matchLoop
}
}
if related {
reason = fmt.Sprintf("auto allowed: domain is related to process: %s is related to %s", domainElement, processElement)
}
return related, reason
}
func checkCustomFilterList(_ context.Context, conn *network.Connection, p *profile.LayeredProfile, _ packet.Packet) bool {
// Check if any custom list is loaded at all.
if !customlists.IsLoaded() {
return false
}
// block if the domain name appears in the custom filter list (check for subdomains if enabled)
if conn.Entity.Domain != "" {
if ok, match := customlists.LookupDomain(conn.Entity.Domain, p.FilterSubDomains()); ok {
conn.Deny(fmt.Sprintf("domain %s matches %s in custom filter list", conn.Entity.Domain, match), customlists.CfgOptionCustomListFileKey)
return true
}
}
// block if any of the CNAME appears in the custom filter list (check for subdomains if enabled)
if p.FilterCNAMEs() {
for _, cname := range conn.Entity.CNAME {
if ok, match := customlists.LookupDomain(cname, p.FilterSubDomains()); ok {
conn.Deny(fmt.Sprintf("domain alias (CNAME) %s matches %s in custom filter list", cname, match), customlists.CfgOptionCustomListFileKey)
return true
}
}
}
// block if ip addresses appears in the custom filter list
if conn.Entity.IP != nil {
if customlists.LookupIP(conn.Entity.IP) {
conn.Deny("IP address is in the custom filter list", customlists.CfgOptionCustomListFileKey)
return true
}
}
// block autonomous system by its number if it appears in the custom filter list
if conn.Entity.ASN != 0 {
if customlists.LookupASN(conn.Entity.ASN) {
conn.Deny("AS is in the custom filter list", customlists.CfgOptionCustomListFileKey)
return true
}
}
// block if the country appears in the custom filter list
if conn.Entity.Country != "" {
if customlists.LookupCountry(conn.Entity.Country) {
conn.Deny("country is in the custom filter list", customlists.CfgOptionCustomListFileKey)
return true
}
}
return false
}

165
service/firewall/module.go Normal file
View File

@@ -0,0 +1,165 @@
package firewall
import (
"context"
"flag"
"fmt"
"path/filepath"
"strings"
"github.com/safing/portbase/config"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
_ "github.com/safing/portmaster/service/core"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/spn/access"
"github.com/safing/portmaster/spn/captain"
)
type stringSliceFlag []string
func (ss *stringSliceFlag) String() string {
return strings.Join(*ss, ":")
}
func (ss *stringSliceFlag) Set(value string) error {
*ss = append(*ss, filepath.Clean(value))
return nil
}
var (
module *modules.Module
allowedClients stringSliceFlag
)
func init() {
module = modules.Register("filter", prep, start, stop, "core", "interception", "intel", "netquery")
subsystems.Register(
"filter",
"Privacy Filter",
"DNS and Network Filter",
module,
"config:filter/",
nil,
)
flag.Var(&allowedClients, "allowed-clients", "A list of binaries that are allowed to connect to the Portmaster API")
}
func prep() error {
network.SetDefaultFirewallHandler(verdictHandler)
// Reset connections every time configuration changes
// this will be triggered on spn enable/disable
err := module.RegisterEventHook(
"config",
config.ChangeEvent,
"reset connection verdicts after global config change",
func(ctx context.Context, _ interface{}) error {
resetAllConnectionVerdicts()
return nil
},
)
if err != nil {
log.Errorf("filter: failed to register event hook: %s", err)
}
// Reset connections every time profile changes
err = module.RegisterEventHook(
"profiles",
profile.ConfigChangeEvent,
"reset connection verdicts after profile config change",
func(ctx context.Context, eventData interface{}) error {
// Expected event data: scoped profile ID.
profileID, ok := eventData.(string)
if !ok {
return fmt.Errorf("event data is not a string: %v", eventData)
}
profileSource, profileID, ok := strings.Cut(profileID, "/")
if !ok {
return fmt.Errorf("event data does not seem to be a scoped profile ID: %v", eventData)
}
resetProfileConnectionVerdict(profileSource, profileID)
return nil
},
)
if err != nil {
log.Errorf("filter: failed to register event hook: %s", err)
}
// Reset connections when spn is connected
// connect and disconnecting is triggered on config change event but connecting takеs more time
err = module.RegisterEventHook(
"captain",
captain.SPNConnectedEvent,
"reset connection verdicts on SPN connect",
func(ctx context.Context, _ interface{}) error {
resetAllConnectionVerdicts()
return nil
},
)
if err != nil {
log.Errorf("filter: failed to register event hook: %s", err)
}
// Reset connections when account is updated.
// This will not change verdicts, but will update the feature flags on connections.
err = module.RegisterEventHook(
"access",
access.AccountUpdateEvent,
"update connection feature flags after account update",
func(ctx context.Context, _ interface{}) error {
resetAllConnectionVerdicts()
return nil
},
)
if err != nil {
log.Errorf("filter: failed to register event hook: %s", err)
}
err = module.RegisterEventHook(
"network",
network.ConnectionReattributedEvent,
"reset verdict of re-attributed connection",
func(ctx context.Context, eventData interface{}) error {
// Expected event data: connection ID.
connID, ok := eventData.(string)
if !ok {
return fmt.Errorf("event data is not a string: %v", eventData)
}
resetSingleConnectionVerdict(connID)
return nil
},
)
if err != nil {
log.Errorf("filter: failed to register event hook: %s", err)
}
if err := registerConfig(); err != nil {
return err
}
return prepAPIAuth()
}
func start() error {
getConfig()
startAPIAuth()
module.StartServiceWorker("packet handler", 0, packetHandler)
module.StartServiceWorker("bandwidth update handler", 0, bandwidthUpdateHandler)
// Start stat logger if logging is set to trace.
if log.GetLogLevel() == log.TraceLevel {
module.StartServiceWorker("stat logger", 0, statLogger)
}
return nil
}
func stop() error {
return nil
}

View File

@@ -0,0 +1,742 @@
package firewall
import (
"context"
"errors"
"fmt"
"net"
"os"
"sync/atomic"
"time"
"github.com/google/gopacket/layers"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/compat"
_ "github.com/safing/portmaster/service/core/base"
"github.com/safing/portmaster/service/firewall/inspection"
"github.com/safing/portmaster/service/firewall/interception"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/netquery"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/network/reference"
"github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/spn/access"
)
var (
nameserverIPMatcher func(ip net.IP) bool
nameserverIPMatcherSet = abool.New()
nameserverIPMatcherReady = abool.New()
packetsAccepted = new(uint64)
packetsBlocked = new(uint64)
packetsDropped = new(uint64)
packetsFailed = new(uint64)
blockedIPv4 = net.IPv4(0, 0, 0, 17)
blockedIPv6 = net.ParseIP("::17")
ownPID = os.Getpid()
)
func resetSingleConnectionVerdict(connID string) {
// Create tracing context.
ctx, tracer := log.AddTracer(context.Background())
defer tracer.Submit()
conn, ok := network.GetConnection(connID)
if !ok {
conn, ok = network.GetDNSConnection(connID)
if !ok {
tracer.Debugf("filter: could not find re-attributed connection %s for re-evaluation", connID)
return
}
}
resetConnectionVerdict(ctx, conn)
}
func resetProfileConnectionVerdict(profileSource, profileID string) {
// Create tracing context.
ctx, tracer := log.AddTracer(context.Background())
defer tracer.Submit()
// Resetting will force all the connection to be evaluated by the firewall again
// this will set new verdicts if configuration was update or spn has been disabled or enabled.
tracer.Infof("filter: re-evaluating connections of %s/%s", profileSource, profileID)
// Re-evaluate all connections.
var changedVerdicts int
for _, conn := range network.GetAllConnections() {
// Check if connection is complete and attributed to the deleted profile.
if conn.DataIsComplete() &&
conn.ProcessContext.Profile == profileID &&
conn.ProcessContext.Source == profileSource {
if resetConnectionVerdict(ctx, conn) {
changedVerdicts++
}
}
}
tracer.Infof("filter: changed verdict on %d connections", changedVerdicts)
}
func resetAllConnectionVerdicts() {
// Create tracing context.
ctx, tracer := log.AddTracer(context.Background())
defer tracer.Submit()
// Resetting will force all the connection to be evaluated by the firewall again
// this will set new verdicts if configuration was update or spn has been disabled or enabled.
tracer.Info("filter: re-evaluating all connections")
// Re-evaluate all connections.
var changedVerdicts int
for _, conn := range network.GetAllConnections() {
// Skip incomplete connections.
if !conn.DataIsComplete() {
continue
}
if resetConnectionVerdict(ctx, conn) {
changedVerdicts++
}
}
tracer.Infof("filter: changed verdict on %d connections", changedVerdicts)
}
func resetConnectionVerdict(ctx context.Context, conn *network.Connection) (verdictChanged bool) {
tracer := log.Tracer(ctx)
// Remove any active prompt as we settings are being re-evaluated.
conn.RemovePrompt()
conn.Lock()
defer conn.Unlock()
// Update feature flags.
if err := conn.UpdateFeatures(); err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
tracer.Warningf("filter: failed to update connection feature flags: %s", err)
}
// Skip internal connections:
// - Pre-authenticated connections from Portmaster
// - Redirected DNS requests
// - SPN Uplink to Home Hub
if conn.Internal {
// tracer.Tracef("filter: skipping internal connection %s", conn)
return false
}
tracer.Debugf("filter: re-evaluating verdict of %s", conn)
previousVerdict := conn.Verdict
// Apply privacy filter and check tunneling.
FilterConnection(ctx, conn, nil, true, true)
// Stop existing SPN tunnel if not needed anymore.
if conn.Verdict != network.VerdictRerouteToTunnel && conn.TunnelContext != nil {
err := conn.TunnelContext.StopTunnel()
if err != nil {
tracer.Debugf("filter: failed to stopped unneeded tunnel: %s", err)
}
}
// Save if verdict changed.
if conn.Verdict != previousVerdict {
err := interception.UpdateVerdictOfConnection(conn)
if err != nil {
log.Debugf("filter: failed to update connection verdict: %s", err)
}
conn.Save()
tracer.Infof("filter: verdict of connection %s changed from %s to %s", conn, previousVerdict.Verb(), conn.VerdictVerb())
// Update verdict in OS integration, if an IP connection.
if conn.Type == network.IPConnection {
err := interception.UpdateVerdictOfConnection(conn)
if err != nil {
log.Debugf("filter: failed to update connection verdict: %s", err)
}
}
return true
}
tracer.Tracef("filter: verdict to connection %s unchanged at %s", conn, conn.VerdictVerb())
return false
}
// SetNameserverIPMatcher sets a function that is used to match the internal
// nameserver IP(s). Can only bet set once.
func SetNameserverIPMatcher(fn func(ip net.IP) bool) error {
if !nameserverIPMatcherSet.SetToIf(false, true) {
return errors.New("nameserver IP matcher already set")
}
nameserverIPMatcher = fn
nameserverIPMatcherReady.Set()
return nil
}
func handlePacket(pkt packet.Packet) {
// First, check for an existing connection.
conn, ok := network.GetConnection(pkt.GetConnectionID())
if ok {
// Add packet to connection handler queue or apply verdict directly.
conn.HandlePacket(pkt)
return
}
// Else create new incomplete connection from the packet and start the new handler.
conn = network.NewIncompleteConnection(pkt)
conn.Lock()
defer conn.Unlock()
conn.SetFirewallHandler(fastTrackHandler)
// Let the new connection handler worker handle the packet.
conn.HandlePacket(pkt)
}
// fastTrackedPermit quickly permits certain network critical or internal connections.
func fastTrackedPermit(conn *network.Connection, pkt packet.Packet) (verdict network.Verdict, permanent bool) {
meta := pkt.Info()
// Check if packed was already fast-tracked by the OS integration.
if pkt.FastTrackedByIntegration() {
log.Tracer(pkt.Ctx()).Debugf("filter: fast-tracked by OS integration: %s", pkt)
return network.VerdictAccept, true
}
// Check if connection was already blocked.
if meta.Dst.Equal(blockedIPv4) || meta.Dst.Equal(blockedIPv6) {
return network.VerdictBlock, true
}
// Some programs do a network self-check where they connects to the same
// IP/Port to test network capabilities.
// Eg. dig: https://gitlab.isc.org/isc-projects/bind9/-/issues/1140
if meta.SrcPort == meta.DstPort &&
meta.Src.Equal(meta.Dst) {
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track network self-check: %s", pkt)
return network.VerdictAccept, true
}
switch meta.Protocol { //nolint:exhaustive // Checking for specific values only.
case packet.ICMP, packet.ICMPv6:
// Load packet data.
err := pkt.LoadPacketData()
if err != nil {
log.Tracer(pkt.Ctx()).Debugf("filter: failed to load ICMP packet data: %s", err)
return network.VerdictAccept, true
}
// Submit to ICMP listener.
submitted := netenv.SubmitPacketToICMPListener(pkt)
if submitted {
// If the packet was submitted to the listener, we must not do a
// permanent accept, because then we won't see any future packets of that
// connection and thus cannot continue to submit them.
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track tracing ICMP/v6: %s", pkt)
return network.VerdictAccept, false
}
// Handle echo request and replies regularly.
// Other ICMP packets are considered system business.
icmpLayers := pkt.Layers().LayerClass(layers.LayerClassIPControl)
switch icmpLayer := icmpLayers.(type) {
case *layers.ICMPv4:
switch icmpLayer.TypeCode.Type() {
case layers.ICMPv4TypeEchoRequest,
layers.ICMPv4TypeEchoReply:
return network.VerdictUndecided, false
}
case *layers.ICMPv6:
switch icmpLayer.TypeCode.Type() {
case layers.ICMPv6TypeEchoRequest,
layers.ICMPv6TypeEchoReply:
return network.VerdictUndecided, false
}
}
// Permit all ICMP/v6 packets that are not echo requests or replies.
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting ICMP/v6: %s", pkt)
return network.VerdictAccept, true
case packet.UDP, packet.TCP:
switch meta.DstPort {
case 67, 68, 546, 547:
// Always allow DHCP, DHCPv6.
// DHCP and DHCPv6 must be UDP.
if meta.Protocol != packet.UDP {
return network.VerdictUndecided, false
}
// DHCP is only valid in local network scopes.
switch netutils.ClassifyIP(meta.Dst) { //nolint:exhaustive // Checking for specific values only.
case netutils.HostLocal, netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
default:
return network.VerdictUndecided, false
}
// Log and permit.
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting DHCP: %s", pkt)
return network.VerdictAccept, true
case apiPort:
// Always allow direct access to the Portmaster API.
// Portmaster API is TCP only.
if meta.Protocol != packet.TCP {
return network.VerdictUndecided, false
}
// Check if the api port is even set.
if !apiPortSet {
return network.VerdictUndecided, false
}
// Must be destined for the API IP.
if !meta.Dst.Equal(apiIP) {
return network.VerdictUndecided, false
}
// Only fast-track local requests.
isMe, err := netenv.IsMyIP(meta.Src)
switch {
case err != nil:
log.Tracer(pkt.Ctx()).Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
return network.VerdictUndecided, false
case !isMe:
return network.VerdictUndecided, false
}
// Log and permit.
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting api connection: %s", pkt)
return network.VerdictAccept, true
case 53:
// Always allow direct access to the Portmaster Nameserver.
// DNS is both UDP and TCP.
// Check if a nameserver IP matcher is set.
if !nameserverIPMatcherReady.IsSet() {
return network.VerdictUndecided, false
}
// Check if packet is destined for a nameserver IP.
if !nameserverIPMatcher(meta.Dst) {
return network.VerdictUndecided, false
}
// Only fast-track local requests.
isMe, err := netenv.IsMyIP(meta.Src)
switch {
case err != nil:
log.Tracer(pkt.Ctx()).Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
return network.VerdictUndecided, false
case !isMe:
return network.VerdictUndecided, false
}
// Log and permit.
log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting local dns: %s", pkt)
// Add to DNS request connections to attribute DNS request if outgoing.
if pkt.IsOutbound() {
// Assign PID from packet directly, as processing stops after fast-track.
conn.PID = pkt.Info().PID
network.SaveDNSRequestConnection(conn, pkt)
}
// Accept local DNS, but only make permanent if we have the PID too.
return network.VerdictAccept, conn.PID != process.UndefinedProcessID
}
case compat.SystemIntegrationCheckProtocol:
if pkt.Info().Dst.Equal(compat.SystemIntegrationCheckDstIP) {
compat.SubmitSystemIntegrationCheckPacket(pkt)
return network.VerdictDrop, false
}
}
return network.VerdictUndecided, false
}
func fastTrackHandler(conn *network.Connection, pkt packet.Packet) {
fastTrackedVerdict, permanent := fastTrackedPermit(conn, pkt)
if fastTrackedVerdict != network.VerdictUndecided {
// Set verdict on connection.
conn.Verdict = fastTrackedVerdict
// Apply verdict to (real) packet.
if !pkt.InfoOnly() {
issueVerdict(conn, pkt, fastTrackedVerdict, permanent)
}
// Stop handler if permanent.
if permanent {
conn.SetVerdict(fastTrackedVerdict, "fast-tracked", "", nil)
// Do not finalize verdict, as we are missing necessary data.
conn.StopFirewallHandler()
}
// Do not continue to next handler.
return
}
// If packet is not fast-tracked, continue with gathering more information.
conn.UpdateFirewallHandler(gatherDataHandler)
gatherDataHandler(conn, pkt)
}
func gatherDataHandler(conn *network.Connection, pkt packet.Packet) {
// Get process info
_ = conn.GatherConnectionInfo(pkt)
// Errors are informational and are logged to the context.
// Run this handler again if data is not yet complete.
if !conn.DataIsComplete() {
return
}
// Continue to filter handler, when connection data is complete.
conn.UpdateFirewallHandler(filterHandler)
filterHandler(conn, pkt)
}
func filterHandler(conn *network.Connection, pkt packet.Packet) {
// Skip if data is not complete or packet is info-only.
if !conn.DataIsComplete() || pkt.InfoOnly() {
return
}
filterConnection := true
// Check for special (internal) connection cases.
switch {
case !conn.Inbound && localPortIsPreAuthenticated(conn.Entity.Protocol, conn.LocalPort):
// Approve connection.
conn.Accept("connection by Portmaster", noReasonOptionKey)
conn.Internal = true
filterConnection = false
log.Tracer(pkt.Ctx()).Infof("filter: granting own pre-authenticated connection %s", conn)
// Redirect outbound DNS packets if enabled,
case dnsQueryInterception() &&
pkt.IsOutbound() &&
pkt.Info().DstPort == 53 &&
// that don't match the address of our nameserver,
nameserverIPMatcherReady.IsSet() &&
!nameserverIPMatcher(pkt.Info().Dst) &&
// and are not broadcast queries by us.
// Context:
// - Unicast queries by the resolver are pre-authenticated.
// - Unicast queries by the compat self-check should be redirected.
!(conn.Process().Pid == ownPID &&
conn.Entity.IPScope == netutils.LocalMulticast):
// Reroute rogue dns queries back to Portmaster.
conn.SetVerdict(network.VerdictRerouteToNameserver, "redirecting rogue dns query", "", nil)
conn.Internal = true
log.Tracer(pkt.Ctx()).Infof("filter: redirecting dns query %s to Portmaster", conn)
// Add to DNS request connections to attribute DNS request.
network.SaveDNSRequestConnection(conn, pkt)
// End directly, as no other processing is necessary.
conn.StopFirewallHandler()
issueVerdict(conn, pkt, 0, true)
return
}
// Apply privacy filter and check tunneling.
FilterConnection(pkt.Ctx(), conn, pkt, filterConnection, true)
// Decide how to continue handling connection.
switch {
case conn.Inspecting:
log.Tracer(pkt.Ctx()).Trace("filter: start inspecting")
conn.SetFirewallHandler(inspectAndVerdictHandler)
inspectAndVerdictHandler(conn, pkt)
default:
conn.StopFirewallHandler()
issueVerdict(conn, pkt, 0, true)
}
}
// FilterConnection runs all the filtering (and tunneling) procedures.
func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet, checkFilter, checkTunnel bool) {
// Skip if data is not complete.
if !conn.DataIsComplete() {
return
}
if checkFilter {
if filterEnabled() {
log.Tracer(ctx).Trace("filter: starting decision process")
decideOnConnection(ctx, conn, pkt)
} else {
conn.Accept("privacy filter disabled", noReasonOptionKey)
}
}
// TODO: Enable inspection framework again.
conn.Inspecting = false
// TODO: Quick fix for the SPN.
// Use inspection framework for proper encryption detection.
switch conn.Entity.DstPort() {
case
22, // SSH
443, // HTTPS
465, // SMTP-SSL
853, // DoT
993, // IMAP-SSL
995: // POP3-SSL
conn.Encrypted = true
}
// Check if connection should be tunneled.
if checkTunnel {
checkTunneling(ctx, conn)
}
// Request tunneling if no tunnel is set and connection should be tunneled.
if conn.Verdict == network.VerdictRerouteToTunnel &&
conn.TunnelContext == nil {
err := requestTunneling(ctx, conn)
if err == nil {
conn.ConnectionEstablished = true
} else {
// Set connection to failed, but keep tunneling data.
// The tunneling data makes connection easy to recognize as a failed SPN
// connection and the data will help with debugging and displaying in the UI.
conn.Failed(fmt.Sprintf("failed to request tunneling: %s", err), "")
}
}
}
func verdictHandler(conn *network.Connection, pkt packet.Packet) {
// Ignore info-only packets in this handler.
if pkt.InfoOnly() {
return
}
issueVerdict(conn, pkt, 0, true)
}
func inspectAndVerdictHandler(conn *network.Connection, pkt packet.Packet) {
// Ignore info-only packets in this handler.
if pkt.InfoOnly() {
return
}
// Run inspectors.
pktVerdict, continueInspection := inspection.RunInspectors(conn, pkt)
if continueInspection {
issueVerdict(conn, pkt, pktVerdict, false)
return
}
// we are done with inspecting
conn.StopFirewallHandler()
issueVerdict(conn, pkt, 0, true)
}
func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.Verdict, allowPermanent bool) {
// Check if packed was already fast-tracked by the OS integration.
if pkt.FastTrackedByIntegration() {
return
}
// Enable permanent verdict.
if allowPermanent && !conn.VerdictPermanent {
switch {
case !permanentVerdicts():
// Permanent verdicts are disabled by configuration.
case conn.Entity != nil && reference.IsICMP(conn.Entity.Protocol):
case pkt != nil && reference.IsICMP(uint8(pkt.Info().Protocol)):
// ICMP is handled differently based on payload, so we cannot use persistent verdicts.
default:
conn.VerdictPermanent = true
conn.SaveWhenFinished()
}
}
// do not allow to circumvent decision: e.g. to ACCEPT packets from a DROP-ed connection
if verdict < conn.Verdict {
verdict = conn.Verdict
}
var err error
switch verdict {
case network.VerdictAccept:
atomic.AddUint64(packetsAccepted, 1)
if conn.VerdictPermanent {
err = pkt.PermanentAccept()
} else {
err = pkt.Accept()
}
case network.VerdictBlock:
atomic.AddUint64(packetsBlocked, 1)
if conn.VerdictPermanent {
err = pkt.PermanentBlock()
} else {
err = pkt.Block()
}
case network.VerdictDrop:
atomic.AddUint64(packetsDropped, 1)
if conn.VerdictPermanent {
err = pkt.PermanentDrop()
} else {
err = pkt.Drop()
}
case network.VerdictRerouteToNameserver:
err = pkt.RerouteToNameserver()
case network.VerdictRerouteToTunnel:
err = pkt.RerouteToTunnel()
case network.VerdictFailed:
atomic.AddUint64(packetsFailed, 1)
err = pkt.Drop()
case network.VerdictUndecided, network.VerdictUndeterminable:
log.Tracer(pkt.Ctx()).Warningf("filter: tried to apply verdict %s to pkt %s: dropping instead", verdict, pkt)
fallthrough
default:
atomic.AddUint64(packetsDropped, 1)
err = pkt.Drop()
}
if err != nil {
log.Tracer(pkt.Ctx()).Warningf("filter: failed to apply verdict to pkt %s: %s", pkt, err)
}
}
// func tunnelHandler(pkt packet.Packet) {
// tunnelInfo := GetTunnelInfo(pkt.Info().Dst)
// if tunnelInfo == nil {
// pkt.Block()
// return
// }
//
// entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords())
// log.Tracef("filter: rerouting %s to tunnel entry point", pkt)
// pkt.RerouteToTunnel()
// return
// }
func packetHandler(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case pkt := <-interception.Packets:
if pkt != nil {
handlePacket(pkt)
} else {
return errors.New("received nil packet from interception")
}
}
}
}
func bandwidthUpdateHandler(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case bwUpdate := <-interception.BandwidthUpdates:
if bwUpdate != nil {
// DEBUG:
// log.Debugf("filter: bandwidth update: %s", bwUpdate)
updateBandwidth(ctx, bwUpdate)
} else {
return errors.New("received nil bandwidth update from interception")
}
}
}
}
func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
// Check if update makes sense.
if bwUpdate.BytesReceived == 0 && bwUpdate.BytesSent == 0 {
return
}
// Get connection.
conn, ok := network.GetConnection(bwUpdate.ConnID)
if !ok {
return
}
// Do not wait for connections that are locked.
// TODO: Use atomic operations for updating bandwidth stats.
if !conn.TryLock() {
// DEBUG:
// log.Warningf("filter: failed to lock connection for bandwidth update: %s", conn)
return
}
defer conn.Unlock()
bytesIn := bwUpdate.BytesReceived
bytesOut := bwUpdate.BytesSent
// Update stats according to method.
switch bwUpdate.Method {
case packet.Absolute:
bytesIn = bwUpdate.BytesReceived - conn.BytesReceived
bytesOut = bwUpdate.BytesSent - conn.BytesSent
conn.BytesReceived = bwUpdate.BytesReceived
conn.BytesSent = bwUpdate.BytesSent
case packet.Additive:
conn.BytesReceived += bwUpdate.BytesReceived
conn.BytesSent += bwUpdate.BytesSent
default:
log.Warningf("filter: unsupported bandwidth update method: %d", bwUpdate.Method)
return
}
// Update bandwidth in the netquery module.
if netquery.DefaultModule != nil && conn.BandwidthEnabled {
if err := netquery.DefaultModule.Store.UpdateBandwidth(
ctx,
conn.HistoryEnabled,
fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile),
conn.Process().GetKey(),
conn.ID,
bytesIn,
bytesOut,
); err != nil {
log.Errorf("filter: failed to persist bandwidth data: %s", err)
}
}
}
func statLogger(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-time.After(10 * time.Second):
log.Tracef(
"filter: packets accepted %d, blocked %d, dropped %d, failed %d",
atomic.LoadUint64(packetsAccepted),
atomic.LoadUint64(packetsBlocked),
atomic.LoadUint64(packetsDropped),
atomic.LoadUint64(packetsFailed),
)
atomic.StoreUint64(packetsAccepted, 0)
atomic.StoreUint64(packetsBlocked, 0)
atomic.StoreUint64(packetsDropped, 0)
atomic.StoreUint64(packetsFailed, 0)
}
}
}

111
service/firewall/preauth.go Normal file
View File

@@ -0,0 +1,111 @@
package firewall
import (
"fmt"
"net"
"strconv"
"sync"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/resolver"
)
var (
preAuthenticatedPorts = make(map[string]struct{})
preAuthenticatedPortsLock sync.Mutex
)
func init() {
resolver.SetLocalAddrFactory(PermittedAddr)
netenv.SetLocalAddrFactory(PermittedAddr)
}
// PermittedAddr returns an already permitted local address for the given network for reliable connectivity.
// Returns nil in case of error.
func PermittedAddr(network string) net.Addr {
switch network {
case "udp":
return PermittedUDPAddr()
case "tcp":
return PermittedTCPAddr()
}
return nil
}
// PermittedUDPAddr returns an already permitted local udp address for reliable connectivity.
// Returns nil in case of error.
func PermittedUDPAddr() *net.UDPAddr {
preAuthdPort := GetPermittedPort(packet.UDP)
if preAuthdPort == 0 {
return nil
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", preAuthdPort))
if err != nil {
return nil
}
return addr
}
// PermittedTCPAddr returns an already permitted local tcp address for reliable connectivity.
// Returns nil in case of error.
func PermittedTCPAddr() *net.TCPAddr {
preAuthdPort := GetPermittedPort(packet.TCP)
if preAuthdPort == 0 {
return nil
}
addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf(":%d", preAuthdPort))
if err != nil {
return nil
}
return addr
}
// GetPermittedPort returns a local port number that is already permitted for communication.
// This bypasses the process attribution step to guarantee connectivity.
// Communication on the returned port is attributed to the Portmaster.
// Every pre-authenticated port is only valid once.
// If no unused local port number can be found, it will return 0, which is
// expected to trigger automatic port selection by the underlying OS.
func GetPermittedPort(protocol packet.IPProtocol) uint16 {
port, ok := network.GetUnusedLocalPort(uint8(protocol))
if !ok {
return 0
}
preAuthenticatedPortsLock.Lock()
defer preAuthenticatedPortsLock.Unlock()
// Save generated port.
key := generateLocalPreAuthKey(uint8(protocol), port)
preAuthenticatedPorts[key] = struct{}{}
return port
}
// localPortIsPreAuthenticated checks if the given protocol and port are
// pre-authenticated and should be attributed to the Portmaster itself.
func localPortIsPreAuthenticated(protocol uint8, port uint16) bool {
preAuthenticatedPortsLock.Lock()
defer preAuthenticatedPortsLock.Unlock()
// Check if the given protocol and port are pre-authenticated.
key := generateLocalPreAuthKey(protocol, port)
_, ok := preAuthenticatedPorts[key]
if ok {
// Immediately remove pre authenticated port.
delete(preAuthenticatedPorts, key)
}
return ok
}
// generateLocalPreAuthKey creates a map key for the pre-authenticated ports.
func generateLocalPreAuthKey(protocol uint8, port uint16) string {
return strconv.Itoa(int(protocol)) + ":" + strconv.Itoa(int(port))
}

327
service/firewall/prompt.go Normal file
View File

@@ -0,0 +1,327 @@
package firewall
import (
"context"
"fmt"
"sync"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/notifications"
"github.com/safing/portmaster/service/intel"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/service/profile/endpoints"
)
const (
// notification action IDs.
allowDomainAll = "allow-domain-all"
allowDomainDistinct = "allow-domain-distinct"
blockDomainAll = "block-domain-all"
blockDomainDistinct = "block-domain-distinct"
allowIP = "allow-ip"
blockIP = "block-ip"
allowServingIP = "allow-serving-ip"
blockServingIP = "block-serving-ip"
cancelPrompt = "cancel"
)
var (
promptNotificationCreation sync.Mutex
decisionTimeout int64 = 10 // in seconds
)
type promptData struct {
Entity *intel.Entity
Profile promptProfile
}
type promptProfile struct {
Source string
ID string
LinkedPath string
}
func prompt(ctx context.Context, conn *network.Connection) {
// Create notification.
n := createPrompt(ctx, conn)
if n == nil {
// createPrompt returns nil when no further action should be taken.
return
}
// Add prompt to connection.
conn.SetPrompt(n)
// Get decision timeout and make sure it does not exceed the ask timeout.
timeout := decisionTimeout
if timeout > askTimeout() {
timeout = askTimeout()
}
// wait for response/timeout
select {
case promptResponse := <-n.Response():
switch promptResponse {
case allowDomainAll, allowDomainDistinct, allowIP, allowServingIP:
// Accept
conn.Accept("allowed via prompt", profile.CfgOptionEndpointsKey)
case "":
// Dismissed
conn.Deny("prompting canceled, waiting for new decision", profile.CfgOptionDefaultActionKey)
default:
// Deny
conn.Deny("blocked via prompt", profile.CfgOptionEndpointsKey)
}
case <-time.After(time.Duration(timeout) * time.Second):
log.Tracer(ctx).Debugf("filter: continuing prompting async")
conn.Deny("prompting in progress, please respond to prompt", profile.CfgOptionDefaultActionKey)
case <-ctx.Done():
log.Tracer(ctx).Debugf("filter: aborting prompting because of shutdown")
conn.Drop("shutting down", noReasonOptionKey)
}
}
// promptIDPrefix is an identifier for privacy filter prompts. This is also used
// in the UI, so don't change!
const promptIDPrefix = "filter:prompt"
func createPrompt(ctx context.Context, conn *network.Connection) (n *notifications.Notification) {
expires := time.Now().Add(time.Duration(askTimeout()) * time.Second).Unix()
// Get local profile.
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
log.Tracer(ctx).Warningf("filter: tried creating prompt for connection without profile")
return nil
}
localProfile := layeredProfile.LocalProfile()
if localProfile == nil {
log.Tracer(ctx).Warningf("filter: tried creating prompt for connection without local profile")
return nil
}
// first check if there is an existing notification for this.
// build notification ID
var nID string
switch {
case conn.Inbound, conn.Entity.Domain == "": // connection to/from IP
nID = fmt.Sprintf(
"%s-%s-%v-%s",
promptIDPrefix,
localProfile.ID,
conn.Inbound,
conn.Entity.IP,
)
default: // connection to domain
nID = fmt.Sprintf(
"%s-%s-%s",
promptIDPrefix,
localProfile.ID,
conn.Entity.Domain,
)
}
// Only handle one notification at a time.
promptNotificationCreation.Lock()
defer promptNotificationCreation.Unlock()
n = notifications.Get(nID)
// If there already is a notification, just update the expiry.
if n != nil {
// Get notification state and action.
n.Lock()
state := n.State
action := n.SelectedActionID
n.Unlock()
// If the notification is still active, extend and return.
// This can happen because user input (prompts changing the endpoint
// lists) can happen any time - also between checking the endpoint lists
// and now.
if state == notifications.Active {
n.Update(expires)
log.Tracer(ctx).Debugf("filter: updated existing prompt notification")
return n
}
// The notification is not active anymore, let's check if there is an
// action we can perform.
// If there already is an action defined, we won't be fast enough to
// receive the action with n.Response(), so we take direct action here.
if action != "" {
switch action {
case allowDomainAll, allowDomainDistinct, allowIP, allowServingIP:
conn.Accept("allowed via prompt", profile.CfgOptionEndpointsKey)
default: // deny
conn.Deny("blocked via prompt", profile.CfgOptionEndpointsKey)
}
return nil // Do not take further action.
}
// Continue to create a new notification because the previous one is not
// active and not actionable.
}
// Reference relevant data for save function
entity := conn.Entity
// Also needed: localProfile
// Create new notification.
n = &notifications.Notification{
EventID: nID,
Type: notifications.Prompt,
Title: "Connection Prompt",
Category: "Privacy Filter",
ShowOnSystem: askWithSystemNotifications(),
EventData: &promptData{
Entity: entity,
Profile: promptProfile{
Source: string(localProfile.Source),
ID: localProfile.ID,
// LinkedPath is used to enhance the display of the prompt in the UI.
// TODO: Using the process path is a workaround. Find a cleaner solution.
LinkedPath: conn.Process().Path,
},
},
Expires: expires,
}
// Set action function.
n.SetActionFunction(func(_ context.Context, n *notifications.Notification) error {
return saveResponse(
localProfile,
entity,
n.SelectedActionID,
)
})
// Get name of profile for notification. The profile is read-locked by the firewall handler.
profileName := localProfile.Name
// add message and actions
switch {
case conn.Inbound:
n.Message = fmt.Sprintf("%s wants to accept connections from %s (%d/%d)", profileName, conn.Entity.IP.String(), conn.Entity.Protocol, conn.Entity.Port)
n.AvailableActions = []*notifications.Action{
{
ID: allowServingIP,
Text: "Allow",
},
{
ID: blockServingIP,
Text: "Block",
},
}
case conn.Entity.Domain == "": // direct connection
n.Message = fmt.Sprintf("%s wants to connect to %s (%d/%d)", profileName, conn.Entity.IP.String(), conn.Entity.Protocol, conn.Entity.Port)
n.AvailableActions = []*notifications.Action{
{
ID: allowIP,
Text: "Allow",
},
{
ID: blockIP,
Text: "Block",
},
}
default: // connection to domain
n.Message = fmt.Sprintf("%s wants to connect to %s", profileName, conn.Entity.Domain)
n.AvailableActions = []*notifications.Action{
{
ID: allowDomainAll,
Text: "Allow",
},
{
ID: blockDomainAll,
Text: "Block",
},
}
}
n.Save()
log.Tracer(ctx).Debugf("filter: sent prompt notification")
return n
}
// promptSavingLock makes sure that only one prompt is saved at a time.
// Should prompts be persisted in bulk, the next save process might load an
// outdated profile and save it, losing config data.
var promptSavingLock sync.Mutex
func saveResponse(p *profile.Profile, entity *intel.Entity, promptResponse string) error {
if promptResponse == cancelPrompt {
return nil
}
promptSavingLock.Lock()
defer promptSavingLock.Unlock()
// Update the profile if necessary.
if p.IsOutdated() {
var err error
p, err = profile.GetLocalProfile(p.ID, nil, nil)
if err != nil {
return err
}
}
var ep endpoints.Endpoint
switch promptResponse {
case allowDomainAll:
ep = &endpoints.EndpointDomain{
EndpointBase: endpoints.EndpointBase{Permitted: true},
OriginalValue: "." + entity.Domain,
}
case allowDomainDistinct:
ep = &endpoints.EndpointDomain{
EndpointBase: endpoints.EndpointBase{Permitted: true},
OriginalValue: entity.Domain,
}
case blockDomainAll:
ep = &endpoints.EndpointDomain{
EndpointBase: endpoints.EndpointBase{Permitted: false},
OriginalValue: "." + entity.Domain,
}
case blockDomainDistinct:
ep = &endpoints.EndpointDomain{
EndpointBase: endpoints.EndpointBase{Permitted: false},
OriginalValue: entity.Domain,
}
case allowIP, allowServingIP:
ep = &endpoints.EndpointIP{
EndpointBase: endpoints.EndpointBase{Permitted: true},
IP: entity.IP,
}
case blockIP, blockServingIP:
ep = &endpoints.EndpointIP{
EndpointBase: endpoints.EndpointBase{Permitted: false},
IP: entity.IP,
}
case cancelPrompt:
return nil
default:
return fmt.Errorf("unknown prompt response: %s", promptResponse)
}
switch promptResponse {
case allowServingIP, blockServingIP:
p.AddServiceEndpoint(ep.String())
log.Infof("filter: added incoming rule to profile %s (LP Rev. %d): %q",
p, p.LayeredProfile().RevisionCnt(), ep.String())
default:
p.AddEndpoint(ep.String())
log.Infof("filter: added outgoing rule to profile %s (LP Rev. %d): %q",
p, p.LayeredProfile().RevisionCnt(), ep.String())
}
return nil
}

215
service/firewall/tunnel.go Normal file
View File

@@ -0,0 +1,215 @@
package firewall
import (
"context"
"errors"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/intel"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/service/resolver"
"github.com/safing/portmaster/spn/captain"
"github.com/safing/portmaster/spn/crew"
"github.com/safing/portmaster/spn/navigator"
"github.com/safing/portmaster/spn/sluice"
)
func checkTunneling(ctx context.Context, conn *network.Connection) {
// Check if the connection should be tunneled at all.
switch {
case !tunnelEnabled():
// Tunneling is disabled.
return
case !conn.Entity.IPScope.IsGlobal():
// Can't tunnel Local/LAN connections.
return
case conn.Inbound:
// Can't tunnel incoming connections.
return
case conn.Verdict != network.VerdictAccept:
// Connection will be blocked.
return
case conn.IPProtocol != packet.TCP && conn.IPProtocol != packet.UDP:
// Unsupported protocol.
return
case conn.Process().Pid == ownPID:
// Bypass tunneling for certain own connections.
switch {
case !captain.ClientReady():
return
case captain.IsExcepted(conn.Entity.IP):
return
}
}
// Check more extensively for Local/LAN connections.
localNet, err := netenv.GetLocalNetwork(conn.Entity.IP)
if err != nil {
log.Warningf("firewall: failed to check if %s is in my net: %s", conn.Entity.IP, err)
} else if localNet != nil {
// With IPv6, just checking the IP scope is not enough, as the host very
// likely has a public IPv6 address.
// Don't tunnel LAN connections.
// TODO: We currently don't check the full LAN scope, but only the
// broadcast domain of the host - ie. the networks that the host is
// directly attached to.
return
}
// Get profile.
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
conn.Failed("no profile set", "")
return
}
// Update profile.
if layeredProfile.NeedsUpdate() {
// Update revision counter in connection.
conn.ProfileRevisionCounter = layeredProfile.Update(
conn.Process().MatchingData(),
conn.Process().CreateProfileCallback,
)
conn.SaveWhenFinished()
} else {
// Check if the revision counter of the connection needs updating.
revCnt := layeredProfile.RevisionCnt()
if conn.ProfileRevisionCounter != revCnt {
conn.ProfileRevisionCounter = revCnt
conn.SaveWhenFinished()
}
}
// Check if tunneling is enabled for this app at all.
if !layeredProfile.UseSPN() {
return
}
// Check if tunneling is enabled for entity.
conn.Entity.FetchData(ctx)
result, _ := layeredProfile.MatchSPNUsagePolicy(ctx, conn.Entity)
switch result {
case endpoints.MatchError:
conn.Failed("failed to check SPN rules", profile.CfgOptionSPNUsagePolicyKey)
return
case endpoints.Denied:
return
case endpoints.Permitted, endpoints.NoMatch:
// Continue
}
// Tunnel all the things!
conn.SaveWhenFinished()
// Check if ready.
if !captain.ClientReady() {
// Block connection as SPN is not ready yet.
log.Tracer(ctx).Trace("SPN not ready for tunneling")
conn.Failed("SPN not ready for tunneling", "")
return
}
conn.SetVerdictDirectly(network.VerdictRerouteToTunnel)
conn.Tunneled = true
}
func requestTunneling(ctx context.Context, conn *network.Connection) error {
// Get profile.
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
return errors.New("no profile set")
}
// Get tunnel options.
conn.TunnelOpts = DeriveTunnelOptions(layeredProfile, conn.Process(), conn.Entity, conn.Encrypted)
// Queue request in sluice.
err := sluice.AwaitRequest(conn, crew.HandleSluiceRequest)
if err != nil {
return err
}
log.Tracer(ctx).Trace("filter: tunneling requested")
return nil
}
func init() {
navigator.DeriveTunnelOptions = func(lp *profile.LayeredProfile, destination *intel.Entity, connEncrypted bool) *navigator.Options {
return DeriveTunnelOptions(lp, nil, destination, connEncrypted)
}
}
// DeriveTunnelOptions derives and returns the tunnel options from the connection and profile.
func DeriveTunnelOptions(lp *profile.LayeredProfile, proc *process.Process, destination *intel.Entity, connEncrypted bool) *navigator.Options {
// Set options.
tunnelOpts := &navigator.Options{
Transit: &navigator.TransitHubOptions{
HubPolicies: lp.StackedTransitHubPolicies(),
},
Destination: &navigator.DestinationHubOptions{
HubPolicies: lp.StackedExitHubPolicies(),
CheckHubPolicyWith: destination,
},
RoutingProfile: lp.SPNRoutingAlgorithm(),
}
if !connEncrypted {
tunnelOpts.Destination.Regard = tunnelOpts.Destination.Regard.Add(navigator.StateTrusted)
// TODO: Add this when all Hubs are on v0.6.21+
// tunnelOpts.Destination.Regard = tunnelOpts.Destination.Regard.Add(navigator.StateAllowUnencrypted)
}
// Add required verified owners if community nodes should not be used.
if !useCommunityNodes() {
tunnelOpts.Transit.RequireVerifiedOwners = captain.NonCommunityVerifiedOwners
tunnelOpts.Destination.RequireVerifiedOwners = captain.NonCommunityVerifiedOwners
}
// Get routing profile for checking for upgrades.
routingProfile := navigator.GetRoutingProfile(tunnelOpts.RoutingProfile)
// If we have any exit hub policies, we must be able to hop in order to follow the policy.
// Switch to single-hop routing to allow for routing with hub selection.
if routingProfile.MaxHops <= 1 && navigator.HubPoliciesAreSet(tunnelOpts.Destination.HubPolicies) {
tunnelOpts.RoutingProfile = navigator.RoutingProfileSingleHopID
}
// If the current home node is not trusted, then upgrade at least to two hops.
if routingProfile.MinHops < 2 {
homeNode, _ := navigator.Main.GetHome()
if homeNode != nil && !homeNode.State.Has(navigator.StateTrusted) {
tunnelOpts.RoutingProfile = navigator.RoutingProfileDoubleHopID
}
}
// Special handling for the internal DNS resolver.
if proc != nil && proc.Pid == ownPID && resolver.IsResolverAddress(destination.IP, destination.Port) {
dnsExitHubPolicy, err := captain.GetDNSExitHubPolicy()
if err != nil {
log.Errorf("firewall: failed to get dns exit hub policy: %s", err)
}
if err == nil && dnsExitHubPolicy.IsSet() {
// Apply the dns exit hub policy, if set.
tunnelOpts.Destination.HubPolicies = []endpoints.Endpoints{dnsExitHubPolicy}
// Use the routing algorithm from the profile, as the home profile won't work with the policy.
tunnelOpts.RoutingProfile = lp.SPNRoutingAlgorithm()
// Raise the routing algorithm at least to single-hop.
if tunnelOpts.RoutingProfile == navigator.RoutingProfileHomeID {
tunnelOpts.RoutingProfile = navigator.RoutingProfileSingleHopID
}
} else {
// Disable any policies for the internal DNS resolver.
tunnelOpts.Destination.HubPolicies = nil
// Always use the home routing profile for the internal DNS resolver.
tunnelOpts.RoutingProfile = navigator.RoutingProfileHomeID
}
}
return tunnelOpts
}