From 101bf16727517bfe5652f3e1fc4ec83f26942d1b Mon Sep 17 00:00:00 2001 From: Vladimir Date: Mon, 7 Nov 2022 16:09:41 +0100 Subject: [PATCH] better kext service handling --- firewall/interception/windowskext/kext.go | 29 +++++----- firewall/interception/windowskext/service.go | 61 +++++++++++++++++--- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index 59e4af35..949eb0aa 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -46,28 +46,27 @@ func Start() error { kextLock.Lock() defer kextLock.Unlock() - filename := `\\.\` + driverName - - // check if driver is already installed - var err error - kextHandle, err = openDriver(filename) - if err == nil { - return nil // device was already initialized - } - // initialize and start driver service - service, err = createKextService(driverName, driverPath) + service, err := createKextService(driverName, driverPath) if err != nil { + log.Warningf("winkext: failed to create service: %s", err) return fmt.Errorf("failed to create service: %w", err) } - err = service.start() + running, err := service.isRunning() + if err == nil && !running { + err = service.start(true) - if err != nil { - return fmt.Errorf("failed to start service: %w", err) + if err != nil { + log.Warningf("winkext: failed to start service: %s", err) + return fmt.Errorf("failed to start service: %w", err) + } + } else if err != nil { + return fmt.Errorf("service not initialized: %w", err) } - // open the driver + // Open the driver + filename := `\\.\` + driverName kextHandle, err = openDriver(filename) // driver was not installed @@ -88,7 +87,7 @@ func Stop() error { log.Warningf("winkext: failed to close the handle: %s", err) } - err = service.stop() + err = service.stop(true) if err != nil { log.Warningf("winkext: failed to stop service: %s", err) } diff --git a/firewall/interception/windowskext/service.go b/firewall/interception/windowskext/service.go index b8f8ed70..08b37088 100644 --- a/firewall/interception/windowskext/service.go +++ b/firewall/interception/windowskext/service.go @@ -6,6 +6,7 @@ package windowskext import ( "fmt" "syscall" + "time" "golang.org/x/sys/windows" ) @@ -47,7 +48,40 @@ func (s *KextService) isValid() bool { return s != nil && s.handle != winInvalidHandleValue && s.handle != 0 } -func (s *KextService) start() error { +func (s *KextService) isRunning() (bool, error) { + if !s.isValid() { + return false, fmt.Errorf("kext service not initialized") + } + var status windows.SERVICE_STATUS + err := windows.QueryServiceStatus(s.handle, &status) + if err != nil { + return false, err + } + return status.CurrentState == windows.SERVICE_RUNNING, nil +} + +func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) { + var status windows.SERVICE_STATUS + status.CurrentState = windows.SERVICE_NO_CHANGE + start := time.Now() + for status.CurrentState == neededStatus { + err := windows.QueryServiceStatus(handle, &status) + if err != nil { + return false, fmt.Errorf("failed while waiting for service to start: %w", err) + } + + if time.Now().Sub(start) > timeLimit { + return false, fmt.Errorf("time limit reached") + } + + // Sleep for 1/10 of the wait hint, recommended time from microsoft + time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond) + } + + return true, nil +} + +func (s *KextService) start(wait bool) error { if !s.isValid() { return fmt.Errorf("kext service not initialized") } @@ -68,22 +102,35 @@ func (s *KextService) start() error { } } + // Wait for service to start + if wait { + success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second)) + if err != nil || !success { + return fmt.Errorf("service did not start: %w", err) + } + } + return nil } -func (s *KextService) stop() error { +func (s *KextService) stop(wait bool) error { if !s.isValid() { - return fmt.Errorf("kext service not initialized") + return fmt.Errorf("kext service not initialized %v", s) } + // Stop the service var status windows.SERVICE_STATUS err := windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status) if err != nil { return fmt.Errorf("service failed to stop: %w", err) } - if status.CurrentState != windows.SERVICE_STOP_PENDING && status.CurrentState != windows.SERVICE_STOPPED { - return fmt.Errorf("service unexpected status after stop: %d", status.CurrentState) + // Wait for service to stop + if wait { + success, err := waitForServiceStatus(s.handle, windows.SERVICE_STOPPED, time.Duration(10*time.Second)) + if err != nil || !success { + return fmt.Errorf("service did not stop: %w", err) + } } return nil @@ -91,7 +138,7 @@ func (s *KextService) stop() error { func (s *KextService) delete() error { if !s.isValid() { - return fmt.Errorf("kext service not initialized") + return fmt.Errorf("kext service not initialized %v", s) } err := windows.DeleteService(s.handle) @@ -103,7 +150,7 @@ func (s *KextService) delete() error { func (s *KextService) closeHandle() error { if !s.isValid() { - return fmt.Errorf("kext service not initialized") + return fmt.Errorf("kext service not initialized %v", s) } err := windows.CloseServiceHandle(s.handle)