Removed legacy code and refactoring

This commit is contained in:
Vladimir
2022-11-07 11:07:49 +01:00
parent 1b480066be
commit 3768db6b32
4 changed files with 157 additions and 145 deletions

View File

@@ -52,14 +52,10 @@ type VerdictInfo struct {
// Handler transforms received packets to the Packet interface.
func Handler(packets chan packet.Packet) {
if !ready.IsSet() {
return
}
defer close(packets)
for {
if !ready.IsSet() {
if kextHandle == winInvalidHandleValue {
return
}

View File

@@ -6,15 +6,13 @@ package windowskext
import (
"errors"
"fmt"
"os/exec"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network"
"github.com/tevino/abool"
"golang.org/x/sys/windows"
)
@@ -23,25 +21,22 @@ var (
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension")
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
kextLock sync.RWMutex
ready = abool.NewBool(false)
urgentRequests *int32
driverPath string
kextLock sync.RWMutex
driverPath string
kextHandle windows.Handle
service *KextService
)
const driverName = "PortmasterKext"
func init() {
var urgentRequestsValue int32
urgentRequests = &urgentRequestsValue
}
const (
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
driverName = "PortmasterKext"
)
// Init initializes the DLL and the Kext (Kernel Driver).
func Init(path string) error {
kextHandle = winInvalidHandleValue
driverPath = path
return nil
}
@@ -61,24 +56,25 @@ func Start() error {
}
// initialize and start driver service
service, err := driverInstall(driverPath)
service, err = createKextService(driverName, driverPath)
if err != nil {
return fmt.Errorf("Failed to start service: %s", err)
return fmt.Errorf("failed to create service: %w", err)
}
err = service.start()
if err != nil {
return fmt.Errorf("failed to start service: %w", err)
}
// open the driver
kextHandle, err = openDriver(filename)
// close the service handles
_ = windows.DeleteService(service)
_ = windows.CloseServiceHandle(service)
// driver was not installed
if err != nil {
return fmt.Errorf("Failed to start the kext service: %s %q", err, filename)
return fmt.Errorf("failed to open driver: %q %w", filename, err)
}
ready.Set()
return nil
}
@@ -86,20 +82,27 @@ func Start() error {
func Stop() error {
kextLock.Lock()
defer kextLock.Unlock()
if !ready.IsSet() {
return ErrKextNotReady
}
ready.UnSet()
err := closeDriver(kextHandle)
if err != nil {
log.Errorf("winkext: failed to close the handle: %s", err)
log.Warningf("winkext: failed to close the handle: %s", err)
}
_, err = exec.Command("sc", "stop", driverName).Output() // This is a question of taste, but it is a robust and solid solution
err = service.stop()
if err != nil {
log.Errorf("winkext: failed to stop the service: %q", err)
log.Warningf("winkext: failed to stop service: %s", err)
}
// Driver file may change on the next start so it's better to delete the service
err = service.delete()
if err != nil {
log.Warningf("winkext: failed to delete service: %s", err)
}
err = service.closeHandle()
if err != nil {
log.Warningf("winkext: failed to close the handle: %s", err)
}
kextHandle = winInvalidHandleValue
return nil
}
@@ -107,19 +110,9 @@ func Stop() error {
func RecvVerdictRequest() (*VerdictRequest, error) {
kextLock.RLock()
defer kextLock.RUnlock()
if !ready.IsSet() {
if kextHandle == winInvalidHandleValue {
return nil, ErrKextNotReady
}
// wait for urgent requests to complete
for i := 1; i <= 100; i++ {
if atomic.LoadInt32(urgentRequests) <= 0 {
break
}
if i == 100 {
log.Warningf("winkext: RecvVerdictRequest waited 100 times")
}
time.Sleep(100 * time.Microsecond)
}
timestamp := time.Now()
// Initialize struct for the output data
@@ -127,7 +120,7 @@ func RecvVerdictRequest() (*VerdictRequest, error) {
// Make driver request
data := asByteArray(&new)
bytesRead, err := deviceIoControlRead(kextHandle, IOCTL_RECV_VERDICT_REQ, data)
bytesRead, err := deviceIOControl(kextHandle, IOCTL_RECV_VERDICT_REQ, nil, data)
if err != nil {
return nil, err
}
@@ -148,7 +141,7 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error {
kextLock.RLock()
defer kextLock.RUnlock()
if !ready.IsSet() {
if kextHandle == winInvalidHandleValue {
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: kext not ready", verdict)
return ErrKextNotReady
}
@@ -156,10 +149,8 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error {
verdictInfo := VerdictInfo{pkt.verdictRequest.id, verdict}
// Make driver request
atomic.AddInt32(urgentRequests, 1)
data := asByteArray(&verdictInfo)
_, err := deviceIoControlWrite(kextHandle, IOCTL_SET_VERDICT, data)
atomic.AddInt32(urgentRequests, -1)
_, err := deviceIOControl(kextHandle, IOCTL_SET_VERDICT, data, nil)
if err != nil {
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id)
return err
@@ -176,7 +167,7 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
// Check if driver is initialized
kextLock.RLock()
defer kextLock.RUnlock()
if !ready.IsSet() {
if kextHandle == winInvalidHandleValue {
return nil, ErrKextNotReady
}
@@ -189,11 +180,8 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
}{packetID, packetSize}
// Make driver request
atomic.AddInt32(urgentRequests, 1)
data := asByteArray(&payload)
bytesRead, err := deviceIoControlReadWrite(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
atomic.AddInt32(urgentRequests, -1)
bytesRead, err := deviceIOControl(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
if err != nil {
return nil, err
@@ -216,16 +204,38 @@ func ClearCache() error {
defer kextLock.RUnlock()
// Check if driver is initialized
if !ready.IsSet() {
if kextHandle == winInvalidHandleValue {
log.Error("kext: failed to clear the cache: kext not ready")
return ErrKextNotReady
}
// Make driver request
_, err := deviceIoControlRead(kextHandle, IOCTL_CLEAR_CACHE, nil)
_, err := deviceIOControl(kextHandle, IOCTL_CLEAR_CACHE, nil, nil)
return err
}
func asByteArray[T any](obj *T) []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj))
}
func openDriver(filename string) (windows.Handle, error) {
u16filename, err := syscall.UTF16FromString(filename)
if err != nil {
return winInvalidHandleValue, fmt.Errorf("failed to convert driver filename to UTF16 string %w", err)
}
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0)
if err != nil {
return winInvalidHandleValue, err
}
return handle, nil
}
func closeDriver(handle windows.Handle) error {
if kextHandle == winInvalidHandleValue {
return ErrKextNotReady
}
return windows.CloseHandle(handle)
}

View File

@@ -10,78 +10,105 @@ import (
"golang.org/x/sys/windows"
)
func createService(manager windows.Handle, portmasterKextPath *uint16) (windows.Handle, error) {
u16filename, err := syscall.UTF16FromString(driverName)
if err != nil {
return 0, fmt.Errorf("Bad service: %s", err)
}
// Check if it's already created
service, err := windows.OpenService(manager, &u16filename[0], windows.SERVICE_ALL_ACCESS)
if err == nil {
return service, nil
}
// Create the service
service, err = windows.CreateService(manager, &u16filename[0], &u16filename[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, portmasterKextPath, nil, nil, nil, nil, nil)
if err != nil {
return 0, err
}
return service, nil
type KextService struct {
handle windows.Handle
}
func driverInstall(portmasterKextPath string) (windows.Handle, error) {
u16kextPath, _ := syscall.UTF16FromString(portmasterKextPath)
func createKextService(driverName string, driverPath string) (*KextService, error) {
// Open the service manager:
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
if err != nil {
return 0, fmt.Errorf("Failed to open service manager: %d", err)
return nil, fmt.Errorf("failed to open service manager: %d", err)
}
defer windows.CloseServiceHandle(manager)
// Try to create the service. Retry if it fails.
var service windows.Handle
retryLoop:
for i := 0; i < 3; i++ {
service, err = createService(manager, &u16kextPath[0])
if err == nil {
break retryLoop
}
driverNameU16, err := syscall.UTF16FromString(driverName)
if err != nil {
return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err)
}
// Check if it's already created
service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS)
if err == nil {
return &KextService{handle: service}, nil // service was already created
}
driverPathU16, err := syscall.UTF16FromString(driverPath)
// Create the service
service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil)
if err != nil {
return 0, fmt.Errorf("Failed to create service: %s", err)
return nil, err
}
return &KextService{handle: service}, nil
}
func (s *KextService) isValid() bool {
return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
}
func (s *KextService) start() error {
if !s.isValid() {
return fmt.Errorf("kext service not initialized")
}
// Start the service:
err = windows.StartService(service, 0, nil)
err := windows.StartService(s.handle, 0, nil)
if err != nil {
err = windows.GetLastError()
if err != windows.ERROR_SERVICE_ALREADY_RUNNING {
// Failed to start service; clean-up:
var status windows.SERVICE_STATUS
_ = windows.ControlService(service, windows.SERVICE_CONTROL_STOP, &status)
_ = windows.DeleteService(service)
_ = windows.CloseServiceHandle(service)
service = 0
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
_ = windows.DeleteService(s.handle)
_ = windows.CloseServiceHandle(s.handle)
s.handle = winInvalidHandleValue
return err
}
}
return service, nil
return nil
}
func openDriver(filename string) (windows.Handle, error) {
u16filename, _ := syscall.UTF16FromString(filename)
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0)
if err != nil {
return 0, err
func (s *KextService) stop() error {
if !s.isValid() {
return fmt.Errorf("kext service not initialized")
}
return handle, nil
var status windows.SERVICE_STATUS
err := windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
if err != nil {
return fmt.Errorf("service failed to stop: %w", err)
}
if status.CurrentState != windows.SERVICE_STOP_PENDING && status.CurrentState != windows.SERVICE_STOPPED {
return fmt.Errorf("service unexpected status after stop: %d", status.CurrentState)
}
return nil
}
func closeDriver(handle windows.Handle) error {
return windows.CloseHandle(handle)
func (s *KextService) delete() error {
if !s.isValid() {
return fmt.Errorf("kext service not initialized")
}
err := windows.DeleteService(s.handle)
if err != nil {
return fmt.Errorf("failed to delete service: %s", err)
}
return nil
}
func (s *KextService) closeHandle() error {
if !s.isValid() {
return fmt.Errorf("kext service not initialized")
}
err := windows.CloseServiceHandle(s.handle)
if err != nil {
return fmt.Errorf("failed to close service handle: %s", err)
}
return nil
}

View File

@@ -27,47 +27,7 @@ func ctlCode(device_type, function, method, access uint32) uint32 {
return (device_type << 16) | (access << 14) | (function << 2) | method
}
func deviceIoControlRead(handle windows.Handle, code uint32, data []byte) (uint32, error) {
var bytesReturned uint32
var dataPtr *byte = nil
var dataSize uint32 = 0
if data != nil {
dataPtr = &data[0]
dataSize = uint32(len(data))
}
err := windows.DeviceIoControl(handle,
code,
nil, 0,
dataPtr, dataSize,
&bytesReturned, nil)
return bytesReturned, err
}
func deviceIoControlWrite(handle windows.Handle, code uint32, data []byte) (uint32, error) {
var bytesReturned uint32
var dataPtr *byte = nil
var dataSize uint32 = 0
if data != nil {
dataPtr = &data[0]
dataSize = uint32(len(data))
}
err := windows.DeviceIoControl(handle,
code,
dataPtr, dataSize,
nil, 0,
&bytesReturned, nil)
return bytesReturned, err
}
func deviceIoControlReadWrite(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
var bytesReturned uint32
func deviceIOControlAsync(handle windows.Handle, code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) {
var inDataPtr *byte = nil
var inDataSize uint32 = 0
if inData != nil {
@@ -81,11 +41,30 @@ func deviceIoControlReadWrite(handle windows.Handle, code uint32, inData []byte,
outDataPtr = &outData[0]
outDataSize = uint32(len(outData))
}
overlapped := &windows.Overlapped{}
err := windows.DeviceIoControl(handle,
code,
inDataPtr, inDataSize,
outDataPtr, outDataSize,
&bytesReturned, nil)
nil, overlapped)
if err != nil {
return nil, err
}
return overlapped, nil
}
func deviceIOControl(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
overlapped, err := deviceIOControlAsync(handle, code, inData, outData)
if err != nil {
return 0, err
}
var bytesReturned uint32
err = windows.GetOverlappedResult(handle, overlapped, &bytesReturned, true)
return bytesReturned, err
}