diff --git a/firewall/interception.go b/firewall/interception.go index bc4395a7..4999c5b3 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -149,6 +149,10 @@ func resetAllConnectionVerdicts() { // Save if verdict changed. if conn.Verdict.Firewall != previousVerdict { + err := interception.UpdateVerdictOfConnection(conn) + if err != nil { + log.Debugf("filter: failed to update connection verdict: %s", err) + } conn.Save() tracer.Infof("filter: verdict of connection %s changed from %s to %s", conn, previousVerdict.Verb(), conn.VerdictVerb()) changedVerdicts++ @@ -159,11 +163,6 @@ func resetAllConnectionVerdicts() { } tracer.Infof("filter: changed verdict on %d connections", changedVerdicts) tracer.Submit() - - err := interception.ResetVerdictOfAllConnections() - if err != nil { - log.Errorf("interception: failed to remove persistent verdicts: %s", err) - } } func interceptionStart() error { diff --git a/firewall/interception/interception_linux.go b/firewall/interception/interception_linux.go index 6fe38edf..be94e5b9 100644 --- a/firewall/interception/interception_linux.go +++ b/firewall/interception/interception_linux.go @@ -2,6 +2,7 @@ package interception import ( "github.com/safing/portmaster/firewall/interception/nfq" + "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/packet" ) @@ -19,3 +20,8 @@ func stop() error { func ResetVerdictOfAllConnections() error { return nfq.DeleteAllMarkedConnection() } + +// UpdateVerdictOfConnection deletes the verdict of specific connection so in can be initialized again with the next packet. +func UpdateVerdictOfConnection(conn *network.Connection) error { + return nfq.DeleteMarkedConnection(conn) +} diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index 382869be..7a29affe 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -4,22 +4,19 @@ import ( "fmt" "github.com/safing/portmaster/firewall/interception/windowskext" + "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/updates" ) // start starts the interception. func start(ch chan packet.Packet) error { - dllFile, err := updates.GetPlatformFile("kext/portmaster-kext.dll") - if err != nil { - return fmt.Errorf("interception: could not get kext dll: %s", err) - } kextFile, err := updates.GetPlatformFile("kext/portmaster-kext.sys") if err != nil { return fmt.Errorf("interception: could not get kext sys: %s", err) } - err = windowskext.Init(dllFile.Path(), kextFile.Path()) + err = windowskext.Init(kextFile.Path()) if err != nil { return fmt.Errorf("interception: could not init windows kext: %s", err) } @@ -43,3 +40,18 @@ func stop() error { func ResetVerdictOfAllConnections() error { return windowskext.ClearCache() } + +// UpdateVerdictOfConnection updates the verdict of specific connection in the kernel extension. +func UpdateVerdictOfConnection(conn *network.Connection) error { + return windowskext.UpdateVerdict(conn) +} + +// GetKextVersion returns the version of the kernel extension. +func GetKextVersion() (string, error) { + version, err := windowskext.GetVersion() + if err != nil { + return "", err + } + + return version.String(), nil +} diff --git a/firewall/interception/nfq/conntrack.go b/firewall/interception/nfq/conntrack.go index ac25728d..4de81036 100644 --- a/firewall/interception/nfq/conntrack.go +++ b/firewall/interception/nfq/conntrack.go @@ -4,20 +4,37 @@ package nfq import ( "encoding/binary" + "fmt" ct "github.com/florianl/go-conntrack" "github.com/safing/portbase/log" "github.com/safing/portmaster/netenv" + "github.com/safing/portmaster/network" ) -// DeleteAllMarkedConnection deletes all marked entries from the conntrack table. -func DeleteAllMarkedConnection() error { - nfct, err := ct.Open(&ct.Config{}) +var nfct *ct.Nfct // Conntrack handler. NFCT: Network Filter Connection Tracking. + +// InitNFCT initializes the network filter conntrack library. +func InitNFCT() error { + var err error + nfct, err = ct.Open(&ct.Config{}) if err != nil { return err } - defer func() { _ = nfct.Close() }() + return nil +} + +// TeardownNFCT deinitializes the network filter conntrack library. +func TeardownNFCT() { + _ = nfct.Close() +} + +// DeleteAllMarkedConnection deletes all marked entries from the conntrack table. +func DeleteAllMarkedConnection() error { + if nfct == nil { + return fmt.Errorf("nfq: nfct not initialized") + } // Delete all ipv4 marked connections deleted := deleteMarkedConnections(nfct, ct.IPv4) @@ -64,3 +81,43 @@ func deleteMarkedConnections(nfct *ct.Nfct, f ct.Family) (deleted int) { } return deleted } + +// DeleteMarkedConnection removes a specific connection from the conntrack table. +func DeleteMarkedConnection(conn *network.Connection) error { + if nfct == nil { + return fmt.Errorf("nfq: nfct not initialized") + } + + con := ct.Con{ + Origin: &ct.IPTuple{ + Src: &conn.LocalIP, + Dst: &conn.Entity.IP, + Proto: &ct.ProtoTuple{ + Number: &conn.Entity.Protocol, + SrcPort: &conn.LocalPort, + DstPort: &conn.Entity.Port, + }, + }, + } + connections, err := nfct.Get(ct.Conntrack, ct.IPv4, con) + if err != nil { + return fmt.Errorf("nfq: failed to find entry for connection %s: %w", conn.String(), err) + } + + if len(connections) > 1 { + log.Warningf("nfq: multiple entries found for single connection: %s -> %d", conn.String(), len(connections)) + } + + for _, connection := range connections { + deleteErr := nfct.Delete(ct.Conntrack, ct.IPv4, connection) + if err == nil { + err = deleteErr + } + } + + if err != nil { + log.Warningf("nfq: error while deleting conntrack entries for connection %s: %s", conn.String(), err) + } + + return nil +} diff --git a/firewall/interception/nfqueue_linux.go b/firewall/interception/nfqueue_linux.go index 488cc7a4..2ba1ab8f 100644 --- a/firewall/interception/nfqueue_linux.go +++ b/firewall/interception/nfqueue_linux.go @@ -147,6 +147,11 @@ func activateNfqueueFirewall() error { } } + if err := nfq.InitNFCT(); err != nil { + return err + } + _ = nfq.DeleteAllMarkedConnection() + return nil } @@ -166,6 +171,9 @@ func DeactivateNfqueueFirewall() error { } } + _ = nfq.DeleteAllMarkedConnection() + nfq.TeardownNFCT() + return result.ErrorOrNil() } diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 623c49a3..2932a31a 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package windowskext @@ -5,11 +6,14 @@ package windowskext import ( "encoding/binary" "errors" + "fmt" "net" + "unsafe" "github.com/tevino/abool" "github.com/safing/portbase/log" + "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/packet" ) @@ -25,6 +29,7 @@ const ( VerdictRequestFlagSocketAuth = 2 ) +// Do not change the order of the members! The structure is used to communicate with the kernel extension. // VerdictRequest is the request structure from the Kext. type VerdictRequest struct { id uint32 // ID from RegisterPacket @@ -43,16 +48,40 @@ type VerdictRequest struct { packetSize uint32 } +// Do not change the order of the members! The structure is used to communicate with the kernel extension. +type VerdictInfo struct { + id uint32 // ID from RegisterPacket + verdict network.Verdict // verdict for the connection +} + +// Do not change the order of the members! The structure to communicate with the kernel extension. +type VerdictUpdateInfo struct { + localIP [4]uint32 // Source Address, only srcIP[0] if IPv4 + remoteIP [4]uint32 // Destination Address + localPort uint16 // Source Port + remotePort uint16 // Destination port + ipV6 uint8 // True: IPv6, False: IPv4 + protocol uint8 // Protocol (UDP, TCP, ...) + verdict uint8 // New verdict +} + +type VersionInfo struct { + major uint8 + minor uint8 + revision uint8 + build uint8 +} + +func (v *VersionInfo) String() string { + return fmt.Sprintf("%d.%d.%d.%d", v.major, v.minor, v.revision, v.build) +} + // Handler transforms received packets to the Packet interface. func Handler(packets chan packet.Packet) { - if !ready.IsSet() { - return - } - defer close(packets) for { - if !ready.IsSet() { + if kextHandle == winInvalidHandleValue { return } @@ -145,3 +174,28 @@ func convertIPv6(input [4]uint32) net.IP { } return net.IP(addressBuf) } + +func ipAddressToArray(ip net.IP, isIPv6 bool) [4]uint32 { + array := [4]uint32{0} + if isIPv6 { + for i := 0; i < 4; i++ { + binary.BigEndian.PutUint32(asByteArrayWithLength(&array[i], 4), getUInt32Value(&ip[i])) + } + } else { + binary.BigEndian.PutUint32(asByteArrayWithLength(&array[0], 4), getUInt32Value(&ip[0])) + } + + return array +} + +func asByteArray[T any](obj *T) []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj)) +} + +func asByteArrayWithLength[T any](obj *T, size uint32) []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(obj)), size) +} + +func getUInt32Value[T any](obj *T) uint32 { + return *(*uint32)(unsafe.Pointer(obj)) +} diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index d438538e..e4ae511a 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -7,14 +7,13 @@ import ( "errors" "fmt" "sync" - "sync/atomic" "syscall" "time" "unsafe" "github.com/safing/portbase/log" "github.com/safing/portmaster/network" - "github.com/tevino/abool" + "github.com/safing/portmaster/network/packet" "golang.org/x/sys/windows" ) @@ -23,92 +22,23 @@ var ( ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands") ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension") - winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA) - - kext *WinKext - kextLock sync.RWMutex - ready = abool.NewBool(false) - urgentRequests *int32 -) - -func init() { - var urgentRequestsValue int32 - urgentRequests = &urgentRequestsValue -} - -// WinKext holds the DLL handle. -type WinKext struct { - sync.RWMutex - - dll *windows.DLL + kextLock sync.RWMutex driverPath string - init *windows.Proc - start *windows.Proc - stop *windows.Proc - recvVerdictRequest *windows.Proc - setVerdict *windows.Proc - getPayload *windows.Proc - clearCache *windows.Proc -} + kextHandle windows.Handle + service *KextService +) + +const ( + winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA) + winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value + driverName = "PortmasterKext" +) // Init initializes the DLL and the Kext (Kernel Driver). -func Init(dllPath, driverPath string) error { - - new := &WinKext{ - driverPath: driverPath, - } - - var err error - - // load dll - new.dll, err = windows.LoadDLL(dllPath) - if err != nil { - return err - } - - // load functions - new.init, err = new.dll.FindProc("PortmasterInit") - if err != nil { - return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err) - } - new.start, err = new.dll.FindProc("PortmasterStart") - if err != nil { - return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err) - } - new.stop, err = new.dll.FindProc("PortmasterStop") - if err != nil { - return fmt.Errorf("could not find proc PortmasterStop in dll: %s", err) - } - new.recvVerdictRequest, err = new.dll.FindProc("PortmasterRecvVerdictRequest") - if err != nil { - return fmt.Errorf("could not find proc PortmasterRecvVerdictRequest in dll: %s", err) - } - new.setVerdict, err = new.dll.FindProc("PortmasterSetVerdict") - if err != nil { - return fmt.Errorf("could not find proc PortmasterSetVerdict in dll: %s", err) - } - new.getPayload, err = new.dll.FindProc("PortmasterGetPayload") - if err != nil { - return fmt.Errorf("could not find proc PortmasterGetPayload in dll: %s", err) - } - new.clearCache, err = new.dll.FindProc("PortmasterClearCache") - if err != nil { - // the loaded dll is an old version - log.Errorf("could not find proc PortmasterClearCache (v1.0.12+) in dll: %s", err) - } - - // initialize dll/kext - rc, _, lastErr := new.init.Call() - if rc != windows.NO_ERROR { - return formatErr(lastErr, rc) - } - - // set kext - kextLock.Lock() - defer kextLock.Unlock() - kext = new - +func Init(path string) error { + kextHandle = winInvalidHandleValue + driverPath = path return nil } @@ -117,19 +47,33 @@ func Start() error { kextLock.Lock() defer kextLock.Unlock() - // convert to C string - charArray := make([]byte, len(kext.driverPath)+1) - copy(charArray, []byte(kext.driverPath)) - charArray[len(charArray)-1] = 0 // force NULL byte at the end - - rc, _, lastErr := kext.start.Call( - uintptr(unsafe.Pointer(&charArray[0])), - ) - if rc != windows.NO_ERROR { - return formatErr(lastErr, rc) + // initialize and start driver service + var err error + service, err = createKextService(driverName, driverPath) + if err != nil { + return fmt.Errorf("failed to create service: %w", err) + } + + running, err := service.isRunning() + if err == nil && !running { + err = service.start(true) + + if err != nil { + return fmt.Errorf("failed to start service: %w", err) + } + } else if err != nil { + return fmt.Errorf("service not initialized: %w", err) + } + + // Open the driver + filename := `\\.\` + driverName + kextHandle, err = openDriver(filename) + + // driver was not installed + if err != nil { + return fmt.Errorf("failed to open driver: %q %w", filename, err) } - ready.Set() return nil } @@ -137,15 +81,27 @@ func Start() error { func Stop() error { kextLock.Lock() defer kextLock.Unlock() - if !ready.IsSet() { - return ErrKextNotReady - } - ready.UnSet() - rc, _, lastErr := kext.stop.Call() - if rc != windows.NO_ERROR { - return formatErr(lastErr, rc) + err := closeDriver(kextHandle) + if err != nil { + log.Warningf("winkext: failed to close the handle: %s", err) } + + err = service.stop(true) + if err != nil { + log.Warningf("winkext: failed to stop service: %s", err) + } + // Driver file may change on the next start so it's better to delete the service + err = service.delete() + if err != nil { + log.Warningf("winkext: failed to delete service: %s", err) + } + err = service.closeHandle() + if err != nil { + log.Warningf("winkext: failed to close the handle: %s", err) + } + + kextHandle = winInvalidHandleValue return nil } @@ -153,36 +109,26 @@ func Stop() error { func RecvVerdictRequest() (*VerdictRequest, error) { kextLock.RLock() defer kextLock.RUnlock() - if !ready.IsSet() { + if kextHandle == winInvalidHandleValue { return nil, ErrKextNotReady } - new := &VerdictRequest{} + timestamp := time.Now() + defer log.Tracef("winkext: getting verdict request took %s", time.Since(timestamp)) + // Initialize struct for the output data + var new VerdictRequest - // wait for urgent requests to complete - for i := 1; i <= 100; i++ { - if atomic.LoadInt32(urgentRequests) <= 0 { - break - } - if i == 100 { - log.Warningf("winkext: RecvVerdictRequest waited 100 times") - } - time.Sleep(100 * time.Microsecond) + // Make driver request + data := asByteArray(&new) + bytesRead, err := deviceIOControl(kextHandle, IOCTL_RECV_VERDICT_REQ, nil, data) + if err != nil { + return nil, err + } + if bytesRead == 0 { + return nil, nil // no error, no new verdict request } - // timestamp := time.Now() - rc, _, lastErr := kext.recvVerdictRequest.Call( - uintptr(unsafe.Pointer(new)), - ) - // log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp)) - - if rc != windows.NO_ERROR { - if rc == winErrInvalidData { - return nil, nil - } - return nil, formatErr(lastErr, rc) - } - return new, nil + return &new, nil } // SetVerdict sets the verdict for a packet and/or connection. @@ -194,22 +140,19 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error { kextLock.RLock() defer kextLock.RUnlock() - if !ready.IsSet() { + if kextHandle == winInvalidHandleValue { log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: kext not ready", verdict) return ErrKextNotReady } - atomic.AddInt32(urgentRequests, 1) - // timestamp := time.Now() - rc, _, lastErr := kext.setVerdict.Call( - uintptr(pkt.verdictRequest.id), - uintptr(verdict), - ) - // log.Tracef("winkext: settings verdict for packetID %d took %s", packetID, time.Now().Sub(timestamp)) - atomic.AddInt32(urgentRequests, -1) - if rc != windows.NO_ERROR { + verdictInfo := VerdictInfo{pkt.verdictRequest.id, verdict} + + // Make driver request + data := asByteArray(&verdictInfo) + _, err := deviceIOControl(kextHandle, IOCTL_SET_VERDICT, data, nil) + if err != nil { log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id) - return formatErr(lastErr, rc) + return err } return nil } @@ -220,34 +163,36 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) { return nil, ErrNoPacketID } + // Check if driver is initialized kextLock.RLock() defer kextLock.RUnlock() - if !ready.IsSet() { + if kextHandle == winInvalidHandleValue { return nil, ErrKextNotReady } buf := make([]byte, packetSize) - atomic.AddInt32(urgentRequests, 1) - // timestamp := time.Now() - rc, _, lastErr := kext.getPayload.Call( - uintptr(packetID), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&packetSize)), - ) - // log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp)) - atomic.AddInt32(urgentRequests, -1) + // Combine id and length + payload := struct { + id uint32 + length uint32 + }{packetID, packetSize} - if rc != windows.NO_ERROR { - return nil, formatErr(lastErr, rc) + // Make driver request + data := asByteArray(&payload) + bytesRead, err := deviceIOControl(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize)) + + if err != nil { + return nil, err } - if packetSize == 0 { + // check the result and return + if bytesRead == 0 { return nil, errors.New("windows kext did not return any data") } - if packetSize < uint32(len(buf)) { - return buf[:packetSize], nil + if bytesRead < uint32(len(buf)) { + return buf[:bytesRead], nil } return buf, nil @@ -256,28 +201,94 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) { func ClearCache() error { kextLock.RLock() defer kextLock.RUnlock() - if !ready.IsSet() { + + // Check if driver is initialized + if kextHandle == winInvalidHandleValue { log.Error("kext: failed to clear the cache: kext not ready") return ErrKextNotReady } - if kext.clearCache == nil { - log.Error("kext: cannot clear cache: clearCache function missing") - } - - rc, _, lastErr := kext.clearCache.Call() - - if rc != windows.NO_ERROR { - return formatErr(lastErr, rc) - } - - return nil -} - -func formatErr(err error, rc uintptr) error { - sysErr, ok := err.(syscall.Errno) - if ok { - return fmt.Errorf("%s [LE 0x%X] [RC 0x%X]", err, uintptr(sysErr), rc) - } + // Make driver request + _, err := deviceIOControl(kextHandle, IOCTL_CLEAR_CACHE, nil, nil) return err } + +func UpdateVerdict(conn *network.Connection) error { + kextLock.RLock() + defer kextLock.RUnlock() + + // Check if driver is initialized + if kextHandle == winInvalidHandleValue { + log.Error("kext: failed to clear the cache: kext not ready") + return ErrKextNotReady + } + + var isIpv6 uint8 = 0 + if conn.IPVersion == packet.IPv6 { + isIpv6 = 1 + } + + // initialize variables + info := VerdictUpdateInfo{ + ipV6: isIpv6, + protocol: uint8(conn.IPProtocol), + localIP: ipAddressToArray(conn.LocalIP, isIpv6 == 1), + localPort: conn.LocalPort, + remoteIP: ipAddressToArray(conn.Entity.IP, isIpv6 == 1), + remotePort: conn.Entity.Port, + verdict: uint8(conn.Verdict.Active), + } + + // Make driver request + data := asByteArray(&info) + err := deviceIoControlDirect(kextHandle, IOCTL_UPDATE_VERDICT, data) + return err +} + +func GetVersion() (*VersionInfo, error) { + kextLock.RLock() + defer kextLock.RUnlock() + + // Check if driver is initialized + if kextHandle == winInvalidHandleValue { + log.Error("kext: failed to clear the cache: kext not ready") + return nil, ErrKextNotReady + } + + data := make([]uint8, 4) + err := deviceIoControlDirect(kextHandle, IOCTL_VERSION, data) + + if err != nil { + return nil, err + } + + version := &VersionInfo{ + major: data[0], + minor: data[1], + revision: data[2], + build: data[3], + } + return version, nil +} + +func openDriver(filename string) (windows.Handle, error) { + u16filename, err := syscall.UTF16FromString(filename) + if err != nil { + return winInvalidHandleValue, fmt.Errorf("failed to convert driver filename to UTF16 string %w", err) + } + + handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0) + if err != nil { + return winInvalidHandleValue, err + } + + return handle, nil +} + +func closeDriver(handle windows.Handle) error { + if kextHandle == winInvalidHandleValue { + return ErrKextNotReady + } + + return windows.CloseHandle(handle) +} diff --git a/firewall/interception/windowskext/service.go b/firewall/interception/windowskext/service.go new file mode 100644 index 00000000..ed4429e8 --- /dev/null +++ b/firewall/interception/windowskext/service.go @@ -0,0 +1,161 @@ +//go:build windows +// +build windows + +package windowskext + +import ( + "fmt" + "syscall" + "time" + + "golang.org/x/sys/windows" +) + +type KextService struct { + handle windows.Handle +} + +func createKextService(driverName string, driverPath string) (*KextService, error) { + // Open the service manager: + manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS) + if err != nil { + return nil, fmt.Errorf("failed to open service manager: %d", err) + } + defer windows.CloseServiceHandle(manager) + + driverNameU16, err := syscall.UTF16FromString(driverName) + if err != nil { + return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err) + } + // Check if it's already created + service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS) + if err == nil { + return &KextService{handle: service}, nil // service was already created + } + + driverPathU16, err := syscall.UTF16FromString(driverPath) + + // Create the service + service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil) + if err != nil { + return nil, err + } + + return &KextService{handle: service}, nil +} + +func (s *KextService) isValid() bool { + return s != nil && s.handle != winInvalidHandleValue && s.handle != 0 +} + +func (s *KextService) isRunning() (bool, error) { + if !s.isValid() { + return false, fmt.Errorf("kext service not initialized") + } + var status windows.SERVICE_STATUS + err := windows.QueryServiceStatus(s.handle, &status) + if err != nil { + return false, err + } + return status.CurrentState == windows.SERVICE_RUNNING, nil +} + +func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) { + var status windows.SERVICE_STATUS + status.CurrentState = windows.SERVICE_NO_CHANGE + start := time.Now() + for status.CurrentState == neededStatus { + err := windows.QueryServiceStatus(handle, &status) + if err != nil { + return false, fmt.Errorf("failed while waiting for service to start: %w", err) + } + + if time.Since(start) > timeLimit { + return false, fmt.Errorf("time limit reached") + } + + // Sleep for 1/10 of the wait hint, recommended time from microsoft + time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond) + } + + return true, nil +} + +func (s *KextService) start(wait bool) error { + if !s.isValid() { + return fmt.Errorf("kext service not initialized") + } + + // Start the service: + err := windows.StartService(s.handle, 0, nil) + + if err != nil { + err = windows.GetLastError() + if err != windows.ERROR_SERVICE_ALREADY_RUNNING { + // Failed to start service; clean-up: + var status windows.SERVICE_STATUS + _ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status) + _ = windows.DeleteService(s.handle) + _ = windows.CloseServiceHandle(s.handle) + s.handle = winInvalidHandleValue + return err + } + } + + // Wait for service to start + if wait { + success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second)) + if err != nil || !success { + return fmt.Errorf("service did not start: %w", err) + } + } + + return nil +} + +func (s *KextService) stop(wait bool) error { + if !s.isValid() { + return fmt.Errorf("kext service not initialized") + } + + // Stop the service + var status windows.SERVICE_STATUS + err := windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status) + if err != nil { + return fmt.Errorf("service failed to stop: %w", err) + } + + // Wait for service to stop + if wait { + success, err := waitForServiceStatus(s.handle, windows.SERVICE_STOPPED, time.Duration(10*time.Second)) + if err != nil || !success { + return fmt.Errorf("service did not stop: %w", err) + } + } + + return nil +} + +func (s *KextService) delete() error { + if !s.isValid() { + return fmt.Errorf("kext service not initialized") + } + + err := windows.DeleteService(s.handle) + if err != nil { + return fmt.Errorf("failed to delete service: %s", err) + } + return nil +} + +func (s *KextService) closeHandle() error { + if !s.isValid() { + return fmt.Errorf("kext service not initialized") + } + + err := windows.CloseServiceHandle(s.handle) + if err != nil { + return fmt.Errorf("failed to close service handle: %s", err) + } + return nil +} diff --git a/firewall/interception/windowskext/syscall.go b/firewall/interception/windowskext/syscall.go new file mode 100644 index 00000000..2b85adf5 --- /dev/null +++ b/firewall/interception/windowskext/syscall.go @@ -0,0 +1,89 @@ +//go:build windows +// +build windows + +package windowskext + +import "golang.org/x/sys/windows" + +const ( + METHOD_BUFFERED = 0 + METHOD_IN_DIRECT = 1 + METHOD_OUT_DIRECT = 2 + METHOD_NEITHER = 3 + + SIOCTL_TYPE = 40000 +) + +var ( + IOCTL_VERSION = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_RECV_VERDICT_REQ_POLL = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) // Not used + IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_UPDATE_VERDICT = ctlCode(SIOCTL_TYPE, 0x806, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) +) + +func ctlCode(device_type, function, method, access uint32) uint32 { + return (device_type << 16) | (access << 14) | (function << 2) | method +} + +func deviceIOControlAsync(handle windows.Handle, code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) { + var inDataPtr *byte = nil + var inDataSize uint32 = 0 + if inData != nil { + inDataPtr = &inData[0] + inDataSize = uint32(len(inData)) + } + + var outDataPtr *byte = nil + var outDataSize uint32 = 0 + if outData != nil { + outDataPtr = &outData[0] + outDataSize = uint32(len(outData)) + } + + overlapped := &windows.Overlapped{} + err := windows.DeviceIoControl(handle, + code, + inDataPtr, inDataSize, + outDataPtr, outDataSize, + nil, overlapped) + + if err != nil { + return nil, err + } + + return overlapped, nil + +} + +func deviceIOControl(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) { + overlapped, err := deviceIOControlAsync(handle, code, inData, outData) + if err != nil { + return 0, err + } + + var bytesReturned uint32 + err = windows.GetOverlappedResult(handle, overlapped, &bytesReturned, true) + + return bytesReturned, err +} + +// Use for METHOD_NEITHER IOCTL, the data buffer is passed directly to the kernel +func deviceIoControlDirect(handle windows.Handle, code uint32, data []byte) error { + var dataPtr *byte = nil + var dataSize uint32 = 0 + if data != nil { + dataPtr = &data[0] + dataSize = uint32(len(data)) + } + + err := windows.DeviceIoControl(handle, + code, + dataPtr, dataSize, + nil, 0, + nil, nil) + + return err +} diff --git a/updates/helper/updates.go b/updates/helper/updates.go index 17fe6116..80193ee8 100644 --- a/updates/helper/updates.go +++ b/updates/helper/updates.go @@ -52,7 +52,6 @@ func MandatoryUpdates() (identifiers []string) { identifiers = append( identifiers, PlatformIdentifier("core/portmaster-core.exe"), - PlatformIdentifier("kext/portmaster-kext.dll"), PlatformIdentifier("kext/portmaster-kext.sys"), PlatformIdentifier("kext/portmaster-kext.pdb"), PlatformIdentifier("start/portmaster-start.exe"),