diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 2932a31a..f7952ec8 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -81,10 +81,6 @@ func Handler(packets chan packet.Packet) { defer close(packets) for { - if kextHandle == winInvalidHandleValue { - return - } - packetInfo, err := RecvVerdictRequest() if err != nil { // Check if we are done with processing. diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index e4ae511a..d0da7baf 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -79,10 +79,16 @@ func Start() error { // Stop intercepting. func Stop() error { + // Prepare kernel for shutdown + err := shutdownRequest() + if err != nil { + log.Warningf("winkext: shutdown request failed: %s", err) + } + kextLock.Lock() defer kextLock.Unlock() - err := closeDriver(kextHandle) + err = closeDriver(kextHandle) if err != nil { log.Warningf("winkext: failed to close the handle: %s", err) } @@ -105,6 +111,18 @@ func Stop() error { return nil } +func shutdownRequest() error { + kextLock.RLock() + defer kextLock.RUnlock() + if kextHandle == winInvalidHandleValue { + return ErrKextNotReady + } + // Sent a shutdown request so the kernel extension can prepare. + _, err := deviceIOControl(kextHandle, IOCTL_SHUTDOWN_REQUEST, nil, nil) + + return err +} + // RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil. func RecvVerdictRequest() (*VerdictRequest, error) { kextLock.RLock() @@ -241,7 +259,7 @@ func UpdateVerdict(conn *network.Connection) error { // Make driver request data := asByteArray(&info) - err := deviceIoControlDirect(kextHandle, IOCTL_UPDATE_VERDICT, data) + _, err := deviceIOControl(kextHandle, IOCTL_UPDATE_VERDICT, data, nil) return err } @@ -256,7 +274,7 @@ func GetVersion() (*VersionInfo, error) { } data := make([]uint8, 4) - err := deviceIoControlDirect(kextHandle, IOCTL_VERSION, data) + _, err := deviceIOControl(kextHandle, IOCTL_VERSION, data, nil) if err != nil { return nil, err diff --git a/firewall/interception/windowskext/syscall.go b/firewall/interception/windowskext/syscall.go index 2b85adf5..21706871 100644 --- a/firewall/interception/windowskext/syscall.go +++ b/firewall/interception/windowskext/syscall.go @@ -15,13 +15,13 @@ const ( ) var ( - IOCTL_VERSION = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) - IOCTL_RECV_VERDICT_REQ_POLL = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) // Not used - IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) - IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) - IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) - IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) - IOCTL_UPDATE_VERDICT = ctlCode(SIOCTL_TYPE, 0x806, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_VERSION = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_SHUTDOWN_REQUEST = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_UPDATE_VERDICT = ctlCode(SIOCTL_TYPE, 0x806, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) ) func ctlCode(device_type, function, method, access uint32) uint32 { @@ -69,21 +69,3 @@ func deviceIOControl(handle windows.Handle, code uint32, inData []byte, outData return bytesReturned, err } - -// Use for METHOD_NEITHER IOCTL, the data buffer is passed directly to the kernel -func deviceIoControlDirect(handle windows.Handle, code uint32, data []byte) error { - var dataPtr *byte = nil - var dataSize uint32 = 0 - if data != nil { - dataPtr = &data[0] - dataSize = uint32(len(data)) - } - - err := windows.DeviceIoControl(handle, - code, - dataPtr, dataSize, - nil, 0, - nil, nil) - - return err -}