Refactoring

This commit is contained in:
Vladimir
2022-10-19 13:25:11 -07:00
parent 3b341496af
commit 1f677cb93f
4 changed files with 212 additions and 158 deletions

View File

@@ -9,7 +9,6 @@ import (
"os/exec"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
@@ -34,6 +33,8 @@ var (
kextHandle windows.Handle
)
const driverName = "PortmasterKext"
func init() {
var urgentRequestsValue int32
urgentRequests = &urgentRequestsValue
@@ -50,112 +51,37 @@ func Start() error {
kextLock.Lock()
defer kextLock.Unlock()
filename := `\\.\PortmasterKext`
filename := `\\.\` + driverName
u16fname, err := syscall.UTF16FromString(filename)
if err != nil {
return fmt.Errorf("Bad filename: %s", err)
}
u16DriverPath, err := syscall.UTF16FromString(driverPath)
if err != nil {
return fmt.Errorf("Bad driver path: %s", err)
}
kextHandle, err = windows.CreateFile(&u16fname[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, 0, 0)
// check if driver is already installed
var err error
kextHandle, err = openDriver(filename)
if err == nil {
return nil // All good
return nil // device was already initialized
}
service, err := portmasterDriverInstall(&u16DriverPath[0])
// initialize and start driver service
service, err := driverInstall(driverPath)
if err != nil {
return fmt.Errorf("Faield to start service: %s", err)
return fmt.Errorf("Failed to start service: %s", err)
}
kextHandle, err = windows.CreateFile(&u16fname[0],
windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil,
windows.OPEN_EXISTING, 0, 0)
// open the driver
kextHandle, err = openDriver(filename)
// close the service handles
windows.DeleteService(service)
windows.CloseServiceHandle(service)
// driver was not installed
if err != nil {
return fmt.Errorf("Faield to kext service: %s %q", err, filename)
return fmt.Errorf("Failed to start the kext service: %s %q", err, filename)
}
ready.Set()
testRead()
return nil
}
func testRead() {
buf := [5]byte{1, 2, 3, 4, 5}
_, err := deviceIoControl(IOCTL_TEST, &buf[0], uintptr(len(buf)))
if err != nil {
log.Criticalf("Erro reading test data: %s", err)
}
log.Criticalf("Read restul: %v", buf)
}
func createService(manager windows.Handle, portmasterKextPath *uint16) (windows.Handle, error) {
u16fname, err := syscall.UTF16FromString("PortmasterKext")
if err != nil {
return 0, fmt.Errorf("Bad service: %s", err)
}
service, err := windows.OpenService(manager, &u16fname[0], windows.SERVICE_ALL_ACCESS)
if err == nil {
return service, nil
}
service, err = windows.CreateService(manager, &u16fname[0], &u16fname[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, portmasterKextPath, nil, nil, nil, nil, nil)
if err != nil {
return 0, err
}
return service, nil
}
func portmasterDriverInstall(portmasterKextPath *uint16) (windows.Handle, error) {
// Open the service manager:
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
if err != nil {
return 0, fmt.Errorf("Failed to open service manager: %d", err)
}
defer windows.CloseServiceHandle(manager)
var service windows.Handle
retryLoop:
for i := 0; i < 3; i++ {
service, err = createService(manager, portmasterKextPath)
if err == nil {
break retryLoop
}
}
if err != nil {
return 0, fmt.Errorf("Failed to create service: %s", err)
}
err = windows.StartService(service, 0, nil)
// Start the service:
if err != nil {
err = windows.GetLastError()
if err == windows.ERROR_SERVICE_ALREADY_RUNNING {
// windows.SetLastError(0)
// windows.SetLast
} else {
// Failed to start service; clean-up:
var status windows.SERVICE_STATUS
_ = windows.ControlService(service, windows.SERVICE_CONTROL_STOP, &status)
_ = windows.DeleteService(service)
_ = windows.CloseServiceHandle(service)
service = 0
//windows.SetLastError(err)
}
}
return service, nil
}
// Stop intercepting.
func Stop() error {
kextLock.Lock()
@@ -165,11 +91,11 @@ func Stop() error {
}
ready.UnSet()
err := windows.CloseHandle(kextHandle)
err := closeDriver(kextHandle)
if err != nil {
log.Errorf("kext: faield to close handle: %s", err)
log.Errorf("winkext: failed to close the handle: %s", err)
}
_, _ = exec.Command("sc", "stop", "PortmasterKext").Output()
_, _ = exec.Command("sc", "stop", driverName).Output()
return nil
}
@@ -194,14 +120,16 @@ func RecvVerdictRequest() (*VerdictRequest, error) {
timestamp := time.Now()
var new VerdictRequest
data := (*byte)(unsafe.Pointer(&new))
_, err := deviceIoControl(IOCTL_RECV_VERDICT_REQ, data, unsafe.Sizeof(new))
data := asByteArray(&new)
bytesRead, err := deviceIoControlRead(kextHandle, IOCTL_RECV_VERDICT_REQ, data)
if err != nil {
return nil, err
}
log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp))
if bytesRead == 0 {
return nil, nil // no error, no new verdict request
}
log.Criticalf("%v", new)
log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp))
return &new, nil
}
@@ -225,8 +153,8 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error {
}{pkt.verdictRequest.id, verdict}
atomic.AddInt32(urgentRequests, 1)
_, err := deviceIoControlBufferd(IOCTL_SET_VERDICT,
(*byte)(unsafe.Pointer(&verdictInfo)), unsafe.Sizeof(verdictInfo), nil, 0)
data := asByteArray(&verdictInfo)
_, err := deviceIoControlWrite(kextHandle, IOCTL_SET_VERDICT, data)
atomic.AddInt32(urgentRequests, -1)
if err != nil {
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id)
@@ -255,25 +183,21 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
}{packetID, packetSize}
atomic.AddInt32(urgentRequests, 1)
data := asByteArray(&payload)
bytesRead, err := deviceIoControlReadWrite(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
writenSize, err := deviceIoControlBufferd(IOCTL_GET_PAYLOAD,
(*byte)(unsafe.Pointer(&payload)), unsafe.Sizeof(payload),
&buf[0], uintptr(packetSize))
// timestamp := time.Now()
// log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp))
atomic.AddInt32(urgentRequests, -1)
if err != nil {
return nil, err
}
if writenSize == 0 {
if bytesRead == 0 {
return nil, errors.New("windows kext did not return any data")
}
if writenSize < uint32(len(buf)) {
return buf[:writenSize], nil
if bytesRead < uint32(len(buf)) {
return buf[:bytesRead], nil
}
return buf, nil
@@ -287,6 +211,10 @@ func ClearCache() error {
return ErrKextNotReady
}
_, err := deviceIoControl(IOCTL_CLEAR_CACHE, nil, 0)
_, err := deviceIoControlRead(kextHandle, IOCTL_CLEAR_CACHE, nil)
return err
}
func asByteArray[T any](obj *T) []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj))
}

View File

@@ -1,50 +0,0 @@
//go:build windows
// +build windows
package windowskext
import "golang.org/x/sys/windows"
const (
METHOD_BUFFERED = 0
METHOD_IN_DIRECT = 1
METHOD_OUT_DIRECT = 2
METHOD_NEITHER = 3
SIOCTL_TYPE = 40000
)
var (
IOCTL_HELLO = ctl_code(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_RECV_VERDICT_REQ_POLL = ctl_code(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_RECV_VERDICT_REQ = ctl_code(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_SET_VERDICT = ctl_code(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_GET_PAYLOAD = ctl_code(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_CLEAR_CACHE = ctl_code(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_TEST = ctl_code(SIOCTL_TYPE, 0x806, METHOD_NEITHER, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
)
func ctl_code(device_type, function, method, access uint32) uint32 {
return (device_type << 16) | (access << 14) | (function << 2) | method
}
func deviceIoControl(code uint32, data *byte, size uintptr) (uint32, error) {
var bytesReturned uint32
err := windows.DeviceIoControl(kextHandle,
code,
nil, 0,
data, uint32(size),
&bytesReturned, nil)
return bytesReturned, err
}
func deviceIoControlBufferd(code uint32, inData *byte, inSize uintptr, outData *byte, outSize uintptr) (uint32, error) {
var bytesReturned uint32
err := windows.DeviceIoControl(kextHandle,
code,
inData, uint32(inSize),
outData, uint32(outSize),
&bytesReturned, nil)
return bytesReturned, err
}

View File

@@ -0,0 +1,85 @@
//go:build windows
// +build windows
package windowskext
import (
"fmt"
"syscall"
"golang.org/x/sys/windows"
)
func createService(manager windows.Handle, portmasterKextPath *uint16) (windows.Handle, error) {
u16filename, err := syscall.UTF16FromString(driverName)
if err != nil {
return 0, fmt.Errorf("Bad service: %s", err)
}
service, err := windows.OpenService(manager, &u16filename[0], windows.SERVICE_ALL_ACCESS)
if err == nil {
return service, nil
}
service, err = windows.CreateService(manager, &u16filename[0], &u16filename[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, portmasterKextPath, nil, nil, nil, nil, nil)
if err != nil {
return 0, err
}
return service, nil
}
func driverInstall(portmasterKextPath string) (windows.Handle, error) {
u16kextPath, _ := syscall.UTF16FromString(portmasterKextPath)
// Open the service manager:
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
if err != nil {
return 0, fmt.Errorf("Failed to open service manager: %d", err)
}
defer windows.CloseServiceHandle(manager)
var service windows.Handle
retryLoop:
for i := 0; i < 3; i++ {
service, err = createService(manager, &u16kextPath[0])
if err == nil {
break retryLoop
}
}
if err != nil {
return 0, fmt.Errorf("Failed to create service: %s", err)
}
err = windows.StartService(service, 0, nil)
// Start the service:
if err != nil {
err = windows.GetLastError()
if err == windows.ERROR_SERVICE_ALREADY_RUNNING {
// windows.SetLastError(0)
} else {
// Failed to start service; clean-up:
var status windows.SERVICE_STATUS
_ = windows.ControlService(service, windows.SERVICE_CONTROL_STOP, &status)
_ = windows.DeleteService(service)
_ = windows.CloseServiceHandle(service)
service = 0
//windows.SetLastError(err)
}
}
return service, nil
}
func openDriver(filename string) (windows.Handle, error) {
u16filename, _ := syscall.UTF16FromString(filename)
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, 0, 0)
if err != nil {
return 0, err
}
return handle, nil
}
func closeDriver(handle windows.Handle) error {
return windows.CloseHandle(handle)
}

View File

@@ -0,0 +1,91 @@
//go:build windows
// +build windows
package windowskext
import "golang.org/x/sys/windows"
const (
METHOD_BUFFERED = 0
METHOD_IN_DIRECT = 1
METHOD_OUT_DIRECT = 2
METHOD_NEITHER = 3
SIOCTL_TYPE = 40000
)
var (
IOCTL_HELLO = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_RECV_VERDICT_REQ_POLL = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
)
func ctlCode(device_type, function, method, access uint32) uint32 {
return (device_type << 16) | (access << 14) | (function << 2) | method
}
func deviceIoControlRead(handle windows.Handle, code uint32, data []byte) (uint32, error) {
var bytesReturned uint32
var dataPtr *byte = nil
var dataSize uint32 = 0
if data != nil {
dataPtr = &data[0]
dataSize = uint32(len(data))
}
err := windows.DeviceIoControl(handle,
code,
nil, 0,
dataPtr, dataSize,
&bytesReturned, nil)
return bytesReturned, err
}
func deviceIoControlWrite(handle windows.Handle, code uint32, data []byte) (uint32, error) {
var bytesReturned uint32
var dataPtr *byte = nil
var dataSize uint32 = 0
if data != nil {
dataPtr = &data[0]
dataSize = uint32(len(data))
}
err := windows.DeviceIoControl(handle,
code,
dataPtr, dataSize,
nil, 0,
&bytesReturned, nil)
return bytesReturned, err
}
func deviceIoControlReadWrite(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
var bytesReturned uint32
var inDataPtr *byte = nil
var inDataSize uint32 = 0
if inData != nil {
inDataPtr = &inData[0]
inDataSize = uint32(len(inData))
}
var outDataPtr *byte = nil
var outDataSize uint32 = 0
if outData != nil {
outDataPtr = &outData[0]
outDataSize = uint32(len(outData))
}
err := windows.DeviceIoControl(handle,
code,
inDataPtr, inDataSize,
outDataPtr, outDataSize,
&bytesReturned, nil)
return bytesReturned, err
}