diff --git a/process/database.go b/process/database.go index a01ebe65..1a6ea558 100644 --- a/process/database.go +++ b/process/database.go @@ -60,7 +60,7 @@ func (p *Process) Save() { processesLock.Unlock() } - if dbControllerFlag.IsSet() { + if dbControllerFlag.IsSet() && p.Error == "" { go dbController.PushUpdate(p) } } @@ -109,10 +109,15 @@ func CleanProcessStorage(thresholdDuration time.Duration) { } } - // clean virtual processes + // clean virtual/failed processes for _, p := range processesCopy { p.Lock() - if p.Virtual { + switch { + case p.Error != "": + if p.Meta().Created < threshold { + delete = true + } + case p.Virtual: _, parentIsAlive := processes[p.ParentPid] if !parentIsAlive { delete = true diff --git a/process/find.go b/process/find.go index 66ee99af..903899d1 100644 --- a/process/find.go +++ b/process/find.go @@ -1,6 +1,7 @@ package process import ( + "context" "errors" "net" @@ -55,24 +56,29 @@ func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { // GetProcessByPacket returns the process that owns the given packet. func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) { + log.Tracer(pkt.Ctx()).Tracef("process: getting process and profile by packet") var pid int pid, direction, err = GetPidByPacket(pkt) if err != nil { + log.Tracer(pkt.Ctx()).Errorf("process: failed to find PID of connection: %s", err) return nil, direction, err } if pid < 0 { + log.Tracer(pkt.Ctx()).Errorf("process: %s", ErrConnectionNotFound.Error()) return nil, direction, ErrConnectionNotFound } - process, err = GetOrFindPrimaryProcess(pid) + process, err = GetOrFindPrimaryProcess(pkt.Ctx(), pid) if err != nil { + log.Tracer(pkt.Ctx()).Errorf("process: failed to find (primary) process with PID: %s", err) return nil, direction, err } - err = process.FindProfiles() + err = process.FindProfiles(pkt.Ctx()) if err != nil { - log.Errorf("failed to find profiles for process %s: %s", process.String(), err) + log.Tracer(pkt.Ctx()).Errorf("process: failed to find profiles for process %s: %s", process, err) + log.Errorf("failed to find profiles for process %s: %s", process, err) } return process, direction, nil @@ -103,29 +109,33 @@ func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remote } // GetProcessByEndpoints returns the process that owns the described link. -func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) { +func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) { + log.Tracer(ctx).Tracef("process: getting process and profile by endpoints") var pid int pid, _, err = GetPidByEndpoints(localIP, localPort, remoteIP, remotePort, protocol) if err != nil { + log.Tracer(ctx).Errorf("process: failed to find PID of connection: %s", err) return nil, err } if pid < 0 { + log.Tracer(ctx).Errorf("process: %s", ErrConnectionNotFound.Error()) return nil, ErrConnectionNotFound } - process, err = GetOrFindPrimaryProcess(pid) + process, err = GetOrFindPrimaryProcess(ctx, pid) if err != nil { + log.Tracer(ctx).Errorf("process: failed to find (primary) process with PID: %s", err) return nil, err } - err = process.FindProfiles() + err = process.FindProfiles(ctx) if err != nil { - log.Errorf("failed to find profiles for process %s: %s", process.String(), err) + log.Tracer(ctx).Errorf("process: failed to find profiles for process %s: %s", process, err) + log.Errorf("process: failed to find profiles for process %s: %s", process, err) } return process, nil - } // GetActiveConnectionIDs returns a list of all active connection IDs. diff --git a/process/iphelper/get.go b/process/iphelper/get.go index 85b788a4..78c6bb12 100644 --- a/process/iphelper/get.go +++ b/process/iphelper/get.go @@ -10,15 +10,15 @@ import ( ) var ( - tcp4Connections []*connectionEntry - tcp4Listeners []*connectionEntry - tcp6Connections []*connectionEntry - tcp6Listeners []*connectionEntry + tcp4Connections []*ConnectionEntry + tcp4Listeners []*ConnectionEntry + tcp6Connections []*ConnectionEntry + tcp6Listeners []*ConnectionEntry - udp4Connections []*connectionEntry - udp4Listeners []*connectionEntry - udp6Connections []*connectionEntry - udp6Listeners []*connectionEntry + udp4Connections []*ConnectionEntry + udp4Listeners []*ConnectionEntry + udp6Connections []*ConnectionEntry + udp6Listeners []*ConnectionEntry ipHelper *IPHelper lock sync.RWMutex @@ -34,6 +34,7 @@ func checkIPHelper() (err error) { return nil } +// GetTCP4PacketInfo returns the pid of the given IPv4/TCP connection. func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { // search @@ -69,6 +70,7 @@ func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote return -1, pktDirection, nil } +// GetTCP6PacketInfo returns the pid of the given IPv6/TCP connection. func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { // search @@ -104,6 +106,7 @@ func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote return -1, pktDirection, nil } +// GetUDP4PacketInfo returns the pid of the given IPv4/UDP connection. func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { // search @@ -139,6 +142,7 @@ func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote return -1, pktDirection, nil } +// GetUDP6PacketInfo returns the pid of the given IPv6/UDP connection. func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { // search @@ -174,7 +178,7 @@ func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote return -1, pktDirection, nil } -func search(connections, listeners []*connectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) { +func search(connections, listeners []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) { lock.RLock() defer lock.RUnlock() @@ -203,7 +207,7 @@ func search(connections, listeners []*connectionEntry, localIP, remoteIP net.IP, return -1, pktDirection } -func searchConnections(list []*connectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) { +func searchConnections(list []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) { for _, entry := range list { if localPort == entry.localPort && @@ -217,7 +221,7 @@ func searchConnections(list []*connectionEntry, localIP, remoteIP net.IP, localP return -1 } -func searchListeners(list []*connectionEntry, localIP net.IP, localPort uint16) (pid int) { +func searchListeners(list []*ConnectionEntry, localIP net.IP, localPort uint16) (pid int) { for _, entry := range list { if localPort == entry.localPort && @@ -230,6 +234,7 @@ func searchListeners(list []*connectionEntry, localIP net.IP, localPort uint16) return -1 } +// GetActiveConnectionIDs returns all currently active connection IDs. func GetActiveConnectionIDs() (connections []string) { lock.Lock() defer lock.Unlock() diff --git a/process/iphelper/iphelper.go b/process/iphelper/iphelper.go index f6ea8fa6..f4799fe9 100644 --- a/process/iphelper/iphelper.go +++ b/process/iphelper/iphelper.go @@ -14,11 +14,12 @@ var ( errInvalid = errors.New("IPHelper not initialzed or broken") ) +// IPHelper represents a subset of the Windows iphlpapi.dll. type IPHelper struct { dll *windows.LazyDLL - getExtendedTcpTable *windows.LazyProc - getExtendedUdpTable *windows.LazyProc + getExtendedTCPTable *windows.LazyProc + getExtendedUDPTable *windows.LazyProc // getOwnerModuleFromTcpEntry *windows.LazyProc // getOwnerModuleFromTcp6Entry *windows.LazyProc // getOwnerModuleFromUdpEntry *windows.LazyProc @@ -27,6 +28,7 @@ type IPHelper struct { valid *abool.AtomicBool } +// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded). func New() (*IPHelper, error) { new := &IPHelper{} @@ -41,13 +43,13 @@ func New() (*IPHelper, error) { } // load functions - new.getExtendedTcpTable = new.dll.NewProc("GetExtendedTcpTable") - err = new.getExtendedTcpTable.Find() + new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable") + err = new.getExtendedTCPTable.Find() if err != nil { return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err) } - new.getExtendedUdpTable = new.dll.NewProc("GetExtendedUdpTable") - err = new.getExtendedUdpTable.Find() + new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable") + err = new.getExtendedUDPTable.Find() if err != nil { return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err) } diff --git a/process/iphelper/tables.go b/process/iphelper/tables.go index c27f849f..d80ab0c5 100644 --- a/process/iphelper/tables.go +++ b/process/iphelper/tables.go @@ -6,18 +6,24 @@ import ( "errors" "fmt" "net" + "sync" "unsafe" "golang.org/x/sys/windows" ) +// Windows API constants const ( - iphelper_TCP_TABLE_OWNER_PID_ALL uintptr = 5 - iphelper_UDP_TABLE_OWNER_PID uintptr = 1 - iphelper_TCP_STATE_LISTEN uint32 = 2 + iphelperTCPTableOwnerPIDAll uintptr = 5 + iphelperUDPTableOwnerPID uintptr = 1 + iphelperTCPStateListen uint32 = 2 + + winErrInsufficientBuffer = uintptr(windows.ERROR_INSUFFICIENT_BUFFER) + winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER) ) -type connectionEntry struct { +// ConnectionEntry describes a connection state table entry. +type ConnectionEntry struct { localIP net.IP remoteIP net.IP localPort uint16 @@ -25,17 +31,17 @@ type connectionEntry struct { pid int } -func (entry *connectionEntry) String() string { +func (entry *ConnectionEntry) String() string { return fmt.Sprintf("PID=%d %s:%d <> %s:%d", entry.pid, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort) } -type iphelperTcpTable struct { +type iphelperTCPTable struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx numEntries uint32 - table [4096]iphelperTcpRow + table [4096]iphelperTCPRow } -type iphelperTcpRow struct { +type iphelperTCPRow struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366913(v=vs.85).aspx state uint32 localAddr uint32 @@ -45,51 +51,52 @@ type iphelperTcpRow struct { owningPid uint32 } -type iphelperTcp6Table struct { +type iphelperTCP6Table struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366905(v=vs.85).aspx numEntries uint32 - table [4096]iphelperTcp6Row + table [4096]iphelperTCP6Row } -type iphelperTcp6Row struct { +type iphelperTCP6Row struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366896(v=vs.85).aspx localAddr [16]byte - localScopeId uint32 + localScopeID uint32 localPort uint32 remoteAddr [16]byte - remoteScopeId uint32 + remoteScopeID uint32 remotePort uint32 state uint32 owningPid uint32 } -type iphelperUdpTable struct { +type iphelperUDPTable struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366932(v=vs.85).aspx numEntries uint32 - table [4096]iphelperUdpRow + table [4096]iphelperUDPRow } -type iphelperUdpRow struct { +type iphelperUDPRow struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366928(v=vs.85).aspx localAddr uint32 localPort uint32 owningPid uint32 } -type iphelperUdp6Table struct { +type iphelperUDP6Table struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366925(v=vs.85).aspx numEntries uint32 - table [4096]iphelperUdp6Row + table [4096]iphelperUDP6Row } -type iphelperUdp6Row struct { +type iphelperUDP6Row struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366923(v=vs.85).aspx localAddr [16]byte - localScopeId uint32 + localScopeID uint32 localPort uint32 owningPid uint32 } +// IP and Protocol constants const ( IPv4 uint8 = 4 IPv6 uint8 = 6 @@ -98,7 +105,51 @@ const ( UDP uint8 = 17 ) -func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connections []*connectionEntry, listeners []*connectionEntry, err error) { +const ( + startBufSize = 4096 + bufSizeUses = 100 +) + +var ( + bufSize = startBufSize + bufSizeUsageLeft = bufSizeUses + bufSizeLock sync.Mutex +) + +func getBufSize() int { + bufSizeLock.Lock() + defer bufSizeLock.Unlock() + + // using bufSize + bufSizeUsageLeft-- + // check if we want to reset + if bufSizeUsageLeft <= 0 { + // reset + bufSize = startBufSize + bufSizeUsageLeft = bufSizeUses + } + + return bufSize +} + +func increaseBufSize() int { + bufSizeLock.Lock() + defer bufSizeLock.Unlock() + + // increase + bufSize = bufSize * 2 + // not too much + if bufSize > 65536 { + bufSize = 65536 + } + // reset + bufSizeUsageLeft = bufSizeUses + // return new bufSize + return bufSize +} + +// GetTables returns the current connection state table of Windows of the given protocol and IP version. +func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connections []*ConnectionEntry, listeners []*ConnectionEntry, err error) { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365928(v=vs.85).aspx if !ipHelper.valid.IsSet() { @@ -115,50 +166,61 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection return nil, nil, errors.New("invalid protocol") } - bufSize := 4096 - buf := make([]byte, bufSize) - var r1 uintptr + // try max 3 times + maxTries := 3 + bufSize := getBufSize() + var buf []byte - switch protocol { - case TCP: - r1, _, err = ipHelper.getExtendedTcpTable.Call( - uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable - uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize - 0, // _In_ BOOL bOrder - uintptr(afClass), // _In_ ULONG ulAf - iphelper_TCP_TABLE_OWNER_PID_ALL, // _In_ TCP_TABLE_CLASS TableClass - 0, // _In_ ULONG Reserved - ) - case UDP: - r1, _, err = ipHelper.getExtendedUdpTable.Call( - uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable, - uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize, - 0, // _In_ BOOL bOrder, - uintptr(afClass), // _In_ ULONG ulAf, - iphelper_UDP_TABLE_OWNER_PID, // _In_ UDP_TABLE_CLASS TableClass, - 0, // _In_ ULONG Reserved - ) - } + for i := 1; i <= maxTries; i++ { + buf = make([]byte, bufSize) + var r1 uintptr - switch r1 { - // case windows.ERROR_INSUFFICIENT_BUFFER: - // return nil, fmt.Errorf("insufficient buffer error: %s", err) - // case windows.ERROR_INVALID_PARAMETER: - // return nil, fmt.Errorf("invalid parameter: %s", err) - case windows.NO_ERROR: - default: - return nil, nil, fmt.Errorf("unexpected error: %s", err) + switch protocol { + case TCP: + r1, _, err = ipHelper.getExtendedTCPTable.Call( + uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable + uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize + 0, // _In_ BOOL bOrder + uintptr(afClass), // _In_ ULONG ulAf + iphelperTCPTableOwnerPIDAll, // _In_ TCP_TABLE_CLASS TableClass + 0, // _In_ ULONG Reserved + ) + case UDP: + r1, _, err = ipHelper.getExtendedUDPTable.Call( + uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable, + uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize, + 0, // _In_ BOOL bOrder, + uintptr(afClass), // _In_ ULONG ulAf, + iphelperUDPTableOwnerPID, // _In_ UDP_TABLE_CLASS TableClass, + 0, // _In_ ULONG Reserved + ) + } + + switch r1 { + case winErrInsufficientBuffer: + if i >= maxTries { + return nil, nil, fmt.Errorf("insufficient buffer error (tried %d times): [NT 0x%X] %s", i, r1, err) + } + bufSize = increaseBufSize() + case winErrInvalidParameter: + return nil, nil, fmt.Errorf("invalid parameter: [NT 0x%X] %s", r1, err) + case windows.NO_ERROR: + // success + break + default: + return nil, nil, fmt.Errorf("unexpected error: [NT 0x%X] %s", r1, err) + } } // parse output switch { case protocol == TCP && ipVersion == IPv4: - tcpTable := (*iphelperTcpTable)(unsafe.Pointer(&buf[0])) + tcpTable := (*iphelperTCPTable)(unsafe.Pointer(&buf[0])) table := tcpTable.table[:tcpTable.numEntries] for _, row := range table { - new := &connectionEntry{} + new := &ConnectionEntry{} // PID new.pid = int(row.owningPid) @@ -170,7 +232,7 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection new.localPort = uint16(row.localPort>>8 | row.localPort<<8) // remote - if row.state == iphelper_TCP_STATE_LISTEN { + if row.state == iphelperTCPStateListen { listeners = append(listeners, new) } else { new.remoteIP = convertIPv4(row.remoteAddr) @@ -182,11 +244,11 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection case protocol == TCP && ipVersion == IPv6: - tcpTable := (*iphelperTcp6Table)(unsafe.Pointer(&buf[0])) + tcpTable := (*iphelperTCP6Table)(unsafe.Pointer(&buf[0])) table := tcpTable.table[:tcpTable.numEntries] for _, row := range table { - new := &connectionEntry{} + new := &ConnectionEntry{} // PID new.pid = int(row.owningPid) @@ -196,7 +258,7 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection new.localPort = uint16(row.localPort>>8 | row.localPort<<8) // remote - if row.state == iphelper_TCP_STATE_LISTEN { + if row.state == iphelperTCPStateListen { if new.localIP.Equal(net.IPv6zero) { new.localIP = nil } @@ -211,11 +273,11 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection case protocol == UDP && ipVersion == IPv4: - udpTable := (*iphelperUdpTable)(unsafe.Pointer(&buf[0])) + udpTable := (*iphelperUDPTable)(unsafe.Pointer(&buf[0])) table := udpTable.table[:udpTable.numEntries] for _, row := range table { - new := &connectionEntry{} + new := &ConnectionEntry{} // PID new.pid = int(row.owningPid) @@ -232,11 +294,11 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection case protocol == UDP && ipVersion == IPv6: - udpTable := (*iphelperUdp6Table)(unsafe.Pointer(&buf[0])) + udpTable := (*iphelperUDP6Table)(unsafe.Pointer(&buf[0])) table := udpTable.table[:udpTable.numEntries] for _, row := range table { - new := &connectionEntry{} + new := &ConnectionEntry{} // PID new.pid = int(row.owningPid) diff --git a/process/matching.go b/process/matching.go index 44b62541..1edaa966 100644 --- a/process/matching.go +++ b/process/matching.go @@ -1,6 +1,7 @@ package process import ( + "context" "fmt" "github.com/Safing/portbase/database" @@ -14,7 +15,8 @@ var ( ) // FindProfiles finds and assigns a profile set to the process. -func (p *Process) FindProfiles() error { +func (p *Process) FindProfiles(ctx context.Context) error { + log.Tracer(ctx).Trace("process: loading profile set") p.Lock() defer p.Unlock() @@ -66,7 +68,7 @@ func (p *Process) FindProfiles() error { // FIXME: implement! p.UserProfileKey = userProfile.Key() - p.profileSet = profile.NewSet(fmt.Sprintf("%d-%s", p.Pid, p.Path), userProfile, nil) + p.profileSet = profile.NewSet(ctx, fmt.Sprintf("%d-%s", p.Pid, p.Path), userProfile, nil) go p.Save() return nil diff --git a/process/process.go b/process/process.go index e7ddc9dd..2963fdc1 100644 --- a/process/process.go +++ b/process/process.go @@ -3,6 +3,7 @@ package process import ( + "context" "fmt" "path/filepath" "runtime" @@ -16,6 +17,11 @@ import ( "github.com/Safing/portmaster/profile" ) +var ( + dupReqMap = make(map[int]*sync.Mutex) + dupReqLock sync.Mutex +) + // A Process represents a process running on the operating system type Process struct { record.Base @@ -45,7 +51,9 @@ type Process struct { FirstCommEstablished int64 LastCommEstablished int64 CommCount uint - Virtual bool // This process is merged into another process + + Virtual bool // This process is either merged into another process or is not needed. + Error string // If this is set, the process is invalid. This is used to cache failing or inexistent processes. } // ProfileSet returns the assigned profile set. @@ -90,7 +98,9 @@ func (p *Process) RemoveCommunication() { } // GetOrFindPrimaryProcess returns the highest process in the tree that matches the given PID. -func GetOrFindPrimaryProcess(pid int) (*Process, error) { +func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { + log.Tracer(ctx).Tracef("process: getting primary process for PID %d", pid) + if pid == -1 { return UnknownProcess, nil } @@ -98,44 +108,52 @@ func GetOrFindPrimaryProcess(pid int) (*Process, error) { return OSProcess, nil } - process, err := loadProcess(pid) + process, err := loadProcess(ctx, pid) if err != nil { return nil, err } + if process.Error != "" { + return nil, fmt.Errorf("%s [cached error]", process.Error) + } for { if process.ParentPid == 0 { return OSProcess, nil } - parentProcess, err := loadProcess(process.ParentPid) + parentProcess, err := loadProcess(ctx, process.ParentPid) if err != nil { - log.Tracef("process: could not get parent (%d): %s", process.Pid, err) + log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s", process.Pid, process.ParentPid, err) + return process, nil + } + if parentProcess.Error != "" { + log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s [cached error]", process.Pid, process.ParentPid, parentProcess.Error) return process, nil } - // parent process does not match, we reached the top of the tree of matching processes + // if parent process path does not match, we have reached the top of the tree of matching processes if process.Path != parentProcess.Path { - // save to storage - process.Save() - // return primary process + // found primary process + + // mark for use, save to storage + process.Lock() + if process.Virtual { + process.Virtual = false + go process.Save() + } + process.Unlock() + return process, nil } - // mark as virtual - process.Lock() - process.Virtual = true - process.Unlock() - - // save to storage - process.Save() - // continue up to process tree process = parentProcess } } // GetOrFindProcess returns the process for the given PID. -func GetOrFindProcess(pid int) (*Process, error) { +func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { + log.Tracer(ctx).Tracef("process: getting process for PID %d", pid) + if pid == -1 { return UnknownProcess, nil } @@ -143,17 +161,25 @@ func GetOrFindProcess(pid int) (*Process, error) { return OSProcess, nil } - p, err := loadProcess(pid) + p, err := loadProcess(ctx, pid) if err != nil { return nil, err } + if p.Error != "" { + return nil, fmt.Errorf("%s [cached error]", p.Error) + } - // save to storage - p.Save() + // mark for use, save to storage + p.Lock() + if p.Virtual { + p.Virtual = false + go p.Save() + } + p.Unlock() return p, nil } -func loadProcess(pid int) (*Process, error) { +func loadProcess(ctx context.Context, pid int) (*Process, error) { if pid == -1 { return UnknownProcess, nil } @@ -166,8 +192,39 @@ func loadProcess(pid int) (*Process, error) { return process, nil } + // dedup requests + dupReqLock.Lock() + mutex, requestActive := dupReqMap[pid] + if !requestActive { + mutex = new(sync.Mutex) + mutex.Lock() + dupReqMap[pid] = mutex + dupReqLock.Unlock() + } else { + dupReqLock.Unlock() + log.Tracer(ctx).Tracef("process: waiting for duplicate request for PID %d to complete", pid) + mutex.Lock() + // wait until duplicate request is finished, then fetch current Process and return + mutex.Unlock() + process, ok = GetProcessFromStorage(pid) + if ok { + return process, nil + } + return nil, fmt.Errorf("previous request for process with PID %d failed", pid) + } + + // lock request for this pid + defer func() { + dupReqLock.Lock() + delete(dupReqMap, pid) + dupReqLock.Unlock() + mutex.Unlock() + }() + + // create new process new := &Process{ - Pid: pid, + Pid: pid, + Virtual: true, // caller must decide to actually use the process - we need to save now. } switch { @@ -187,16 +244,15 @@ func loadProcess(pid int) (*Process, error) { var uids []int32 uids, err = pInfo.Uids() if err != nil { - log.Warningf("process: failed to get UID for p%d: %s", pid, err) - } else { - new.UserID = int(uids[0]) + return failedToLoad(new, fmt.Errorf("failed to get UID for p%d: %s", pid, err)) } + new.UserID = int(uids[0]) } // Username new.UserName, err = pInfo.Username() if err != nil { - log.Warningf("process: failed to get Username for p%d: %s", pid, err) + return failedToLoad(new, fmt.Errorf("process: failed to get Username for p%d: %s", pid, err)) } // TODO: User Home @@ -205,15 +261,14 @@ func loadProcess(pid int) (*Process, error) { // PPID ppid, err := pInfo.Ppid() if err != nil { - log.Warningf("process: failed to get PPID for p%d: %s", pid, err) - } else { - new.ParentPid = int(ppid) + return failedToLoad(new, fmt.Errorf("failed to get PPID for p%d: %s", pid, err)) } + new.ParentPid = int(ppid) // Path new.Path, err = pInfo.Exe() if err != nil { - log.Warningf("process: failed to get Path for p%d: %s", pid, err) + return failedToLoad(new, fmt.Errorf("failed to get Path for p%d: %s", pid, err)) } // Executable Name _, new.ExecName = filepath.Split(new.Path) @@ -228,18 +283,21 @@ func loadProcess(pid int) (*Process, error) { // Command line arguments new.CmdLine, err = pInfo.Cmdline() if err != nil { - log.Warningf("process: failed to get Cmdline for p%d: %s", pid, err) + return failedToLoad(new, fmt.Errorf("failed to get Cmdline for p%d: %s", pid, err)) } // Name new.Name, err = pInfo.Name() if err != nil { - log.Warningf("process: failed to get Name for p%d: %s", pid, err) + return failedToLoad(new, fmt.Errorf("failed to get Name for p%d: %s", pid, err)) } if new.Name == "" { new.Name = new.ExecName } + // OS specifics + new.specialOSInit() + // TODO: App Icon // new.Icon, err = @@ -317,5 +375,12 @@ func loadProcess(pid int) (*Process, error) { // } } + new.Save() return new, nil } + +func failedToLoad(p *Process, err error) (*Process, error) { + p.Error = err.Error() + p.Save() + return nil, err +} diff --git a/profile/active.go b/profile/active.go index e6b05c8a..fdc5ceb8 100644 --- a/profile/active.go +++ b/profile/active.go @@ -1,6 +1,7 @@ package profile import ( + "context" "sync" "github.com/Safing/portbase/log" @@ -11,13 +12,13 @@ var ( activeProfileSetsLock sync.RWMutex ) -func activateProfileSet(set *Set) { +func activateProfileSet(ctx context.Context, set *Set) { activeProfileSetsLock.Lock() defer activeProfileSetsLock.Unlock() set.Lock() defer set.Unlock() activeProfileSets[set.id] = set - log.Tracef("profile: activated profile set %s", set.id) + log.Tracer(ctx).Tracef("profile: activated profile set %s", set.id) } // DeactivateProfileSet marks a profile set as not active. diff --git a/profile/set.go b/profile/set.go index a1a5120c..c21b475b 100644 --- a/profile/set.go +++ b/profile/set.go @@ -1,6 +1,7 @@ package profile import ( + "context" "net" "sync" @@ -27,7 +28,7 @@ type Set struct { } // NewSet returns a new profile set with given the profiles. -func NewSet(id string, user, stamp *Profile) *Set { +func NewSet(ctx context.Context, id string, user, stamp *Profile) *Set { new := &Set{ id: id, profiles: [4]*Profile{ @@ -37,7 +38,7 @@ func NewSet(id string, user, stamp *Profile) *Set { nil, // Default }, } - activateProfileSet(new) + activateProfileSet(ctx, new) new.Update(status.SecurityLevelFortress) return new }