Removed legacy code and refactoring
This commit is contained in:
@@ -52,14 +52,10 @@ type VerdictInfo struct {
|
||||
|
||||
// Handler transforms received packets to the Packet interface.
|
||||
func Handler(packets chan packet.Packet) {
|
||||
if !ready.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
defer close(packets)
|
||||
|
||||
for {
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -6,15 +6,13 @@ package windowskext
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/network"
|
||||
"github.com/tevino/abool"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
@@ -23,25 +21,22 @@ var (
|
||||
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
|
||||
ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension")
|
||||
|
||||
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
|
||||
|
||||
kextLock sync.RWMutex
|
||||
ready = abool.NewBool(false)
|
||||
urgentRequests *int32
|
||||
driverPath string
|
||||
kextLock sync.RWMutex
|
||||
driverPath string
|
||||
|
||||
kextHandle windows.Handle
|
||||
service *KextService
|
||||
)
|
||||
|
||||
const driverName = "PortmasterKext"
|
||||
|
||||
func init() {
|
||||
var urgentRequestsValue int32
|
||||
urgentRequests = &urgentRequestsValue
|
||||
}
|
||||
const (
|
||||
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
|
||||
winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
|
||||
driverName = "PortmasterKext"
|
||||
)
|
||||
|
||||
// Init initializes the DLL and the Kext (Kernel Driver).
|
||||
func Init(path string) error {
|
||||
kextHandle = winInvalidHandleValue
|
||||
driverPath = path
|
||||
return nil
|
||||
}
|
||||
@@ -61,24 +56,25 @@ func Start() error {
|
||||
}
|
||||
|
||||
// initialize and start driver service
|
||||
service, err := driverInstall(driverPath)
|
||||
service, err = createKextService(driverName, driverPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to start service: %s", err)
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
err = service.start()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %w", err)
|
||||
}
|
||||
|
||||
// open the driver
|
||||
kextHandle, err = openDriver(filename)
|
||||
|
||||
// close the service handles
|
||||
_ = windows.DeleteService(service)
|
||||
_ = windows.CloseServiceHandle(service)
|
||||
|
||||
// driver was not installed
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to start the kext service: %s %q", err, filename)
|
||||
return fmt.Errorf("failed to open driver: %q %w", filename, err)
|
||||
}
|
||||
|
||||
ready.Set()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -86,20 +82,27 @@ func Start() error {
|
||||
func Stop() error {
|
||||
kextLock.Lock()
|
||||
defer kextLock.Unlock()
|
||||
if !ready.IsSet() {
|
||||
return ErrKextNotReady
|
||||
}
|
||||
ready.UnSet()
|
||||
|
||||
err := closeDriver(kextHandle)
|
||||
if err != nil {
|
||||
log.Errorf("winkext: failed to close the handle: %s", err)
|
||||
log.Warningf("winkext: failed to close the handle: %s", err)
|
||||
}
|
||||
|
||||
_, err = exec.Command("sc", "stop", driverName).Output() // This is a question of taste, but it is a robust and solid solution
|
||||
err = service.stop()
|
||||
if err != nil {
|
||||
log.Errorf("winkext: failed to stop the service: %q", err)
|
||||
log.Warningf("winkext: failed to stop service: %s", err)
|
||||
}
|
||||
// Driver file may change on the next start so it's better to delete the service
|
||||
err = service.delete()
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to delete service: %s", err)
|
||||
}
|
||||
err = service.closeHandle()
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to close the handle: %s", err)
|
||||
}
|
||||
|
||||
kextHandle = winInvalidHandleValue
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -107,19 +110,9 @@ func Stop() error {
|
||||
func RecvVerdictRequest() (*VerdictRequest, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
// wait for urgent requests to complete
|
||||
for i := 1; i <= 100; i++ {
|
||||
if atomic.LoadInt32(urgentRequests) <= 0 {
|
||||
break
|
||||
}
|
||||
if i == 100 {
|
||||
log.Warningf("winkext: RecvVerdictRequest waited 100 times")
|
||||
}
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
}
|
||||
|
||||
timestamp := time.Now()
|
||||
// Initialize struct for the output data
|
||||
@@ -127,7 +120,7 @@ func RecvVerdictRequest() (*VerdictRequest, error) {
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&new)
|
||||
bytesRead, err := deviceIoControlRead(kextHandle, IOCTL_RECV_VERDICT_REQ, data)
|
||||
bytesRead, err := deviceIOControl(kextHandle, IOCTL_RECV_VERDICT_REQ, nil, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,7 +141,7 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error {
|
||||
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: kext not ready", verdict)
|
||||
return ErrKextNotReady
|
||||
}
|
||||
@@ -156,10 +149,8 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error {
|
||||
verdictInfo := VerdictInfo{pkt.verdictRequest.id, verdict}
|
||||
|
||||
// Make driver request
|
||||
atomic.AddInt32(urgentRequests, 1)
|
||||
data := asByteArray(&verdictInfo)
|
||||
_, err := deviceIoControlWrite(kextHandle, IOCTL_SET_VERDICT, data)
|
||||
atomic.AddInt32(urgentRequests, -1)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_SET_VERDICT, data, nil)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id)
|
||||
return err
|
||||
@@ -176,7 +167,7 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
|
||||
// Check if driver is initialized
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
@@ -189,11 +180,8 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
|
||||
}{packetID, packetSize}
|
||||
|
||||
// Make driver request
|
||||
atomic.AddInt32(urgentRequests, 1)
|
||||
data := asByteArray(&payload)
|
||||
bytesRead, err := deviceIoControlReadWrite(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
|
||||
|
||||
atomic.AddInt32(urgentRequests, -1)
|
||||
bytesRead, err := deviceIOControl(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -216,16 +204,38 @@ func ClearCache() error {
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
// Make driver request
|
||||
_, err := deviceIoControlRead(kextHandle, IOCTL_CLEAR_CACHE, nil)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_CLEAR_CACHE, nil, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func asByteArray[T any](obj *T) []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj))
|
||||
}
|
||||
|
||||
func openDriver(filename string) (windows.Handle, error) {
|
||||
u16filename, err := syscall.UTF16FromString(filename)
|
||||
if err != nil {
|
||||
return winInvalidHandleValue, fmt.Errorf("failed to convert driver filename to UTF16 string %w", err)
|
||||
}
|
||||
|
||||
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0)
|
||||
if err != nil {
|
||||
return winInvalidHandleValue, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func closeDriver(handle windows.Handle) error {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
return windows.CloseHandle(handle)
|
||||
}
|
||||
|
||||
@@ -10,78 +10,105 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func createService(manager windows.Handle, portmasterKextPath *uint16) (windows.Handle, error) {
|
||||
u16filename, err := syscall.UTF16FromString(driverName)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Bad service: %s", err)
|
||||
}
|
||||
// Check if it's already created
|
||||
service, err := windows.OpenService(manager, &u16filename[0], windows.SERVICE_ALL_ACCESS)
|
||||
if err == nil {
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// Create the service
|
||||
service, err = windows.CreateService(manager, &u16filename[0], &u16filename[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, portmasterKextPath, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return service, nil
|
||||
type KextService struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func driverInstall(portmasterKextPath string) (windows.Handle, error) {
|
||||
u16kextPath, _ := syscall.UTF16FromString(portmasterKextPath)
|
||||
func createKextService(driverName string, driverPath string) (*KextService, error) {
|
||||
// Open the service manager:
|
||||
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Failed to open service manager: %d", err)
|
||||
return nil, fmt.Errorf("failed to open service manager: %d", err)
|
||||
}
|
||||
defer windows.CloseServiceHandle(manager)
|
||||
|
||||
// Try to create the service. Retry if it fails.
|
||||
var service windows.Handle
|
||||
retryLoop:
|
||||
for i := 0; i < 3; i++ {
|
||||
service, err = createService(manager, &u16kextPath[0])
|
||||
if err == nil {
|
||||
break retryLoop
|
||||
}
|
||||
driverNameU16, err := syscall.UTF16FromString(driverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err)
|
||||
}
|
||||
// Check if it's already created
|
||||
service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS)
|
||||
if err == nil {
|
||||
return &KextService{handle: service}, nil // service was already created
|
||||
}
|
||||
|
||||
driverPathU16, err := syscall.UTF16FromString(driverPath)
|
||||
|
||||
// Create the service
|
||||
service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Failed to create service: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KextService{handle: service}, nil
|
||||
}
|
||||
|
||||
func (s *KextService) isValid() bool {
|
||||
return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
|
||||
}
|
||||
|
||||
func (s *KextService) start() error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
// Start the service:
|
||||
err = windows.StartService(service, 0, nil)
|
||||
err := windows.StartService(s.handle, 0, nil)
|
||||
|
||||
if err != nil {
|
||||
err = windows.GetLastError()
|
||||
if err != windows.ERROR_SERVICE_ALREADY_RUNNING {
|
||||
// Failed to start service; clean-up:
|
||||
var status windows.SERVICE_STATUS
|
||||
_ = windows.ControlService(service, windows.SERVICE_CONTROL_STOP, &status)
|
||||
_ = windows.DeleteService(service)
|
||||
_ = windows.CloseServiceHandle(service)
|
||||
service = 0
|
||||
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
|
||||
_ = windows.DeleteService(s.handle)
|
||||
_ = windows.CloseServiceHandle(s.handle)
|
||||
s.handle = winInvalidHandleValue
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return service, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func openDriver(filename string) (windows.Handle, error) {
|
||||
u16filename, _ := syscall.UTF16FromString(filename)
|
||||
|
||||
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func (s *KextService) stop() error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
var status windows.SERVICE_STATUS
|
||||
err := windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service failed to stop: %w", err)
|
||||
}
|
||||
|
||||
if status.CurrentState != windows.SERVICE_STOP_PENDING && status.CurrentState != windows.SERVICE_STOPPED {
|
||||
return fmt.Errorf("service unexpected status after stop: %d", status.CurrentState)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeDriver(handle windows.Handle) error {
|
||||
return windows.CloseHandle(handle)
|
||||
func (s *KextService) delete() error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
err := windows.DeleteService(s.handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete service: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) closeHandle() error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
err := windows.CloseServiceHandle(s.handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close service handle: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,47 +27,7 @@ func ctlCode(device_type, function, method, access uint32) uint32 {
|
||||
return (device_type << 16) | (access << 14) | (function << 2) | method
|
||||
}
|
||||
|
||||
func deviceIoControlRead(handle windows.Handle, code uint32, data []byte) (uint32, error) {
|
||||
var bytesReturned uint32
|
||||
|
||||
var dataPtr *byte = nil
|
||||
var dataSize uint32 = 0
|
||||
if data != nil {
|
||||
dataPtr = &data[0]
|
||||
dataSize = uint32(len(data))
|
||||
}
|
||||
|
||||
err := windows.DeviceIoControl(handle,
|
||||
code,
|
||||
nil, 0,
|
||||
dataPtr, dataSize,
|
||||
&bytesReturned, nil)
|
||||
|
||||
return bytesReturned, err
|
||||
}
|
||||
|
||||
func deviceIoControlWrite(handle windows.Handle, code uint32, data []byte) (uint32, error) {
|
||||
var bytesReturned uint32
|
||||
|
||||
var dataPtr *byte = nil
|
||||
var dataSize uint32 = 0
|
||||
if data != nil {
|
||||
dataPtr = &data[0]
|
||||
dataSize = uint32(len(data))
|
||||
}
|
||||
|
||||
err := windows.DeviceIoControl(handle,
|
||||
code,
|
||||
dataPtr, dataSize,
|
||||
nil, 0,
|
||||
&bytesReturned, nil)
|
||||
|
||||
return bytesReturned, err
|
||||
}
|
||||
|
||||
func deviceIoControlReadWrite(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
|
||||
var bytesReturned uint32
|
||||
|
||||
func deviceIOControlAsync(handle windows.Handle, code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) {
|
||||
var inDataPtr *byte = nil
|
||||
var inDataSize uint32 = 0
|
||||
if inData != nil {
|
||||
@@ -81,11 +41,30 @@ func deviceIoControlReadWrite(handle windows.Handle, code uint32, inData []byte,
|
||||
outDataPtr = &outData[0]
|
||||
outDataSize = uint32(len(outData))
|
||||
}
|
||||
|
||||
overlapped := &windows.Overlapped{}
|
||||
err := windows.DeviceIoControl(handle,
|
||||
code,
|
||||
inDataPtr, inDataSize,
|
||||
outDataPtr, outDataSize,
|
||||
&bytesReturned, nil)
|
||||
nil, overlapped)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return overlapped, nil
|
||||
|
||||
}
|
||||
|
||||
func deviceIOControl(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
|
||||
overlapped, err := deviceIOControlAsync(handle, code, inData, outData)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var bytesReturned uint32
|
||||
err = windows.GetOverlappedResult(handle, overlapped, &bytesReturned, true)
|
||||
|
||||
return bytesReturned, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user