diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index 52488494..3613cf33 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -4,9 +4,12 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "syscall" + "time" "unsafe" + "github.com/Safing/portbase/log" "github.com/Safing/portmaster/network" "github.com/tevino/abool" "golang.org/x/sys/windows" @@ -16,11 +19,19 @@ import ( var ( ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands") - kext *WinKext - kextLock sync.RWMutex - ready = abool.NewBool(false) + winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA) + + kext *WinKext + kextLock sync.RWMutex + ready = abool.NewBool(false) + urgentRequests *int32 ) +func init() { + var urgentRequestsValue int32 + urgentRequests = &urgentRequestsValue +} + // WinKext holds the DLL handle. type WinKext struct { sync.RWMutex @@ -80,7 +91,7 @@ func Init(dllPath, driverPath string) error { // initialize dll/kext rc, _, lastErr := new.init.Call() if rc != windows.NO_ERROR { - return formatErr(lastErr) + return formatErr(lastErr, rc) } // set kext @@ -105,7 +116,7 @@ func Start() error { uintptr(unsafe.Pointer(&charArray[0])), ) if rc != windows.NO_ERROR { - return formatErr(lastErr) + return formatErr(lastErr, rc) } ready.Set() @@ -123,7 +134,7 @@ func Stop() error { rc, _, lastErr := kext.stop.Call() if rc != windows.NO_ERROR { - return formatErr(lastErr) + return formatErr(lastErr, rc) } return nil } @@ -138,14 +149,28 @@ func RecvVerdictRequest() (*VerdictRequest, error) { new := &VerdictRequest{} + // wait for urgent requests to complete + for i := 1; i <= 100; i++ { + if atomic.LoadInt32(urgentRequests) <= 0 { + break + } + if i == 100 { + log.Warningf("winkext: RecvVerdictRequest waited 100 times") + } + time.Sleep(100 * time.Microsecond) + } + + // timestamp := time.Now() rc, _, lastErr := kext.recvVerdictRequest.Call( uintptr(unsafe.Pointer(new)), ) - if rc != 0 { - if rc == 13 /* ERROR_INVALID_DATA */ { + // log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp)) + + if rc != windows.NO_ERROR { + if rc == winErrInvalidData { return nil, nil } - return nil, formatErr(lastErr) + return nil, formatErr(lastErr, rc) } return new, nil } @@ -158,12 +183,16 @@ func SetVerdict(packetID uint32, verdict network.Verdict) error { return ErrKextNotReady } + atomic.AddInt32(urgentRequests, 1) + // timestamp := time.Now() rc, _, lastErr := kext.setVerdict.Call( uintptr(packetID), uintptr(verdict), ) + // log.Tracef("winkext: settings verdict for packetID %d took %s", packetID, time.Now().Sub(timestamp)) + atomic.AddInt32(urgentRequests, -1) if rc != windows.NO_ERROR { - return formatErr(lastErr) + return formatErr(lastErr, rc) } return nil } @@ -178,13 +207,18 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) { buf := make([]byte, packetSize) + atomic.AddInt32(urgentRequests, 1) + // timestamp := time.Now() rc, _, lastErr := kext.getPayload.Call( uintptr(packetID), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&packetSize)), ) + // log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp)) + atomic.AddInt32(urgentRequests, -1) + if rc != windows.NO_ERROR { - return nil, formatErr(lastErr) + return nil, formatErr(lastErr, rc) } if packetSize == 0 { @@ -197,10 +231,10 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) { return buf, nil } -func formatErr(err error) error { +func formatErr(err error, rc uintptr) error { sysErr, ok := err.(syscall.Errno) if ok { - return fmt.Errorf("%s [0x%X]", err, uintptr(sysErr)) + return fmt.Errorf("%s [LE 0x%X] [RC 0x%X]", err, uintptr(sysErr), rc) } return err } diff --git a/firewall/interception/windowskext/packet.go b/firewall/interception/windowskext/packet.go index 1356bbaf..48f4f7de 100644 --- a/firewall/interception/windowskext/packet.go +++ b/firewall/interception/windowskext/packet.go @@ -31,6 +31,7 @@ func (pkt *Packet) GetPayload() ([]byte, error) { payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize) if err != nil { + log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to load payload %s", err) log.Errorf("windowskext: failed to load payload %s", err) return nil, packet.ErrFailedToLoadPayload } diff --git a/firewall/interception/windowskext/test/main.go b/firewall/interception/windowskext/test/main.go index f9040510..0fcc0208 100644 --- a/firewall/interception/windowskext/test/main.go +++ b/firewall/interception/windowskext/test/main.go @@ -68,7 +68,7 @@ func main() { // stop err = windowskext.Stop() if err != nil { - fmt.Printf("error stopping: %s\n", err) + panic(err) } log.Info("shutdown complete")