Improve performance, logging

This commit is contained in:
Daniel
2019-05-10 11:57:51 +02:00
parent e72ed023db
commit 16db10b84b
9 changed files with 282 additions and 129 deletions

View File

@@ -60,7 +60,7 @@ func (p *Process) Save() {
processesLock.Unlock() processesLock.Unlock()
} }
if dbControllerFlag.IsSet() { if dbControllerFlag.IsSet() && p.Error == "" {
go dbController.PushUpdate(p) go dbController.PushUpdate(p)
} }
} }
@@ -109,10 +109,15 @@ func CleanProcessStorage(thresholdDuration time.Duration) {
} }
} }
// clean virtual processes // clean virtual/failed processes
for _, p := range processesCopy { for _, p := range processesCopy {
p.Lock() p.Lock()
if p.Virtual { switch {
case p.Error != "":
if p.Meta().Created < threshold {
delete = true
}
case p.Virtual:
_, parentIsAlive := processes[p.ParentPid] _, parentIsAlive := processes[p.ParentPid]
if !parentIsAlive { if !parentIsAlive {
delete = true delete = true

View File

@@ -1,6 +1,7 @@
package process package process
import ( import (
"context"
"errors" "errors"
"net" "net"
@@ -55,24 +56,29 @@ func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) {
// GetProcessByPacket returns the process that owns the given packet. // GetProcessByPacket returns the process that owns the given packet.
func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) { func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) {
log.Tracer(pkt.Ctx()).Tracef("process: getting process and profile by packet")
var pid int var pid int
pid, direction, err = GetPidByPacket(pkt) pid, direction, err = GetPidByPacket(pkt)
if err != nil { if err != nil {
log.Tracer(pkt.Ctx()).Errorf("process: failed to find PID of connection: %s", err)
return nil, direction, err return nil, direction, err
} }
if pid < 0 { if pid < 0 {
log.Tracer(pkt.Ctx()).Errorf("process: %s", ErrConnectionNotFound.Error())
return nil, direction, ErrConnectionNotFound return nil, direction, ErrConnectionNotFound
} }
process, err = GetOrFindPrimaryProcess(pid) process, err = GetOrFindPrimaryProcess(pkt.Ctx(), pid)
if err != nil { if err != nil {
log.Tracer(pkt.Ctx()).Errorf("process: failed to find (primary) process with PID: %s", err)
return nil, direction, err return nil, direction, err
} }
err = process.FindProfiles() err = process.FindProfiles(pkt.Ctx())
if err != nil { if err != nil {
log.Errorf("failed to find profiles for process %s: %s", process.String(), err) log.Tracer(pkt.Ctx()).Errorf("process: failed to find profiles for process %s: %s", process, err)
log.Errorf("failed to find profiles for process %s: %s", process, err)
} }
return process, direction, nil return process, direction, nil
@@ -103,29 +109,33 @@ func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remote
} }
// GetProcessByEndpoints returns the process that owns the described link. // GetProcessByEndpoints returns the process that owns the described link.
func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) { func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) {
log.Tracer(ctx).Tracef("process: getting process and profile by endpoints")
var pid int var pid int
pid, _, err = GetPidByEndpoints(localIP, localPort, remoteIP, remotePort, protocol) pid, _, err = GetPidByEndpoints(localIP, localPort, remoteIP, remotePort, protocol)
if err != nil { if err != nil {
log.Tracer(ctx).Errorf("process: failed to find PID of connection: %s", err)
return nil, err return nil, err
} }
if pid < 0 { if pid < 0 {
log.Tracer(ctx).Errorf("process: %s", ErrConnectionNotFound.Error())
return nil, ErrConnectionNotFound return nil, ErrConnectionNotFound
} }
process, err = GetOrFindPrimaryProcess(pid) process, err = GetOrFindPrimaryProcess(ctx, pid)
if err != nil { if err != nil {
log.Tracer(ctx).Errorf("process: failed to find (primary) process with PID: %s", err)
return nil, err return nil, err
} }
err = process.FindProfiles() err = process.FindProfiles(ctx)
if err != nil { if err != nil {
log.Errorf("failed to find profiles for process %s: %s", process.String(), err) log.Tracer(ctx).Errorf("process: failed to find profiles for process %s: %s", process, err)
log.Errorf("process: failed to find profiles for process %s: %s", process, err)
} }
return process, nil return process, nil
} }
// GetActiveConnectionIDs returns a list of all active connection IDs. // GetActiveConnectionIDs returns a list of all active connection IDs.

View File

@@ -10,15 +10,15 @@ import (
) )
var ( var (
tcp4Connections []*connectionEntry tcp4Connections []*ConnectionEntry
tcp4Listeners []*connectionEntry tcp4Listeners []*ConnectionEntry
tcp6Connections []*connectionEntry tcp6Connections []*ConnectionEntry
tcp6Listeners []*connectionEntry tcp6Listeners []*ConnectionEntry
udp4Connections []*connectionEntry udp4Connections []*ConnectionEntry
udp4Listeners []*connectionEntry udp4Listeners []*ConnectionEntry
udp6Connections []*connectionEntry udp6Connections []*ConnectionEntry
udp6Listeners []*connectionEntry udp6Listeners []*ConnectionEntry
ipHelper *IPHelper ipHelper *IPHelper
lock sync.RWMutex lock sync.RWMutex
@@ -34,6 +34,7 @@ func checkIPHelper() (err error) {
return nil return nil
} }
// GetTCP4PacketInfo returns the pid of the given IPv4/TCP connection.
func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
// search // search
@@ -69,6 +70,7 @@ func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote
return -1, pktDirection, nil return -1, pktDirection, nil
} }
// GetTCP6PacketInfo returns the pid of the given IPv6/TCP connection.
func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
// search // search
@@ -104,6 +106,7 @@ func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote
return -1, pktDirection, nil return -1, pktDirection, nil
} }
// GetUDP4PacketInfo returns the pid of the given IPv4/UDP connection.
func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
// search // search
@@ -139,6 +142,7 @@ func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote
return -1, pktDirection, nil return -1, pktDirection, nil
} }
// GetUDP6PacketInfo returns the pid of the given IPv6/UDP connection.
func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
// search // search
@@ -174,7 +178,7 @@ func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote
return -1, pktDirection, nil return -1, pktDirection, nil
} }
func search(connections, listeners []*connectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) { func search(connections, listeners []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) {
lock.RLock() lock.RLock()
defer lock.RUnlock() defer lock.RUnlock()
@@ -203,7 +207,7 @@ func search(connections, listeners []*connectionEntry, localIP, remoteIP net.IP,
return -1, pktDirection return -1, pktDirection
} }
func searchConnections(list []*connectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) { func searchConnections(list []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) {
for _, entry := range list { for _, entry := range list {
if localPort == entry.localPort && if localPort == entry.localPort &&
@@ -217,7 +221,7 @@ func searchConnections(list []*connectionEntry, localIP, remoteIP net.IP, localP
return -1 return -1
} }
func searchListeners(list []*connectionEntry, localIP net.IP, localPort uint16) (pid int) { func searchListeners(list []*ConnectionEntry, localIP net.IP, localPort uint16) (pid int) {
for _, entry := range list { for _, entry := range list {
if localPort == entry.localPort && if localPort == entry.localPort &&
@@ -230,6 +234,7 @@ func searchListeners(list []*connectionEntry, localIP net.IP, localPort uint16)
return -1 return -1
} }
// GetActiveConnectionIDs returns all currently active connection IDs.
func GetActiveConnectionIDs() (connections []string) { func GetActiveConnectionIDs() (connections []string) {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()

View File

@@ -14,11 +14,12 @@ var (
errInvalid = errors.New("IPHelper not initialzed or broken") errInvalid = errors.New("IPHelper not initialzed or broken")
) )
// IPHelper represents a subset of the Windows iphlpapi.dll.
type IPHelper struct { type IPHelper struct {
dll *windows.LazyDLL dll *windows.LazyDLL
getExtendedTcpTable *windows.LazyProc getExtendedTCPTable *windows.LazyProc
getExtendedUdpTable *windows.LazyProc getExtendedUDPTable *windows.LazyProc
// getOwnerModuleFromTcpEntry *windows.LazyProc // getOwnerModuleFromTcpEntry *windows.LazyProc
// getOwnerModuleFromTcp6Entry *windows.LazyProc // getOwnerModuleFromTcp6Entry *windows.LazyProc
// getOwnerModuleFromUdpEntry *windows.LazyProc // getOwnerModuleFromUdpEntry *windows.LazyProc
@@ -27,6 +28,7 @@ type IPHelper struct {
valid *abool.AtomicBool valid *abool.AtomicBool
} }
// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded).
func New() (*IPHelper, error) { func New() (*IPHelper, error) {
new := &IPHelper{} new := &IPHelper{}
@@ -41,13 +43,13 @@ func New() (*IPHelper, error) {
} }
// load functions // load functions
new.getExtendedTcpTable = new.dll.NewProc("GetExtendedTcpTable") new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable")
err = new.getExtendedTcpTable.Find() err = new.getExtendedTCPTable.Find()
if err != nil { if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err) return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err)
} }
new.getExtendedUdpTable = new.dll.NewProc("GetExtendedUdpTable") new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable")
err = new.getExtendedUdpTable.Find() err = new.getExtendedUDPTable.Find()
if err != nil { if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err) return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err)
} }

View File

@@ -6,18 +6,24 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
// Windows API constants
const ( const (
iphelper_TCP_TABLE_OWNER_PID_ALL uintptr = 5 iphelperTCPTableOwnerPIDAll uintptr = 5
iphelper_UDP_TABLE_OWNER_PID uintptr = 1 iphelperUDPTableOwnerPID uintptr = 1
iphelper_TCP_STATE_LISTEN uint32 = 2 iphelperTCPStateListen uint32 = 2
winErrInsufficientBuffer = uintptr(windows.ERROR_INSUFFICIENT_BUFFER)
winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER)
) )
type connectionEntry struct { // ConnectionEntry describes a connection state table entry.
type ConnectionEntry struct {
localIP net.IP localIP net.IP
remoteIP net.IP remoteIP net.IP
localPort uint16 localPort uint16
@@ -25,17 +31,17 @@ type connectionEntry struct {
pid int pid int
} }
func (entry *connectionEntry) String() string { func (entry *ConnectionEntry) String() string {
return fmt.Sprintf("PID=%d %s:%d <> %s:%d", entry.pid, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort) return fmt.Sprintf("PID=%d %s:%d <> %s:%d", entry.pid, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)
} }
type iphelperTcpTable struct { type iphelperTCPTable struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx
numEntries uint32 numEntries uint32
table [4096]iphelperTcpRow table [4096]iphelperTCPRow
} }
type iphelperTcpRow struct { type iphelperTCPRow struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366913(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366913(v=vs.85).aspx
state uint32 state uint32
localAddr uint32 localAddr uint32
@@ -45,51 +51,52 @@ type iphelperTcpRow struct {
owningPid uint32 owningPid uint32
} }
type iphelperTcp6Table struct { type iphelperTCP6Table struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366905(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366905(v=vs.85).aspx
numEntries uint32 numEntries uint32
table [4096]iphelperTcp6Row table [4096]iphelperTCP6Row
} }
type iphelperTcp6Row struct { type iphelperTCP6Row struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366896(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366896(v=vs.85).aspx
localAddr [16]byte localAddr [16]byte
localScopeId uint32 localScopeID uint32
localPort uint32 localPort uint32
remoteAddr [16]byte remoteAddr [16]byte
remoteScopeId uint32 remoteScopeID uint32
remotePort uint32 remotePort uint32
state uint32 state uint32
owningPid uint32 owningPid uint32
} }
type iphelperUdpTable struct { type iphelperUDPTable struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366932(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366932(v=vs.85).aspx
numEntries uint32 numEntries uint32
table [4096]iphelperUdpRow table [4096]iphelperUDPRow
} }
type iphelperUdpRow struct { type iphelperUDPRow struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366928(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366928(v=vs.85).aspx
localAddr uint32 localAddr uint32
localPort uint32 localPort uint32
owningPid uint32 owningPid uint32
} }
type iphelperUdp6Table struct { type iphelperUDP6Table struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366925(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366925(v=vs.85).aspx
numEntries uint32 numEntries uint32
table [4096]iphelperUdp6Row table [4096]iphelperUDP6Row
} }
type iphelperUdp6Row struct { type iphelperUDP6Row struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366923(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366923(v=vs.85).aspx
localAddr [16]byte localAddr [16]byte
localScopeId uint32 localScopeID uint32
localPort uint32 localPort uint32
owningPid uint32 owningPid uint32
} }
// IP and Protocol constants
const ( const (
IPv4 uint8 = 4 IPv4 uint8 = 4
IPv6 uint8 = 6 IPv6 uint8 = 6
@@ -98,7 +105,51 @@ const (
UDP uint8 = 17 UDP uint8 = 17
) )
func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connections []*connectionEntry, listeners []*connectionEntry, err error) { const (
startBufSize = 4096
bufSizeUses = 100
)
var (
bufSize = startBufSize
bufSizeUsageLeft = bufSizeUses
bufSizeLock sync.Mutex
)
func getBufSize() int {
bufSizeLock.Lock()
defer bufSizeLock.Unlock()
// using bufSize
bufSizeUsageLeft--
// check if we want to reset
if bufSizeUsageLeft <= 0 {
// reset
bufSize = startBufSize
bufSizeUsageLeft = bufSizeUses
}
return bufSize
}
func increaseBufSize() int {
bufSizeLock.Lock()
defer bufSizeLock.Unlock()
// increase
bufSize = bufSize * 2
// not too much
if bufSize > 65536 {
bufSize = 65536
}
// reset
bufSizeUsageLeft = bufSizeUses
// return new bufSize
return bufSize
}
// GetTables returns the current connection state table of Windows of the given protocol and IP version.
func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connections []*ConnectionEntry, listeners []*ConnectionEntry, err error) {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365928(v=vs.85).aspx // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365928(v=vs.85).aspx
if !ipHelper.valid.IsSet() { if !ipHelper.valid.IsSet() {
@@ -115,50 +166,61 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
return nil, nil, errors.New("invalid protocol") return nil, nil, errors.New("invalid protocol")
} }
bufSize := 4096 // try max 3 times
buf := make([]byte, bufSize) maxTries := 3
var r1 uintptr bufSize := getBufSize()
var buf []byte
switch protocol { for i := 1; i <= maxTries; i++ {
case TCP: buf = make([]byte, bufSize)
r1, _, err = ipHelper.getExtendedTcpTable.Call( var r1 uintptr
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable
uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize
0, // _In_ BOOL bOrder
uintptr(afClass), // _In_ ULONG ulAf
iphelper_TCP_TABLE_OWNER_PID_ALL, // _In_ TCP_TABLE_CLASS TableClass
0, // _In_ ULONG Reserved
)
case UDP:
r1, _, err = ipHelper.getExtendedUdpTable.Call(
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable,
uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize,
0, // _In_ BOOL bOrder,
uintptr(afClass), // _In_ ULONG ulAf,
iphelper_UDP_TABLE_OWNER_PID, // _In_ UDP_TABLE_CLASS TableClass,
0, // _In_ ULONG Reserved
)
}
switch r1 { switch protocol {
// case windows.ERROR_INSUFFICIENT_BUFFER: case TCP:
// return nil, fmt.Errorf("insufficient buffer error: %s", err) r1, _, err = ipHelper.getExtendedTCPTable.Call(
// case windows.ERROR_INVALID_PARAMETER: uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable
// return nil, fmt.Errorf("invalid parameter: %s", err) uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize
case windows.NO_ERROR: 0, // _In_ BOOL bOrder
default: uintptr(afClass), // _In_ ULONG ulAf
return nil, nil, fmt.Errorf("unexpected error: %s", err) iphelperTCPTableOwnerPIDAll, // _In_ TCP_TABLE_CLASS TableClass
0, // _In_ ULONG Reserved
)
case UDP:
r1, _, err = ipHelper.getExtendedUDPTable.Call(
uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable,
uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize,
0, // _In_ BOOL bOrder,
uintptr(afClass), // _In_ ULONG ulAf,
iphelperUDPTableOwnerPID, // _In_ UDP_TABLE_CLASS TableClass,
0, // _In_ ULONG Reserved
)
}
switch r1 {
case winErrInsufficientBuffer:
if i >= maxTries {
return nil, nil, fmt.Errorf("insufficient buffer error (tried %d times): [NT 0x%X] %s", i, r1, err)
}
bufSize = increaseBufSize()
case winErrInvalidParameter:
return nil, nil, fmt.Errorf("invalid parameter: [NT 0x%X] %s", r1, err)
case windows.NO_ERROR:
// success
break
default:
return nil, nil, fmt.Errorf("unexpected error: [NT 0x%X] %s", r1, err)
}
} }
// parse output // parse output
switch { switch {
case protocol == TCP && ipVersion == IPv4: case protocol == TCP && ipVersion == IPv4:
tcpTable := (*iphelperTcpTable)(unsafe.Pointer(&buf[0])) tcpTable := (*iphelperTCPTable)(unsafe.Pointer(&buf[0]))
table := tcpTable.table[:tcpTable.numEntries] table := tcpTable.table[:tcpTable.numEntries]
for _, row := range table { for _, row := range table {
new := &connectionEntry{} new := &ConnectionEntry{}
// PID // PID
new.pid = int(row.owningPid) new.pid = int(row.owningPid)
@@ -170,7 +232,7 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
new.localPort = uint16(row.localPort>>8 | row.localPort<<8) new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
// remote // remote
if row.state == iphelper_TCP_STATE_LISTEN { if row.state == iphelperTCPStateListen {
listeners = append(listeners, new) listeners = append(listeners, new)
} else { } else {
new.remoteIP = convertIPv4(row.remoteAddr) new.remoteIP = convertIPv4(row.remoteAddr)
@@ -182,11 +244,11 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
case protocol == TCP && ipVersion == IPv6: case protocol == TCP && ipVersion == IPv6:
tcpTable := (*iphelperTcp6Table)(unsafe.Pointer(&buf[0])) tcpTable := (*iphelperTCP6Table)(unsafe.Pointer(&buf[0]))
table := tcpTable.table[:tcpTable.numEntries] table := tcpTable.table[:tcpTable.numEntries]
for _, row := range table { for _, row := range table {
new := &connectionEntry{} new := &ConnectionEntry{}
// PID // PID
new.pid = int(row.owningPid) new.pid = int(row.owningPid)
@@ -196,7 +258,7 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
new.localPort = uint16(row.localPort>>8 | row.localPort<<8) new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
// remote // remote
if row.state == iphelper_TCP_STATE_LISTEN { if row.state == iphelperTCPStateListen {
if new.localIP.Equal(net.IPv6zero) { if new.localIP.Equal(net.IPv6zero) {
new.localIP = nil new.localIP = nil
} }
@@ -211,11 +273,11 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
case protocol == UDP && ipVersion == IPv4: case protocol == UDP && ipVersion == IPv4:
udpTable := (*iphelperUdpTable)(unsafe.Pointer(&buf[0])) udpTable := (*iphelperUDPTable)(unsafe.Pointer(&buf[0]))
table := udpTable.table[:udpTable.numEntries] table := udpTable.table[:udpTable.numEntries]
for _, row := range table { for _, row := range table {
new := &connectionEntry{} new := &ConnectionEntry{}
// PID // PID
new.pid = int(row.owningPid) new.pid = int(row.owningPid)
@@ -232,11 +294,11 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
case protocol == UDP && ipVersion == IPv6: case protocol == UDP && ipVersion == IPv6:
udpTable := (*iphelperUdp6Table)(unsafe.Pointer(&buf[0])) udpTable := (*iphelperUDP6Table)(unsafe.Pointer(&buf[0]))
table := udpTable.table[:udpTable.numEntries] table := udpTable.table[:udpTable.numEntries]
for _, row := range table { for _, row := range table {
new := &connectionEntry{} new := &ConnectionEntry{}
// PID // PID
new.pid = int(row.owningPid) new.pid = int(row.owningPid)

View File

@@ -1,6 +1,7 @@
package process package process
import ( import (
"context"
"fmt" "fmt"
"github.com/Safing/portbase/database" "github.com/Safing/portbase/database"
@@ -14,7 +15,8 @@ var (
) )
// FindProfiles finds and assigns a profile set to the process. // FindProfiles finds and assigns a profile set to the process.
func (p *Process) FindProfiles() error { func (p *Process) FindProfiles(ctx context.Context) error {
log.Tracer(ctx).Trace("process: loading profile set")
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
@@ -66,7 +68,7 @@ func (p *Process) FindProfiles() error {
// FIXME: implement! // FIXME: implement!
p.UserProfileKey = userProfile.Key() p.UserProfileKey = userProfile.Key()
p.profileSet = profile.NewSet(fmt.Sprintf("%d-%s", p.Pid, p.Path), userProfile, nil) p.profileSet = profile.NewSet(ctx, fmt.Sprintf("%d-%s", p.Pid, p.Path), userProfile, nil)
go p.Save() go p.Save()
return nil return nil

View File

@@ -3,6 +3,7 @@
package process package process
import ( import (
"context"
"fmt" "fmt"
"path/filepath" "path/filepath"
"runtime" "runtime"
@@ -16,6 +17,11 @@ import (
"github.com/Safing/portmaster/profile" "github.com/Safing/portmaster/profile"
) )
var (
dupReqMap = make(map[int]*sync.Mutex)
dupReqLock sync.Mutex
)
// A Process represents a process running on the operating system // A Process represents a process running on the operating system
type Process struct { type Process struct {
record.Base record.Base
@@ -45,7 +51,9 @@ type Process struct {
FirstCommEstablished int64 FirstCommEstablished int64
LastCommEstablished int64 LastCommEstablished int64
CommCount uint CommCount uint
Virtual bool // This process is merged into another process
Virtual bool // This process is either merged into another process or is not needed.
Error string // If this is set, the process is invalid. This is used to cache failing or inexistent processes.
} }
// ProfileSet returns the assigned profile set. // ProfileSet returns the assigned profile set.
@@ -90,7 +98,9 @@ func (p *Process) RemoveCommunication() {
} }
// GetOrFindPrimaryProcess returns the highest process in the tree that matches the given PID. // GetOrFindPrimaryProcess returns the highest process in the tree that matches the given PID.
func GetOrFindPrimaryProcess(pid int) (*Process, error) { func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) {
log.Tracer(ctx).Tracef("process: getting primary process for PID %d", pid)
if pid == -1 { if pid == -1 {
return UnknownProcess, nil return UnknownProcess, nil
} }
@@ -98,44 +108,52 @@ func GetOrFindPrimaryProcess(pid int) (*Process, error) {
return OSProcess, nil return OSProcess, nil
} }
process, err := loadProcess(pid) process, err := loadProcess(ctx, pid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if process.Error != "" {
return nil, fmt.Errorf("%s [cached error]", process.Error)
}
for { for {
if process.ParentPid == 0 { if process.ParentPid == 0 {
return OSProcess, nil return OSProcess, nil
} }
parentProcess, err := loadProcess(process.ParentPid) parentProcess, err := loadProcess(ctx, process.ParentPid)
if err != nil { if err != nil {
log.Tracef("process: could not get parent (%d): %s", process.Pid, err) log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s", process.Pid, process.ParentPid, err)
return process, nil
}
if parentProcess.Error != "" {
log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s [cached error]", process.Pid, process.ParentPid, parentProcess.Error)
return process, nil return process, nil
} }
// parent process does not match, we reached the top of the tree of matching processes // if parent process path does not match, we have reached the top of the tree of matching processes
if process.Path != parentProcess.Path { if process.Path != parentProcess.Path {
// save to storage // found primary process
process.Save()
// return primary process // mark for use, save to storage
process.Lock()
if process.Virtual {
process.Virtual = false
go process.Save()
}
process.Unlock()
return process, nil return process, nil
} }
// mark as virtual
process.Lock()
process.Virtual = true
process.Unlock()
// save to storage
process.Save()
// continue up to process tree // continue up to process tree
process = parentProcess process = parentProcess
} }
} }
// GetOrFindProcess returns the process for the given PID. // GetOrFindProcess returns the process for the given PID.
func GetOrFindProcess(pid int) (*Process, error) { func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) {
log.Tracer(ctx).Tracef("process: getting process for PID %d", pid)
if pid == -1 { if pid == -1 {
return UnknownProcess, nil return UnknownProcess, nil
} }
@@ -143,17 +161,25 @@ func GetOrFindProcess(pid int) (*Process, error) {
return OSProcess, nil return OSProcess, nil
} }
p, err := loadProcess(pid) p, err := loadProcess(ctx, pid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if p.Error != "" {
return nil, fmt.Errorf("%s [cached error]", p.Error)
}
// save to storage // mark for use, save to storage
p.Save() p.Lock()
if p.Virtual {
p.Virtual = false
go p.Save()
}
p.Unlock()
return p, nil return p, nil
} }
func loadProcess(pid int) (*Process, error) { func loadProcess(ctx context.Context, pid int) (*Process, error) {
if pid == -1 { if pid == -1 {
return UnknownProcess, nil return UnknownProcess, nil
} }
@@ -166,8 +192,39 @@ func loadProcess(pid int) (*Process, error) {
return process, nil return process, nil
} }
// dedup requests
dupReqLock.Lock()
mutex, requestActive := dupReqMap[pid]
if !requestActive {
mutex = new(sync.Mutex)
mutex.Lock()
dupReqMap[pid] = mutex
dupReqLock.Unlock()
} else {
dupReqLock.Unlock()
log.Tracer(ctx).Tracef("process: waiting for duplicate request for PID %d to complete", pid)
mutex.Lock()
// wait until duplicate request is finished, then fetch current Process and return
mutex.Unlock()
process, ok = GetProcessFromStorage(pid)
if ok {
return process, nil
}
return nil, fmt.Errorf("previous request for process with PID %d failed", pid)
}
// lock request for this pid
defer func() {
dupReqLock.Lock()
delete(dupReqMap, pid)
dupReqLock.Unlock()
mutex.Unlock()
}()
// create new process
new := &Process{ new := &Process{
Pid: pid, Pid: pid,
Virtual: true, // caller must decide to actually use the process - we need to save now.
} }
switch { switch {
@@ -187,16 +244,15 @@ func loadProcess(pid int) (*Process, error) {
var uids []int32 var uids []int32
uids, err = pInfo.Uids() uids, err = pInfo.Uids()
if err != nil { if err != nil {
log.Warningf("process: failed to get UID for p%d: %s", pid, err) return failedToLoad(new, fmt.Errorf("failed to get UID for p%d: %s", pid, err))
} else {
new.UserID = int(uids[0])
} }
new.UserID = int(uids[0])
} }
// Username // Username
new.UserName, err = pInfo.Username() new.UserName, err = pInfo.Username()
if err != nil { if err != nil {
log.Warningf("process: failed to get Username for p%d: %s", pid, err) return failedToLoad(new, fmt.Errorf("process: failed to get Username for p%d: %s", pid, err))
} }
// TODO: User Home // TODO: User Home
@@ -205,15 +261,14 @@ func loadProcess(pid int) (*Process, error) {
// PPID // PPID
ppid, err := pInfo.Ppid() ppid, err := pInfo.Ppid()
if err != nil { if err != nil {
log.Warningf("process: failed to get PPID for p%d: %s", pid, err) return failedToLoad(new, fmt.Errorf("failed to get PPID for p%d: %s", pid, err))
} else {
new.ParentPid = int(ppid)
} }
new.ParentPid = int(ppid)
// Path // Path
new.Path, err = pInfo.Exe() new.Path, err = pInfo.Exe()
if err != nil { if err != nil {
log.Warningf("process: failed to get Path for p%d: %s", pid, err) return failedToLoad(new, fmt.Errorf("failed to get Path for p%d: %s", pid, err))
} }
// Executable Name // Executable Name
_, new.ExecName = filepath.Split(new.Path) _, new.ExecName = filepath.Split(new.Path)
@@ -228,18 +283,21 @@ func loadProcess(pid int) (*Process, error) {
// Command line arguments // Command line arguments
new.CmdLine, err = pInfo.Cmdline() new.CmdLine, err = pInfo.Cmdline()
if err != nil { if err != nil {
log.Warningf("process: failed to get Cmdline for p%d: %s", pid, err) return failedToLoad(new, fmt.Errorf("failed to get Cmdline for p%d: %s", pid, err))
} }
// Name // Name
new.Name, err = pInfo.Name() new.Name, err = pInfo.Name()
if err != nil { if err != nil {
log.Warningf("process: failed to get Name for p%d: %s", pid, err) return failedToLoad(new, fmt.Errorf("failed to get Name for p%d: %s", pid, err))
} }
if new.Name == "" { if new.Name == "" {
new.Name = new.ExecName new.Name = new.ExecName
} }
// OS specifics
new.specialOSInit()
// TODO: App Icon // TODO: App Icon
// new.Icon, err = // new.Icon, err =
@@ -317,5 +375,12 @@ func loadProcess(pid int) (*Process, error) {
// } // }
} }
new.Save()
return new, nil return new, nil
} }
func failedToLoad(p *Process, err error) (*Process, error) {
p.Error = err.Error()
p.Save()
return nil, err
}

View File

@@ -1,6 +1,7 @@
package profile package profile
import ( import (
"context"
"sync" "sync"
"github.com/Safing/portbase/log" "github.com/Safing/portbase/log"
@@ -11,13 +12,13 @@ var (
activeProfileSetsLock sync.RWMutex activeProfileSetsLock sync.RWMutex
) )
func activateProfileSet(set *Set) { func activateProfileSet(ctx context.Context, set *Set) {
activeProfileSetsLock.Lock() activeProfileSetsLock.Lock()
defer activeProfileSetsLock.Unlock() defer activeProfileSetsLock.Unlock()
set.Lock() set.Lock()
defer set.Unlock() defer set.Unlock()
activeProfileSets[set.id] = set activeProfileSets[set.id] = set
log.Tracef("profile: activated profile set %s", set.id) log.Tracer(ctx).Tracef("profile: activated profile set %s", set.id)
} }
// DeactivateProfileSet marks a profile set as not active. // DeactivateProfileSet marks a profile set as not active.

View File

@@ -1,6 +1,7 @@
package profile package profile
import ( import (
"context"
"net" "net"
"sync" "sync"
@@ -27,7 +28,7 @@ type Set struct {
} }
// NewSet returns a new profile set with given the profiles. // NewSet returns a new profile set with given the profiles.
func NewSet(id string, user, stamp *Profile) *Set { func NewSet(ctx context.Context, id string, user, stamp *Profile) *Set {
new := &Set{ new := &Set{
id: id, id: id,
profiles: [4]*Profile{ profiles: [4]*Profile{
@@ -37,7 +38,7 @@ func NewSet(id string, user, stamp *Profile) *Set {
nil, // Default nil, // Default
}, },
} }
activateProfileSet(new) activateProfileSet(ctx, new)
new.Update(status.SecurityLevelFortress) new.Update(status.SecurityLevelFortress)
return new return new
} }