diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index 6a1c28b3..be5b8c4a 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -105,26 +105,25 @@ func startInterception(packets chan packet.Packet) error { } }) - // Start kext logging. The worker will periodically send request to the kext to print memory stats. - // module.StartServiceWorker("kext memory stats request worker", 0, func(ctx context.Context) error { - // timer := time.NewTicker(20 * time.Second) - // for { - // select { - // case <-timer.C: - // { - // err := kext2.SendPrintMemoryStatsCommand() - // if err != nil { - // return err - // } - // } - // case <-ctx.Done(): - // { - // return nil - // } - // } + module.StartServiceWorker("kext clean ended connection worker", 0, func(ctx context.Context) error { + timer := time.NewTicker(30 * time.Second) + for { + select { + case <-timer.C: + { + err := kext2.SendCleanEndedConnection() + if err != nil { + return err + } + } + case <-ctx.Done(): + { + return nil + } + } - // } - // }) + } + }) } return nil diff --git a/firewall/interception/windowskext/service.go b/firewall/interception/windowskext/service.go index 420a4196..facba765 100644 --- a/firewall/interception/windowskext/service.go +++ b/firewall/interception/windowskext/service.go @@ -24,6 +24,7 @@ func createKextService(driverName string, driverPath string) (*KextService, erro } defer windows.CloseServiceHandle(manager) + // Convert the driver name to a UTF16 string driverNameU16, err := syscall.UTF16FromString(driverName) if err != nil { return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err) @@ -47,103 +48,103 @@ func createKextService(driverName string, driverPath string) (*KextService, erro // Create the service service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil) - if err != nil { - return nil, err - } - - return &KextService{handle: service}, nil -} - -func deleteService(manager windows.Handle, service *KextService, driverName []uint16) error { - // Stop and wait before deleting - _ = service.stop(true) - - // Try to delete even if stop failed - err := service.delete() - if err != nil { - return fmt.Errorf("failed to delete old service: %s", err) - } - - // Wait until we can no longer open the old service. - // Not very efficient but NotifyServiceStatusChange cannot be used with driver service. - start := time.Now() - timeLimit := time.Duration(30 * time.Second) - for { - handle, err := windows.OpenService(manager, &driverName[0], windows.SERVICE_ALL_ACCESS) - if err != nil { - break - } - _ = windows.CloseServiceHandle(handle) - - if time.Since(start) > timeLimit { - return fmt.Errorf("time limit reached") - } - - time.Sleep(100 * time.Millisecond) - } - return nil -} - -func (s *KextService) isValid() bool { - return s != nil && s.handle != winInvalidHandleValue && s.handle != 0 -} - -func (s *KextService) isRunning() (bool, error) { - if !s.isValid() { - return false, fmt.Errorf("kext service not initialized") - } - var status windows.SERVICE_STATUS - err := windows.QueryServiceStatus(s.handle, &status) - if err != nil { - return false, err - } - return status.CurrentState == windows.SERVICE_RUNNING, nil -} - -func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) { - var status windows.SERVICE_STATUS - status.CurrentState = windows.SERVICE_NO_CHANGE - start := time.Now() - for status.CurrentState == neededStatus { - err := windows.QueryServiceStatus(handle, &status) - if err != nil { - return false, fmt.Errorf("failed while waiting for service to start: %w", err) - } - - if time.Since(start) > timeLimit { - return false, fmt.Errorf("time limit reached") - } - - // Sleep for 1/10 of the wait hint, recommended time from microsoft - time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond) - } - - return true, nil -} - -func (s *KextService) start(wait bool) error { - if !s.isValid() { - return fmt.Errorf("kext service not initialized") - } - - // Start the service: - err := windows.StartService(s.handle, 0, nil) - - if err != nil { - err = windows.GetLastError() - if err != windows.ERROR_SERVICE_ALREADY_RUNNING { - // Failed to start service; clean-up: - var status windows.SERVICE_STATUS - _ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status) - _ = windows.DeleteService(s.handle) - _ = windows.CloseServiceHandle(s.handle) - s.handle = winInvalidHandleValue - return err - } - } - - // Wait for service to start - if wait { + if err != nil { + return nil, err + } + + return &KextService{handle: service}, nil +} + +func deleteService(manager windows.Handle, service *KextService, driverName []uint16) error { + // Stop and wait before deleting + _ = service.stop(true) + + // Try to delete even if stop failed + err := service.delete() + if err != nil { + return fmt.Errorf("failed to delete old service: %s", err) + } + + // Wait until we can no longer open the old service. + // Not very efficient but NotifyServiceStatusChange cannot be used with driver service. + start := time.Now() + timeLimit := time.Duration(30 * time.Second) + for { + handle, err := windows.OpenService(manager, &driverName[0], windows.SERVICE_ALL_ACCESS) + if err != nil { + break + } + _ = windows.CloseServiceHandle(handle) + + if time.Since(start) > timeLimit { + return fmt.Errorf("time limit reached") + } + + time.Sleep(100 * time.Millisecond) + } + return nil +} + +func (s *KextService) isValid() bool { + return s != nil && s.handle != winInvalidHandleValue && s.handle != 0 +} + +func (s *KextService) isRunning() (bool, error) { + if !s.isValid() { + return false, fmt.Errorf("kext service not initialized") + } + var status windows.SERVICE_STATUS + err := windows.QueryServiceStatus(s.handle, &status) + if err != nil { + return false, err + } + return status.CurrentState == windows.SERVICE_RUNNING, nil +} + +func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) { + var status windows.SERVICE_STATUS + status.CurrentState = windows.SERVICE_NO_CHANGE + start := time.Now() + for status.CurrentState == neededStatus { + err := windows.QueryServiceStatus(handle, &status) + if err != nil { + return false, fmt.Errorf("failed while waiting for service to start: %w", err) + } + + if time.Since(start) > timeLimit { + return false, fmt.Errorf("time limit reached") + } + + // Sleep for 1/10 of the wait hint, recommended time from microsoft + time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond) + } + + return true, nil +} + +func (s *KextService) start(wait bool) error { + if !s.isValid() { + return fmt.Errorf("kext service not initialized") + } + + // Start the service: + err := windows.StartService(s.handle, 0, nil) + + if err != nil { + err = windows.GetLastError() + if err != windows.ERROR_SERVICE_ALREADY_RUNNING { + // Failed to start service; clean-up: + var status windows.SERVICE_STATUS + _ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status) + _ = windows.DeleteService(s.handle) + _ = windows.CloseServiceHandle(s.handle) + s.handle = winInvalidHandleValue + return err + } + } + + // Wait for service to start + if wait { success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second)) if err != nil || !success { return fmt.Errorf("service did not start: %w", err) diff --git a/firewall/interception/windowskext2/handler.go b/firewall/interception/windowskext2/handler.go index 37978023..1d10fa79 100644 --- a/firewall/interception/windowskext2/handler.go +++ b/firewall/interception/windowskext2/handler.go @@ -45,8 +45,10 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch // New Packet new := &Packet{ verdictRequest: conn.Id, + payload: conn.Payload, verdictSet: abool.NewBool(false), } + new.Base.Payload() info := new.Info() info.Inbound = conn.Direction > 0 info.InTunnel = false @@ -95,6 +97,7 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch // New Packet new := &Packet{ verdictRequest: conn.Id, + payload: conn.Payload, verdictSet: abool.NewBool(false), } info := new.Info() diff --git a/firewall/interception/windowskext2/kext.go b/firewall/interception/windowskext2/kext.go index 31fa5f27..665cf25d 100644 --- a/firewall/interception/windowskext2/kext.go +++ b/firewall/interception/windowskext2/kext.go @@ -102,13 +102,17 @@ func SendPrintMemoryStatsCommand() error { return kext_interface.SendPrintMemoryStatsCommand(kextFile) } +func SendCleanEndedConnection() error { + return kext_interface.SendCleanEndedConnectionsCommand(kextFile) +} + // RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil. func RecvVerdictRequest() (*kext_interface.Info, error) { return kext_interface.RecvInfo(kextFile) } // SetVerdict sets the verdict for a packet and/or connection. -func SetVerdict(pkt *Packet, verdict network.Verdict) error { +func SetVerdict(pkt *Packet, verdict kext_interface.KextVerdict) error { verdictCommand := kext_interface.Verdict{Id: pkt.verdictRequest, Verdict: uint8(verdict)} return kext_interface.SendVerdictCommand(kextFile, verdictCommand) } diff --git a/firewall/interception/windowskext2/packet.go b/firewall/interception/windowskext2/packet.go index a8f768f6..ef671a65 100644 --- a/firewall/interception/windowskext2/packet.go +++ b/firewall/interception/windowskext2/packet.go @@ -4,12 +4,12 @@ package windowskext import ( - "fmt" "sync" "github.com/tevino/abool" + "github.com/vlabo/portmaster_windows_rust_kext/kext_interface" - "github.com/safing/portmaster/network" + "github.com/safing/portbase/log" "github.com/safing/portmaster/network/packet" ) @@ -18,6 +18,7 @@ type Packet struct { packet.Base verdictRequest uint64 + payload []byte verdictSet *abool.AtomicBool payloadLoaded bool @@ -33,7 +34,7 @@ func (pkt *Packet) FastTrackedByIntegration() bool { // InfoOnly returns whether the packet is informational only and does not // represent an actual packet. func (pkt *Packet) InfoOnly() bool { - return pkt.verdictRequest == 0 + return false } // ExpectInfo returns whether the next packet is expected to be informational only. @@ -43,13 +44,33 @@ func (pkt *Packet) ExpectInfo() bool { // GetPayload returns the full raw packet. func (pkt *Packet) LoadPacketData() error { - return fmt.Errorf("Not implemented") + pkt.lock.Lock() + defer pkt.lock.Unlock() + + if !pkt.payloadLoaded { + pkt.payloadLoaded = true + + if len(pkt.payload) > 0 { + err := packet.Parse(pkt.payload, &pkt.Base) + if err != nil { + log.Tracef("payload: %#v", pkt.payload) + log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to parse payload: %s", err) + return packet.ErrFailedToLoadPayload + } + } + } + + if len(pkt.Raw()) == 0 { + return packet.ErrFailedToLoadPayload + } + + return nil } // Accept accepts the packet. func (pkt *Packet) Accept() error { if pkt.verdictSet.SetToIf(false, true) { - return SetVerdict(pkt, -network.VerdictAccept) + return SetVerdict(pkt, kext_interface.VerdictAccept) } return nil } @@ -57,7 +78,7 @@ func (pkt *Packet) Accept() error { // Block blocks the packet. func (pkt *Packet) Block() error { if pkt.verdictSet.SetToIf(false, true) { - return SetVerdict(pkt, -network.VerdictBlock) + return SetVerdict(pkt, kext_interface.VerdictBlock) } return nil } @@ -65,7 +86,7 @@ func (pkt *Packet) Block() error { // Drop drops the packet. func (pkt *Packet) Drop() error { if pkt.verdictSet.SetToIf(false, true) { - return SetVerdict(pkt, -network.VerdictDrop) + return SetVerdict(pkt, kext_interface.VerdictDrop) } return nil } @@ -73,7 +94,7 @@ func (pkt *Packet) Drop() error { // PermanentAccept permanently accepts connection (and the current packet). func (pkt *Packet) PermanentAccept() error { if pkt.verdictSet.SetToIf(false, true) { - return SetVerdict(pkt, network.VerdictAccept) + return SetVerdict(pkt, kext_interface.VerdictAccept) } return nil } @@ -81,7 +102,7 @@ func (pkt *Packet) PermanentAccept() error { // PermanentBlock permanently blocks connection (and the current packet). func (pkt *Packet) PermanentBlock() error { if pkt.verdictSet.SetToIf(false, true) { - return SetVerdict(pkt, network.VerdictBlock) + return SetVerdict(pkt, kext_interface.VerdictBlock) } return nil } @@ -89,7 +110,7 @@ func (pkt *Packet) PermanentBlock() error { // PermanentDrop permanently drops connection (and the current packet). func (pkt *Packet) PermanentDrop() error { if pkt.verdictSet.SetToIf(false, true) { - return SetVerdict(pkt, network.VerdictDrop) + return SetVerdict(pkt, kext_interface.VerdictDrop) } return nil } @@ -97,7 +118,7 @@ func (pkt *Packet) PermanentDrop() error { // RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet). func (pkt *Packet) RerouteToNameserver() error { if pkt.verdictSet.SetToIf(false, true) { - return SetVerdict(pkt, network.VerdictRerouteToNameserver) + return SetVerdict(pkt, kext_interface.VerdictRerouteToNameserver) } return nil } @@ -105,7 +126,7 @@ func (pkt *Packet) RerouteToNameserver() error { // RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet). func (pkt *Packet) RerouteToTunnel() error { if pkt.verdictSet.SetToIf(false, true) { - return SetVerdict(pkt, network.VerdictRerouteToTunnel) + return SetVerdict(pkt, kext_interface.VerdictRerouteToTunnel) } return nil }