Merge branch 'develop' into feature/new-kext
This commit is contained in:
132
service/firewall/interception/windowskext/bandwidth_stats.go
Normal file
132
service/firewall/interception/windowskext/bandwidth_stats.go
Normal file
@@ -0,0 +1,132 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
// This file contains example code how to read bandwidth stats from the kext. Its not ment to be used in production.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
type Rxtxdata struct {
|
||||
rx uint64
|
||||
tx uint64
|
||||
}
|
||||
|
||||
type Key struct {
|
||||
localIP [4]uint32
|
||||
remoteIP [4]uint32
|
||||
localPort uint16
|
||||
remotePort uint16
|
||||
ipv6 bool
|
||||
protocol uint8
|
||||
}
|
||||
|
||||
var m = make(map[Key]Rxtxdata)
|
||||
|
||||
func BandwidthStatsWorker(ctx context.Context, collectInterval time.Duration, bandwidthUpdates chan *packet.BandwidthUpdate) error {
|
||||
// Setup ticker.
|
||||
ticker := time.NewTicker(collectInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Collect bandwidth at every tick.
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := reportBandwidth(ctx, bandwidthUpdates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.BandwidthUpdate) error {
|
||||
stats, err := GetConnectionsStats()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Report all statistics.
|
||||
for i, stat := range stats {
|
||||
connID := packet.CreateConnectionID(
|
||||
packet.IPProtocol(stat.protocol),
|
||||
convertArrayToIP(stat.localIP, stat.ipV6 == 1), stat.localPort,
|
||||
convertArrayToIP(stat.remoteIP, stat.ipV6 == 1), stat.remotePort,
|
||||
false,
|
||||
)
|
||||
update := &packet.BandwidthUpdate{
|
||||
ConnID: connID,
|
||||
BytesReceived: stat.receivedBytes,
|
||||
BytesSent: stat.transmittedBytes,
|
||||
Method: packet.Additive,
|
||||
}
|
||||
select {
|
||||
case bandwidthUpdates <- update:
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
log.Warningf("kext: bandwidth update queue is full, skipping rest of batch (%d entries)", len(stats)-i)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func StartBandwidthConsoleLogger() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
conns, err := GetConnectionsStats()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, conn := range conns {
|
||||
if conn.receivedBytes == 0 && conn.transmittedBytes == 0 {
|
||||
continue
|
||||
}
|
||||
key := Key{
|
||||
localIP: conn.localIP,
|
||||
remoteIP: conn.remoteIP,
|
||||
localPort: conn.localPort,
|
||||
remotePort: conn.remotePort,
|
||||
ipv6: conn.ipV6 == 1,
|
||||
protocol: conn.protocol,
|
||||
}
|
||||
|
||||
// First we get a "copy" of the entry
|
||||
if entry, ok := m[key]; ok {
|
||||
// Then we modify the copy
|
||||
entry.rx += conn.receivedBytes
|
||||
entry.tx += conn.transmittedBytes
|
||||
|
||||
// Then we reassign map entry
|
||||
m[key] = entry
|
||||
} else {
|
||||
m[key] = Rxtxdata{
|
||||
rx: conn.receivedBytes,
|
||||
tx: conn.transmittedBytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debug("----------------------------------")
|
||||
for key, value := range m {
|
||||
log.Debugf(
|
||||
"Conn: %d %s:%d %s:%d rx:%d tx:%d", key.protocol,
|
||||
convertArrayToIP(key.localIP, key.ipv6), key.localPort,
|
||||
convertArrayToIP(key.remoteIP, key.ipv6), key.remotePort,
|
||||
value.rx, value.tx,
|
||||
)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
4
service/firewall/interception/windowskext/doc.go
Normal file
4
service/firewall/interception/windowskext/doc.go
Normal file
@@ -0,0 +1,4 @@
|
||||
// +build windows
|
||||
|
||||
// Package windowskext provides network interception capabilities on windows via the Portmaster Kernel Extension.
|
||||
package windowskext
|
||||
209
service/firewall/interception/windowskext/handler.go
Normal file
209
service/firewall/interception/windowskext/handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/safing/portmaster/service/process"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
const (
|
||||
// VerdictRequestFlagFastTrackPermitted is set on packets that have been
|
||||
// already permitted by the kernel extension and the verdict request is only
|
||||
// informational.
|
||||
VerdictRequestFlagFastTrackPermitted = 1
|
||||
|
||||
// VerdictRequestFlagSocketAuth indicates that the verdict request is for a
|
||||
// connection that was intercepted on an ALE layer instead of in the network
|
||||
// stack itself. Thus, no packet data is available.
|
||||
VerdictRequestFlagSocketAuth = 2
|
||||
|
||||
// VerdictRequestFlagExpectSocketAuth indicates that the next verdict
|
||||
// requests is expected to be an informational socket auth request from
|
||||
// the ALE layer.
|
||||
VerdictRequestFlagExpectSocketAuth = 4
|
||||
)
|
||||
|
||||
// Do not change the order of the members! The structure is used to communicate with the kernel extension.
|
||||
// VerdictRequest is the request structure from the Kext.
|
||||
type VerdictRequest struct {
|
||||
id uint32 // ID from RegisterPacket
|
||||
pid uint64 // Process ID - info only packets
|
||||
direction uint8
|
||||
ipV6 uint8 // True: IPv6, False: IPv4
|
||||
protocol uint8 // Protocol
|
||||
flags uint8 // Flags
|
||||
localIP [4]uint32 // Source Address
|
||||
remoteIP [4]uint32 // Destination Address
|
||||
localPort uint16 // Source Port
|
||||
remotePort uint16 // Destination port
|
||||
_ uint32 // compartmentID
|
||||
_ uint32 // interfaceIndex
|
||||
_ uint32 // subInterfaceIndex
|
||||
packetSize uint32
|
||||
}
|
||||
|
||||
// Do not change the order of the members! The structure is used to communicate with the kernel extension.
|
||||
type VerdictInfo struct {
|
||||
id uint32 // ID from RegisterPacket
|
||||
verdict network.Verdict // verdict for the connection
|
||||
}
|
||||
|
||||
// Do not change the order of the members! The structure to communicate with the kernel extension.
|
||||
type VerdictUpdateInfo struct {
|
||||
localIP [4]uint32 // Source Address, only srcIP[0] if IPv4
|
||||
remoteIP [4]uint32 // Destination Address
|
||||
localPort uint16 // Source Port
|
||||
remotePort uint16 // Destination port
|
||||
ipV6 uint8 // True: IPv6, False: IPv4
|
||||
protocol uint8 // Protocol (UDP, TCP, ...)
|
||||
verdict uint8 // New verdict
|
||||
}
|
||||
|
||||
type ConnectionStat struct {
|
||||
localIP [4]uint32 //Source Address, only srcIP[0] if IPv4
|
||||
remoteIP [4]uint32 //Destination Address
|
||||
localPort uint16 //Source Port
|
||||
remotePort uint16 //Destination port
|
||||
receivedBytes uint64 //Number of bytes recived on this connection
|
||||
transmittedBytes uint64 //Number of bytes transsmited from this connection
|
||||
ipV6 uint8 //True: IPv6, False: IPv4
|
||||
protocol uint8 //Protocol (UDP, TCP, ...)
|
||||
}
|
||||
|
||||
type VersionInfo struct {
|
||||
major uint8
|
||||
minor uint8
|
||||
revision uint8
|
||||
build uint8
|
||||
}
|
||||
|
||||
func (v *VersionInfo) String() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", v.major, v.minor, v.revision, v.build)
|
||||
}
|
||||
|
||||
// Handler transforms received packets to the Packet interface.
|
||||
func Handler(ctx context.Context, packets chan packet.Packet) {
|
||||
for {
|
||||
packetInfo, err := RecvVerdictRequest()
|
||||
if err != nil {
|
||||
// Check if we are done with processing.
|
||||
if errors.Is(err, ErrKextNotReady) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Warningf("failed to get packet from windows kext: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if packetInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// log.Tracef("packet: %+v", packetInfo)
|
||||
|
||||
// New Packet
|
||||
new := &Packet{
|
||||
verdictRequest: packetInfo,
|
||||
verdictSet: abool.NewBool(false),
|
||||
}
|
||||
info := new.Info()
|
||||
info.Inbound = packetInfo.direction > 0
|
||||
info.InTunnel = false
|
||||
info.Protocol = packet.IPProtocol(packetInfo.protocol)
|
||||
info.PID = int(packetInfo.pid)
|
||||
info.SeenAt = time.Now()
|
||||
|
||||
// Check PID
|
||||
if info.PID == 0 {
|
||||
// Windows does not have zero PIDs.
|
||||
// Set to UndefinedProcessID.
|
||||
info.PID = process.UndefinedProcessID
|
||||
}
|
||||
|
||||
// Set IP version
|
||||
if packetInfo.ipV6 == 1 {
|
||||
info.Version = packet.IPv6
|
||||
} else {
|
||||
info.Version = packet.IPv4
|
||||
}
|
||||
|
||||
// Set IPs
|
||||
if info.Inbound {
|
||||
// Inbound
|
||||
info.Src = convertArrayToIP(packetInfo.remoteIP, info.Version == packet.IPv6)
|
||||
info.Dst = convertArrayToIP(packetInfo.localIP, info.Version == packet.IPv6)
|
||||
} else {
|
||||
// Outbound
|
||||
info.Src = convertArrayToIP(packetInfo.localIP, info.Version == packet.IPv6)
|
||||
info.Dst = convertArrayToIP(packetInfo.remoteIP, info.Version == packet.IPv6)
|
||||
}
|
||||
|
||||
// Set Ports
|
||||
if info.Inbound {
|
||||
// Inbound
|
||||
info.SrcPort = packetInfo.remotePort
|
||||
info.DstPort = packetInfo.localPort
|
||||
} else {
|
||||
// Outbound
|
||||
info.SrcPort = packetInfo.localPort
|
||||
info.DstPort = packetInfo.remotePort
|
||||
}
|
||||
|
||||
packets <- new
|
||||
}
|
||||
}
|
||||
|
||||
// convertArrayToIP converts an array of uint32 values to a net.IP address.
|
||||
func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP {
|
||||
if !ipv6 {
|
||||
addressBuf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(addressBuf, input[0])
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
addressBuf := make([]byte, 16)
|
||||
for i := 0; i < 4; i++ {
|
||||
binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
|
||||
}
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
func ipAddressToArray(ip net.IP, isIPv6 bool) [4]uint32 {
|
||||
array := [4]uint32{0}
|
||||
if isIPv6 {
|
||||
for i := 0; i < 4; i++ {
|
||||
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[i], 4), getUInt32Value(&ip[i]))
|
||||
}
|
||||
} else {
|
||||
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[0], 4), getUInt32Value(&ip[0]))
|
||||
}
|
||||
|
||||
return array
|
||||
}
|
||||
|
||||
func asByteArray[T any](obj *T) []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj))
|
||||
}
|
||||
|
||||
func asByteArrayWithLength[T any](obj *T, size uint32) []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), size)
|
||||
}
|
||||
|
||||
func getUInt32Value[T any](obj *T) uint32 {
|
||||
return *(*uint32)(unsafe.Pointer(obj))
|
||||
}
|
||||
344
service/firewall/interception/windowskext/kext.go
Normal file
344
service/firewall/interception/windowskext/kext.go
Normal file
@@ -0,0 +1,344 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// Package errors
|
||||
var (
|
||||
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
|
||||
ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension")
|
||||
|
||||
kextLock sync.RWMutex
|
||||
driverPath string
|
||||
|
||||
kextHandle windows.Handle
|
||||
service *KextService
|
||||
)
|
||||
|
||||
const (
|
||||
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
|
||||
winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
|
||||
driverName = "PortmasterKext"
|
||||
)
|
||||
|
||||
// Init initializes the DLL and the Kext (Kernel Driver).
|
||||
func Init(path string) error {
|
||||
kextHandle = winInvalidHandleValue
|
||||
driverPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start intercepting.
|
||||
func Start() error {
|
||||
kextLock.Lock()
|
||||
defer kextLock.Unlock()
|
||||
|
||||
// initialize and start driver service
|
||||
var err error
|
||||
service, err = createKextService(driverName, driverPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
running, err := service.isRunning()
|
||||
if err == nil && !running {
|
||||
err = service.start(true)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("service not initialized: %w", err)
|
||||
}
|
||||
|
||||
// Open the driver
|
||||
filename := `\\.\` + driverName
|
||||
kextHandle, err = openDriver(filename)
|
||||
|
||||
// driver was not installed
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open driver: %q %w", filename, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetKextHandler(handle windows.Handle) {
|
||||
kextHandle = handle
|
||||
}
|
||||
|
||||
func SetKextService(handle windows.Handle, path string) {
|
||||
service = &KextService{handle: handle}
|
||||
driverPath = path
|
||||
}
|
||||
|
||||
// Stop intercepting.
|
||||
func Stop() error {
|
||||
// Prepare kernel for shutdown
|
||||
err := shutdownRequest()
|
||||
if err != nil {
|
||||
log.Warningf("winkext: shutdown request failed: %s", err)
|
||||
}
|
||||
|
||||
kextLock.Lock()
|
||||
defer kextLock.Unlock()
|
||||
|
||||
err = closeDriver(kextHandle)
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to close the handle: %s", err)
|
||||
}
|
||||
|
||||
err = service.stop(true)
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to stop service: %s", err)
|
||||
}
|
||||
// Driver file may change on the next start so it's better to delete the service
|
||||
err = service.delete()
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to delete service: %s", err)
|
||||
}
|
||||
|
||||
kextHandle = winInvalidHandleValue
|
||||
return nil
|
||||
}
|
||||
|
||||
func shutdownRequest() error {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return ErrKextNotReady
|
||||
}
|
||||
// Sent a shutdown request so the kernel extension can prepare.
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_SHUTDOWN_REQUEST, nil, nil)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil.
|
||||
func RecvVerdictRequest() (*VerdictRequest, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
// DEBUG:
|
||||
// timestamp := time.Now()
|
||||
// defer log.Tracef("winkext: getting verdict request took %s", time.Since(timestamp))
|
||||
|
||||
// Initialize struct for the output data
|
||||
var new VerdictRequest
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&new)
|
||||
bytesRead, err := deviceIOControl(kextHandle, IOCTL_RECV_VERDICT_REQ, nil, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if bytesRead == 0 {
|
||||
return nil, nil // no error, no new verdict request
|
||||
}
|
||||
|
||||
return &new, nil
|
||||
}
|
||||
|
||||
// SetVerdict sets the verdict for a packet and/or connection.
|
||||
func SetVerdict(pkt *Packet, verdict network.Verdict) error {
|
||||
if pkt.verdictRequest.pid != 0 {
|
||||
return nil // Ignore info only packets
|
||||
}
|
||||
if pkt.verdictRequest.id == 0 {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: no packet ID", verdict)
|
||||
return ErrNoPacketID
|
||||
}
|
||||
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: kext not ready", verdict)
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
verdictInfo := VerdictInfo{pkt.verdictRequest.id, verdict}
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&verdictInfo)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_SET_VERDICT, data, nil)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPayload returns the payload of a packet.
|
||||
func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
|
||||
if packetID == 0 {
|
||||
return nil, ErrNoPacketID
|
||||
}
|
||||
|
||||
// Check if driver is initialized
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
buf := make([]byte, packetSize)
|
||||
|
||||
// Combine id and length
|
||||
payload := struct {
|
||||
id uint32
|
||||
length uint32
|
||||
}{packetID, packetSize}
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&payload)
|
||||
bytesRead, err := deviceIOControl(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check the result and return
|
||||
if bytesRead == 0 {
|
||||
return nil, errors.New("windows kext did not return any data")
|
||||
}
|
||||
|
||||
if bytesRead < uint32(len(buf)) {
|
||||
return buf[:bytesRead], nil
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func ClearCache() error {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
// Make driver request
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_CLEAR_CACHE, nil, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func UpdateVerdict(conn *network.Connection) error {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
var isIpv6 uint8 = 0
|
||||
if conn.IPVersion == packet.IPv6 {
|
||||
isIpv6 = 1
|
||||
}
|
||||
|
||||
// initialize variables
|
||||
info := VerdictUpdateInfo{
|
||||
ipV6: isIpv6,
|
||||
protocol: uint8(conn.IPProtocol),
|
||||
localIP: ipAddressToArray(conn.LocalIP, isIpv6 == 1),
|
||||
localPort: conn.LocalPort,
|
||||
remoteIP: ipAddressToArray(conn.Entity.IP, isIpv6 == 1),
|
||||
remotePort: conn.Entity.Port,
|
||||
verdict: uint8(conn.Verdict),
|
||||
}
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&info)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_UPDATE_VERDICT, data, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func GetVersion() (*VersionInfo, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
data := make([]uint8, 4)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_VERSION, nil, data)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
version := &VersionInfo{
|
||||
major: data[0],
|
||||
minor: data[1],
|
||||
revision: data[2],
|
||||
build: data[3],
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
var sizeOfConnectionStat = uint32(unsafe.Sizeof(ConnectionStat{}))
|
||||
|
||||
func GetConnectionsStats() ([]ConnectionStat, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
var data [100]ConnectionStat
|
||||
size := len(data)
|
||||
bytesReturned, err := deviceIOControl(kextHandle, IOCTL_GET_CONNECTIONS_STAT, asByteArray(&size), asByteArray(&data))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data[:bytesReturned/sizeOfConnectionStat], nil
|
||||
}
|
||||
|
||||
func openDriver(filename string) (windows.Handle, error) {
|
||||
u16filename, err := syscall.UTF16FromString(filename)
|
||||
if err != nil {
|
||||
return winInvalidHandleValue, fmt.Errorf("failed to convert driver filename to UTF16 string %w", err)
|
||||
}
|
||||
|
||||
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0)
|
||||
if err != nil {
|
||||
return winInvalidHandleValue, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func closeDriver(handle windows.Handle) error {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
return windows.CloseHandle(handle)
|
||||
}
|
||||
137
service/firewall/interception/windowskext/packet.go
Normal file
137
service/firewall/interception/windowskext/packet.go
Normal file
@@ -0,0 +1,137 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
// Packet represents an IP packet.
|
||||
type Packet struct {
|
||||
packet.Base
|
||||
|
||||
verdictRequest *VerdictRequest
|
||||
verdictSet *abool.AtomicBool
|
||||
|
||||
payloadLoaded bool
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// FastTrackedByIntegration returns whether the packet has been fast-track
|
||||
// accepted by the OS integration.
|
||||
func (pkt *Packet) FastTrackedByIntegration() bool {
|
||||
return pkt.verdictRequest.flags&VerdictRequestFlagFastTrackPermitted > 0
|
||||
}
|
||||
|
||||
// InfoOnly returns whether the packet is informational only and does not
|
||||
// represent an actual packet.
|
||||
func (pkt *Packet) InfoOnly() bool {
|
||||
return pkt.verdictRequest.flags&VerdictRequestFlagSocketAuth > 0
|
||||
}
|
||||
|
||||
// ExpectInfo returns whether the next packet is expected to be informational only.
|
||||
func (pkt *Packet) ExpectInfo() bool {
|
||||
return pkt.verdictRequest.flags&VerdictRequestFlagExpectSocketAuth > 0
|
||||
}
|
||||
|
||||
// GetPayload returns the full raw packet.
|
||||
func (pkt *Packet) LoadPacketData() error {
|
||||
pkt.lock.Lock()
|
||||
defer pkt.lock.Unlock()
|
||||
|
||||
if pkt.verdictRequest.id == 0 {
|
||||
return ErrNoPacketID
|
||||
}
|
||||
|
||||
if !pkt.payloadLoaded {
|
||||
pkt.payloadLoaded = true
|
||||
|
||||
payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to load payload: %s", err)
|
||||
return packet.ErrFailedToLoadPayload
|
||||
}
|
||||
|
||||
err = packet.Parse(payload, &pkt.Base)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to parse payload: %s", err)
|
||||
return packet.ErrFailedToLoadPayload
|
||||
}
|
||||
}
|
||||
|
||||
if len(pkt.Raw()) == 0 {
|
||||
return packet.ErrFailedToLoadPayload
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Accept accepts the packet.
|
||||
func (pkt *Packet) Accept() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, -network.VerdictAccept)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Block blocks the packet.
|
||||
func (pkt *Packet) Block() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, -network.VerdictBlock)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Drop drops the packet.
|
||||
func (pkt *Packet) Drop() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, -network.VerdictDrop)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PermanentAccept permanently accepts connection (and the current packet).
|
||||
func (pkt *Packet) PermanentAccept() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictAccept)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PermanentBlock permanently blocks connection (and the current packet).
|
||||
func (pkt *Packet) PermanentBlock() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictBlock)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PermanentDrop permanently drops connection (and the current packet).
|
||||
func (pkt *Packet) PermanentDrop() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictDrop)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet).
|
||||
func (pkt *Packet) RerouteToNameserver() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictRerouteToNameserver)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet).
|
||||
func (pkt *Packet) RerouteToTunnel() error {
|
||||
if pkt.verdictSet.SetToIf(false, true) {
|
||||
return SetVerdict(pkt, network.VerdictRerouteToTunnel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
198
service/firewall/interception/windowskext/service.go
Normal file
198
service/firewall/interception/windowskext/service.go
Normal file
@@ -0,0 +1,198 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type KextService struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func createKextService(driverName string, driverPath string) (*KextService, error) {
|
||||
// Open the service manager:
|
||||
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open service manager: %d", err)
|
||||
}
|
||||
defer windows.CloseServiceHandle(manager)
|
||||
|
||||
// Convert the driver name to a UTF16 string
|
||||
driverNameU16, err := syscall.UTF16FromString(driverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err)
|
||||
}
|
||||
|
||||
// Check if there is an old service.
|
||||
service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS)
|
||||
if err == nil {
|
||||
log.Warning("kext: old driver service was found")
|
||||
oldService := &KextService{handle: service}
|
||||
err := deleteService(manager, oldService, driverNameU16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete old driver service: %s", err)
|
||||
}
|
||||
|
||||
service = winInvalidHandleValue
|
||||
log.Info("kext: old driver service was deleted successfully")
|
||||
}
|
||||
|
||||
driverPathU16, err := syscall.UTF16FromString(driverPath)
|
||||
|
||||
// Create the service
|
||||
service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KextService{handle: service}, nil
|
||||
}
|
||||
|
||||
func deleteService(manager windows.Handle, service *KextService, driverName []uint16) error {
|
||||
// Stop and wait before deleting
|
||||
_ = service.stop(true)
|
||||
|
||||
// Try to delete even if stop failed
|
||||
err := service.delete()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old service: %s", err)
|
||||
}
|
||||
|
||||
// Wait until we can no longer open the old service.
|
||||
// Not very efficient but NotifyServiceStatusChange cannot be used with driver service.
|
||||
start := time.Now()
|
||||
timeLimit := time.Duration(30 * time.Second)
|
||||
for {
|
||||
handle, err := windows.OpenService(manager, &driverName[0], windows.SERVICE_ALL_ACCESS)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
_ = windows.CloseServiceHandle(handle)
|
||||
|
||||
if time.Since(start) > timeLimit {
|
||||
return fmt.Errorf("time limit reached")
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) isValid() bool {
|
||||
return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
|
||||
}
|
||||
|
||||
func (s *KextService) isRunning() (bool, error) {
|
||||
if !s.isValid() {
|
||||
return false, fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
var status windows.SERVICE_STATUS
|
||||
err := windows.QueryServiceStatus(s.handle, &status)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return status.CurrentState == windows.SERVICE_RUNNING, nil
|
||||
}
|
||||
|
||||
func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) {
|
||||
var status windows.SERVICE_STATUS
|
||||
status.CurrentState = windows.SERVICE_NO_CHANGE
|
||||
start := time.Now()
|
||||
for status.CurrentState != neededStatus {
|
||||
err := windows.QueryServiceStatus(handle, &status)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed while waiting for service to start: %w", err)
|
||||
}
|
||||
|
||||
if time.Since(start) > timeLimit {
|
||||
return false, fmt.Errorf("time limit reached")
|
||||
}
|
||||
|
||||
// Sleep for 1/10 of the wait hint, recommended time from microsoft
|
||||
time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *KextService) start(wait bool) error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
// Start the service:
|
||||
err := windows.StartService(s.handle, 0, nil)
|
||||
|
||||
if err != nil {
|
||||
err = windows.GetLastError()
|
||||
if err != windows.ERROR_SERVICE_ALREADY_RUNNING {
|
||||
// Failed to start service; clean-up:
|
||||
var status windows.SERVICE_STATUS
|
||||
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
|
||||
_ = windows.DeleteService(s.handle)
|
||||
_ = windows.CloseServiceHandle(s.handle)
|
||||
s.handle = winInvalidHandleValue
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for service to start
|
||||
if wait {
|
||||
success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second))
|
||||
if err != nil || !success {
|
||||
return fmt.Errorf("service did not start: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) stop(wait bool) error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
// Stop the service
|
||||
var status windows.SERVICE_STATUS
|
||||
err := windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service failed to stop: %w", err)
|
||||
}
|
||||
|
||||
// Wait for service to stop
|
||||
if wait {
|
||||
success, err := waitForServiceStatus(s.handle, windows.SERVICE_STOPPED, time.Duration(10*time.Second))
|
||||
if err != nil || !success {
|
||||
return fmt.Errorf("service did not stop: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) delete() error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
err := windows.DeleteService(s.handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete service: %s", err)
|
||||
}
|
||||
|
||||
// Service wont be deleted until all handles are closed.
|
||||
err = windows.CloseServiceHandle(s.handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close service handle: %s", err)
|
||||
}
|
||||
|
||||
s.handle = winInvalidHandleValue
|
||||
return nil
|
||||
}
|
||||
72
service/firewall/interception/windowskext/syscall.go
Normal file
72
service/firewall/interception/windowskext/syscall.go
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
METHOD_BUFFERED = 0
|
||||
METHOD_IN_DIRECT = 1
|
||||
METHOD_OUT_DIRECT = 2
|
||||
METHOD_NEITHER = 3
|
||||
|
||||
SIOCTL_TYPE = 40000
|
||||
)
|
||||
|
||||
var (
|
||||
IOCTL_VERSION = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_SHUTDOWN_REQUEST = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_UPDATE_VERDICT = ctlCode(SIOCTL_TYPE, 0x806, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_GET_CONNECTIONS_STAT = ctlCode(SIOCTL_TYPE, 0x807, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
)
|
||||
|
||||
func ctlCode(device_type, function, method, access uint32) uint32 {
|
||||
return (device_type << 16) | (access << 14) | (function << 2) | method
|
||||
}
|
||||
|
||||
func deviceIOControlAsync(handle windows.Handle, code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) {
|
||||
var inDataPtr *byte = nil
|
||||
var inDataSize uint32 = 0
|
||||
if inData != nil {
|
||||
inDataPtr = &inData[0]
|
||||
inDataSize = uint32(len(inData))
|
||||
}
|
||||
|
||||
var outDataPtr *byte = nil
|
||||
var outDataSize uint32 = 0
|
||||
if outData != nil {
|
||||
outDataPtr = &outData[0]
|
||||
outDataSize = uint32(len(outData))
|
||||
}
|
||||
|
||||
overlapped := &windows.Overlapped{}
|
||||
err := windows.DeviceIoControl(handle,
|
||||
code,
|
||||
inDataPtr, inDataSize,
|
||||
outDataPtr, outDataSize,
|
||||
nil, overlapped)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return overlapped, nil
|
||||
|
||||
}
|
||||
|
||||
func deviceIOControl(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
|
||||
overlapped, err := deviceIOControlAsync(handle, code, inData, outData)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var bytesReturned uint32
|
||||
err = windows.GetOverlappedResult(handle, overlapped, &bytesReturned, true)
|
||||
|
||||
return bytesReturned, err
|
||||
}
|
||||
Reference in New Issue
Block a user