diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 53ca7582..cbb5ac3c 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -2,7 +2,6 @@ package windowskext import ( "encoding/binary" - "fmt" "net" "github.com/tevino/abool" @@ -23,7 +22,7 @@ type VerdictRequest struct { remoteIP [4]uint32 /* Destination Address */ localPort uint16 /* Source Port */ remotePort uint16 /* Destination port */ - compartmentId uint32 + compartmentID uint32 interfaceIndex uint32 subInterfaceIndex uint32 packetSize uint32 diff --git a/process/database.go b/process/database.go index 4f43785f..c5d0783f 100644 --- a/process/database.go +++ b/process/database.go @@ -94,13 +94,27 @@ func CleanProcessStorage(thresholdDuration time.Duration) { defer processesLock.Unlock() threshold := time.Now().Add(-thresholdDuration).Unix() + + // clean primary processes for _, p := range processes { p.Lock() - if p.FirstCommEstablished < threshold && p.CommCount == 0 { + if !p.Virtual && p.LastCommEstablished < threshold && p.CommCount == 0 { go p.Delete() } p.Unlock() } + + // clean virtual processes + for _, p := range processes { + p.Lock() + if p.Virtual { + _, parentIsAlive := processes[p.ParentPid] + if !parentIsAlive { + go p.Delete() + } + } + p.Unlock() + } } // SetDBController sets the database controller and allows the package to push database updates on a save. It must be set by the package that registers the "network" database. diff --git a/process/find.go b/process/find.go index dab6ef8e..66ee99af 100644 --- a/process/find.go +++ b/process/find.go @@ -65,7 +65,7 @@ func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, er return nil, direction, ErrConnectionNotFound } - process, err = GetOrFindProcess(pid) + process, err = GetOrFindPrimaryProcess(pid) if err != nil { return nil, direction, err } @@ -114,7 +114,7 @@ func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, re return nil, ErrConnectionNotFound } - process, err = GetOrFindProcess(pid) + process, err = GetOrFindPrimaryProcess(pid) if err != nil { return nil, err } diff --git a/process/process.go b/process/process.go index 5496d911..143a225e 100644 --- a/process/process.go +++ b/process/process.go @@ -4,8 +4,8 @@ package process import ( "fmt" + "path/filepath" "runtime" - "strings" "sync" "time" @@ -45,6 +45,7 @@ type Process struct { FirstCommEstablished int64 LastCommEstablished int64 CommCount uint + Virtual bool // This process is merged into another process } // ProfileSet returns the assigned profile set. @@ -88,8 +89,54 @@ func (p *Process) RemoveCommunication() { } } +// GetOrFindPrimaryProcess returns the highest process in the tree that matches the given PID. +func GetOrFindPrimaryProcess(pid int) (*Process, error) { + process, err := loadProcess(pid) + if err != nil { + return nil, err + } + + for { + parentProcess, err := loadProcess(process.ParentPid) + if err != nil { + log.Tracef("process: could not get parent (%d): %s", process.Pid, err) + return process, nil + } + + // parent process does not match, we reached the top of the tree of matching processes + if process.Path != parentProcess.Path { + // save to storage + process.Save() + // return primary process + return process, nil + } + + // mark as virtual + process.Lock() + process.Virtual = true + process.Unlock() + + // save to storage + process.Save() + + // continue up to process tree + process = parentProcess + } +} + // GetOrFindProcess returns the process for the given PID. func GetOrFindProcess(pid int) (*Process, error) { + p, err := loadProcess(pid) + if err != nil { + return nil, err + } + + // save to storage + p.Save() + return p, nil +} + +func loadProcess(pid int) (*Process, error) { process, ok := GetProcessFromStorage(pid) if ok { return process, nil @@ -116,7 +163,7 @@ func GetOrFindProcess(pid int) (*Process, error) { var uids []int32 uids, err = pInfo.Uids() if err != nil { - log.Warningf("process: failed to get UID: %s", err) + log.Warningf("process: failed to get UID for p%d: %s", pid, err) } else { new.UserID = int(uids[0]) } @@ -125,7 +172,7 @@ func GetOrFindProcess(pid int) (*Process, error) { // Username new.UserName, err = pInfo.Username() if err != nil { - log.Warningf("process: failed to get Username: %s", err) + log.Warningf("process: failed to get Username for p%d: %s", pid, err) } // TODO: User Home @@ -134,7 +181,7 @@ func GetOrFindProcess(pid int) (*Process, error) { // PPID ppid, err := pInfo.Ppid() if err != nil { - log.Warningf("process: failed to get PPID: %s", err) + log.Warningf("process: failed to get PPID for p%d: %s", pid, err) } else { new.ParentPid = int(ppid) } @@ -142,8 +189,10 @@ func GetOrFindProcess(pid int) (*Process, error) { // Path new.Path, err = pInfo.Exe() if err != nil { - log.Warningf("process: failed to get Path: %s", err) + log.Warningf("process: failed to get Path for p%d: %s", pid, err) } + // Executable Name + _, new.ExecName = filepath.Split(new.Path) // Current working directory // net yet implemented for windows @@ -155,13 +204,16 @@ func GetOrFindProcess(pid int) (*Process, error) { // Command line arguments new.CmdLine, err = pInfo.Cmdline() if err != nil { - log.Warningf("process: failed to get Cmdline: %s", err) + log.Warningf("process: failed to get Cmdline for p%d: %s", pid, err) } // Name new.Name, err = pInfo.Name() if err != nil { - log.Warningf("process: failed to get Name: %s", err) + log.Warningf("process: failed to get Name for p%d: %s", pid, err) + } + if new.Name == "" { + new.Name = new.ExecName } // TODO: App Icon @@ -239,16 +291,7 @@ func GetOrFindProcess(pid int) (*Process, error) { // } // } // } - - // Executable Information - - // FIXME: use os specific path seperator - splittedPath := strings.Split(new.Path, "/") - new.ExecName = splittedPath[len(splittedPath)-1] } - // save to storage - new.Save() - return new, nil }