diff --git a/firewall/master.go b/firewall/master.go index c3d33001..0d1fa4ea 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -150,10 +150,10 @@ func checkSelfCommunication(ctx context.Context, conn *network.Connection, pkt p log.Tracer(ctx).Warningf("filter: failed to find local peer process PID: %s", err) } else { // get primary process - otherProcess, err := process.GetOrFindPrimaryProcess(ctx, otherPid) + otherProcess, err := process.GetOrFindProcess(ctx, otherPid) if err != nil { log.Tracer(ctx).Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err) - } else if otherProcess.Pid == conn.Process().Pid { + } else if otherProcess.Path == conn.Process().Path { conn.Accept("process internal connection", noReasonOptionKey) conn.Internal = true return true diff --git a/process/find.go b/process/find.go index add5d484..dcd7a052 100644 --- a/process/find.go +++ b/process/find.go @@ -24,7 +24,7 @@ func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process return nil, pktInfo.Inbound, err } - process, err = GetOrFindPrimaryProcess(ctx, pid) + process, err = GetOrFindProcess(ctx, pid) if err != nil { log.Tracer(ctx).Debugf("process: failed to find (primary) process with PID: %s", err) return nil, connInbound, err diff --git a/process/process.go b/process/process.go index e99d0b3b..6a2683a1 100644 --- a/process/process.go +++ b/process/process.go @@ -2,14 +2,17 @@ package process import ( "context" + "errors" "fmt" "path/filepath" "runtime" + "strconv" "strings" "sync" "time" processInfo "github.com/shirou/gopsutil/process" + "golang.org/x/sync/singleflight" "github.com/safing/portbase/database/record" "github.com/safing/portbase/log" @@ -20,10 +23,7 @@ const ( onLinux = runtime.GOOS == "linux" ) -var ( - dupReqMap = make(map[int]*sync.WaitGroup) - dupReqLock sync.Mutex -) +var getProcessSingleInflight singleflight.Group // A Process represents a process running on the operating system type Process struct { @@ -51,7 +51,6 @@ type Process struct { FirstSeen int64 LastSeen int64 - Virtual bool // This process is either merged into another process or is not needed. Error string // Cache errors ExecHashes map[string]string @@ -91,117 +90,21 @@ func (p *Process) String() string { return fmt.Sprintf("%s:%s:%d", p.UserName, p.Path, p.Pid) } -// GetOrFindPrimaryProcess returns the highest process in the tree that matches the given PID. -func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { - log.Tracer(ctx).Tracef("process: getting primary process for PID %d", pid) - - switch pid { - case UnidentifiedProcessID: - return GetUnidentifiedProcess(ctx), nil - case SystemProcessID: - return GetSystemProcess(ctx), nil - } - - process, err := loadProcess(ctx, pid) - if err != nil { - return nil, err - } - - for { - if process.ParentPid <= 0 { - return process, nil - } - parentProcess, err := loadProcess(ctx, process.ParentPid) - if err != nil { - log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s", process.Pid, process.ParentPid, err) - saveFailedProcess(process.ParentPid, err.Error()) - return process, nil - } - - // if parent process path does not match, we have reached the top of the tree of matching processes - if process.Path != parentProcess.Path { - // found primary process - - // mark for use, save to storage - process.Lock() - if process.Virtual { - process.Virtual = false - go process.Save() - } - process.Unlock() - - return process, nil - } - - // continue up to process tree - process = parentProcess - } -} - // GetOrFindProcess returns the process for the given PID. func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { log.Tracer(ctx).Tracef("process: getting process for PID %d", pid) - switch pid { - case UnidentifiedProcessID: - return GetUnidentifiedProcess(ctx), nil - case SystemProcessID: - return GetSystemProcess(ctx), nil - } - - p, err := loadProcess(ctx, pid) + p, err, _ := getProcessSingleInflight.Do(strconv.Itoa(pid), func() (interface{}, error) { + return loadProcess(ctx, pid) + }) if err != nil { return nil, err } - - // mark for use, save to storage - p.Lock() - if p.Virtual { - p.Virtual = false - go p.Save() - } - p.Unlock() - return p, nil -} - -func deduplicateRequest(ctx context.Context, pid int) (finishRequest func()) { - dupReqLock.Lock() - - // get duplicate request waitgroup - wg, requestActive := dupReqMap[pid] - - // someone else is already on it! - if requestActive { - dupReqLock.Unlock() - - // log that we are waiting - log.Tracer(ctx).Tracef("intel: waiting for duplicate request for PID %d to complete", pid) - // wait - wg.Wait() - // done! - return nil + if p == nil { + return nil, errors.New("process getter returned nil") } - // we are currently the only one doing a request for this - - // create new waitgroup - wg = new(sync.WaitGroup) - // add worker (us!) - wg.Add(1) - // add to registry - dupReqMap[pid] = wg - - dupReqLock.Unlock() - - // return function to mark request as finished - return func() { - dupReqLock.Lock() - defer dupReqLock.Unlock() - // mark request as done - wg.Done() - // delete from registry - delete(dupReqMap, pid) - } + return p.(*Process), nil } func loadProcess(ctx context.Context, pid int) (*Process, error) { @@ -218,24 +121,9 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { return process, nil } - // dedupe! - markRequestFinished := deduplicateRequest(ctx, pid) - if markRequestFinished == nil { - // we waited for another request, recheck the storage! - process, ok = GetProcessFromStorage(pid) - if ok { - return process, nil - } - // if cache is still empty, go ahead - } else { - // we are the first! - defer markRequestFinished() - } - // Create new a process object. new := &Process{ Pid: pid, - Virtual: true, // caller must decide to actually use the process - we need to save now. FirstSeen: time.Now().Unix(), } @@ -312,14 +200,3 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { new.Save() return new, nil } - -func saveFailedProcess(pid int, err string) { - failed := &Process{ - Pid: pid, - FirstSeen: time.Now().Unix(), - Virtual: true, // not needed - Error: err, - } - - failed.Save() -}