Merge pull request #984 from safing/fix/verdict-cache-update
Fix/verdict cache update
This commit is contained in:
@@ -149,6 +149,10 @@ func resetAllConnectionVerdicts() {
|
||||
|
||||
// Save if verdict changed.
|
||||
if conn.Verdict.Firewall != previousVerdict {
|
||||
err := interception.UpdateVerdictOfConnection(conn)
|
||||
if err != nil {
|
||||
log.Debugf("filter: failed to update connection verdict: %s", err)
|
||||
}
|
||||
conn.Save()
|
||||
tracer.Infof("filter: verdict of connection %s changed from %s to %s", conn, previousVerdict.Verb(), conn.VerdictVerb())
|
||||
changedVerdicts++
|
||||
@@ -159,11 +163,6 @@ func resetAllConnectionVerdicts() {
|
||||
}
|
||||
tracer.Infof("filter: changed verdict on %d connections", changedVerdicts)
|
||||
tracer.Submit()
|
||||
|
||||
err := interception.ResetVerdictOfAllConnections()
|
||||
if err != nil {
|
||||
log.Errorf("interception: failed to remove persistent verdicts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func interceptionStart() error {
|
||||
|
||||
@@ -2,6 +2,7 @@ package interception
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/firewall/interception/nfq"
|
||||
"github.com/safing/portmaster/network"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
)
|
||||
|
||||
@@ -19,3 +20,8 @@ func stop() error {
|
||||
func ResetVerdictOfAllConnections() error {
|
||||
return nfq.DeleteAllMarkedConnection()
|
||||
}
|
||||
|
||||
// UpdateVerdictOfConnection deletes the verdict of specific connection so in can be initialized again with the next packet.
|
||||
func UpdateVerdictOfConnection(conn *network.Connection) error {
|
||||
return nfq.DeleteMarkedConnection(conn)
|
||||
}
|
||||
|
||||
@@ -4,22 +4,19 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portmaster/firewall/interception/windowskext"
|
||||
"github.com/safing/portmaster/network"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"github.com/safing/portmaster/updates"
|
||||
)
|
||||
|
||||
// start starts the interception.
|
||||
func start(ch chan packet.Packet) error {
|
||||
dllFile, err := updates.GetPlatformFile("kext/portmaster-kext.dll")
|
||||
if err != nil {
|
||||
return fmt.Errorf("interception: could not get kext dll: %s", err)
|
||||
}
|
||||
kextFile, err := updates.GetPlatformFile("kext/portmaster-kext.sys")
|
||||
if err != nil {
|
||||
return fmt.Errorf("interception: could not get kext sys: %s", err)
|
||||
}
|
||||
|
||||
err = windowskext.Init(dllFile.Path(), kextFile.Path())
|
||||
err = windowskext.Init(kextFile.Path())
|
||||
if err != nil {
|
||||
return fmt.Errorf("interception: could not init windows kext: %s", err)
|
||||
}
|
||||
@@ -43,3 +40,18 @@ func stop() error {
|
||||
func ResetVerdictOfAllConnections() error {
|
||||
return windowskext.ClearCache()
|
||||
}
|
||||
|
||||
// UpdateVerdictOfConnection updates the verdict of specific connection in the kernel extension.
|
||||
func UpdateVerdictOfConnection(conn *network.Connection) error {
|
||||
return windowskext.UpdateVerdict(conn)
|
||||
}
|
||||
|
||||
// GetKextVersion returns the version of the kernel extension.
|
||||
func GetKextVersion() (string, error) {
|
||||
version, err := windowskext.GetVersion()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return version.String(), nil
|
||||
}
|
||||
|
||||
@@ -4,20 +4,37 @@ package nfq
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
ct "github.com/florianl/go-conntrack"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/netenv"
|
||||
"github.com/safing/portmaster/network"
|
||||
)
|
||||
|
||||
// DeleteAllMarkedConnection deletes all marked entries from the conntrack table.
|
||||
func DeleteAllMarkedConnection() error {
|
||||
nfct, err := ct.Open(&ct.Config{})
|
||||
var nfct *ct.Nfct // Conntrack handler. NFCT: Network Filter Connection Tracking.
|
||||
|
||||
// InitNFCT initializes the network filter conntrack library.
|
||||
func InitNFCT() error {
|
||||
var err error
|
||||
nfct, err = ct.Open(&ct.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = nfct.Close() }()
|
||||
return nil
|
||||
}
|
||||
|
||||
// TeardownNFCT deinitializes the network filter conntrack library.
|
||||
func TeardownNFCT() {
|
||||
_ = nfct.Close()
|
||||
}
|
||||
|
||||
// DeleteAllMarkedConnection deletes all marked entries from the conntrack table.
|
||||
func DeleteAllMarkedConnection() error {
|
||||
if nfct == nil {
|
||||
return fmt.Errorf("nfq: nfct not initialized")
|
||||
}
|
||||
|
||||
// Delete all ipv4 marked connections
|
||||
deleted := deleteMarkedConnections(nfct, ct.IPv4)
|
||||
@@ -64,3 +81,43 @@ func deleteMarkedConnections(nfct *ct.Nfct, f ct.Family) (deleted int) {
|
||||
}
|
||||
return deleted
|
||||
}
|
||||
|
||||
// DeleteMarkedConnection removes a specific connection from the conntrack table.
|
||||
func DeleteMarkedConnection(conn *network.Connection) error {
|
||||
if nfct == nil {
|
||||
return fmt.Errorf("nfq: nfct not initialized")
|
||||
}
|
||||
|
||||
con := ct.Con{
|
||||
Origin: &ct.IPTuple{
|
||||
Src: &conn.LocalIP,
|
||||
Dst: &conn.Entity.IP,
|
||||
Proto: &ct.ProtoTuple{
|
||||
Number: &conn.Entity.Protocol,
|
||||
SrcPort: &conn.LocalPort,
|
||||
DstPort: &conn.Entity.Port,
|
||||
},
|
||||
},
|
||||
}
|
||||
connections, err := nfct.Get(ct.Conntrack, ct.IPv4, con)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nfq: failed to find entry for connection %s: %w", conn.String(), err)
|
||||
}
|
||||
|
||||
if len(connections) > 1 {
|
||||
log.Warningf("nfq: multiple entries found for single connection: %s -> %d", conn.String(), len(connections))
|
||||
}
|
||||
|
||||
for _, connection := range connections {
|
||||
deleteErr := nfct.Delete(ct.Conntrack, ct.IPv4, connection)
|
||||
if err == nil {
|
||||
err = deleteErr
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warningf("nfq: error while deleting conntrack entries for connection %s: %s", conn.String(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -147,6 +147,11 @@ func activateNfqueueFirewall() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := nfq.InitNFCT(); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = nfq.DeleteAllMarkedConnection()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -166,6 +171,9 @@ func DeactivateNfqueueFirewall() error {
|
||||
}
|
||||
}
|
||||
|
||||
_ = nfq.DeleteAllMarkedConnection()
|
||||
nfq.TeardownNFCT()
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
@@ -5,11 +6,14 @@ package windowskext
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/network"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
)
|
||||
|
||||
@@ -25,6 +29,7 @@ const (
|
||||
VerdictRequestFlagSocketAuth = 2
|
||||
)
|
||||
|
||||
// Do not change the order of the members! The structure is used to communicate with the kernel extension.
|
||||
// VerdictRequest is the request structure from the Kext.
|
||||
type VerdictRequest struct {
|
||||
id uint32 // ID from RegisterPacket
|
||||
@@ -43,16 +48,40 @@ type VerdictRequest struct {
|
||||
packetSize uint32
|
||||
}
|
||||
|
||||
// Do not change the order of the members! The structure is used to communicate with the kernel extension.
|
||||
type VerdictInfo struct {
|
||||
id uint32 // ID from RegisterPacket
|
||||
verdict network.Verdict // verdict for the connection
|
||||
}
|
||||
|
||||
// Do not change the order of the members! The structure to communicate with the kernel extension.
|
||||
type VerdictUpdateInfo struct {
|
||||
localIP [4]uint32 // Source Address, only srcIP[0] if IPv4
|
||||
remoteIP [4]uint32 // Destination Address
|
||||
localPort uint16 // Source Port
|
||||
remotePort uint16 // Destination port
|
||||
ipV6 uint8 // True: IPv6, False: IPv4
|
||||
protocol uint8 // Protocol (UDP, TCP, ...)
|
||||
verdict uint8 // New verdict
|
||||
}
|
||||
|
||||
type VersionInfo struct {
|
||||
major uint8
|
||||
minor uint8
|
||||
revision uint8
|
||||
build uint8
|
||||
}
|
||||
|
||||
func (v *VersionInfo) String() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", v.major, v.minor, v.revision, v.build)
|
||||
}
|
||||
|
||||
// Handler transforms received packets to the Packet interface.
|
||||
func Handler(packets chan packet.Packet) {
|
||||
if !ready.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
defer close(packets)
|
||||
|
||||
for {
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -145,3 +174,28 @@ func convertIPv6(input [4]uint32) net.IP {
|
||||
}
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
func ipAddressToArray(ip net.IP, isIPv6 bool) [4]uint32 {
|
||||
array := [4]uint32{0}
|
||||
if isIPv6 {
|
||||
for i := 0; i < 4; i++ {
|
||||
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[i], 4), getUInt32Value(&ip[i]))
|
||||
}
|
||||
} else {
|
||||
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[0], 4), getUInt32Value(&ip[0]))
|
||||
}
|
||||
|
||||
return array
|
||||
}
|
||||
|
||||
func asByteArray[T any](obj *T) []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj))
|
||||
}
|
||||
|
||||
func asByteArrayWithLength[T any](obj *T, size uint32) []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), size)
|
||||
}
|
||||
|
||||
func getUInt32Value[T any](obj *T) uint32 {
|
||||
return *(*uint32)(unsafe.Pointer(obj))
|
||||
}
|
||||
|
||||
@@ -7,14 +7,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/network"
|
||||
"github.com/tevino/abool"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
@@ -23,92 +22,23 @@ var (
|
||||
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
|
||||
ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension")
|
||||
|
||||
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
|
||||
|
||||
kext *WinKext
|
||||
kextLock sync.RWMutex
|
||||
ready = abool.NewBool(false)
|
||||
urgentRequests *int32
|
||||
)
|
||||
|
||||
func init() {
|
||||
var urgentRequestsValue int32
|
||||
urgentRequests = &urgentRequestsValue
|
||||
}
|
||||
|
||||
// WinKext holds the DLL handle.
|
||||
type WinKext struct {
|
||||
sync.RWMutex
|
||||
|
||||
dll *windows.DLL
|
||||
kextLock sync.RWMutex
|
||||
driverPath string
|
||||
|
||||
init *windows.Proc
|
||||
start *windows.Proc
|
||||
stop *windows.Proc
|
||||
recvVerdictRequest *windows.Proc
|
||||
setVerdict *windows.Proc
|
||||
getPayload *windows.Proc
|
||||
clearCache *windows.Proc
|
||||
}
|
||||
kextHandle windows.Handle
|
||||
service *KextService
|
||||
)
|
||||
|
||||
const (
|
||||
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
|
||||
winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
|
||||
driverName = "PortmasterKext"
|
||||
)
|
||||
|
||||
// Init initializes the DLL and the Kext (Kernel Driver).
|
||||
func Init(dllPath, driverPath string) error {
|
||||
|
||||
new := &WinKext{
|
||||
driverPath: driverPath,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// load dll
|
||||
new.dll, err = windows.LoadDLL(dllPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load functions
|
||||
new.init, err = new.dll.FindProc("PortmasterInit")
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err)
|
||||
}
|
||||
new.start, err = new.dll.FindProc("PortmasterStart")
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err)
|
||||
}
|
||||
new.stop, err = new.dll.FindProc("PortmasterStop")
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find proc PortmasterStop in dll: %s", err)
|
||||
}
|
||||
new.recvVerdictRequest, err = new.dll.FindProc("PortmasterRecvVerdictRequest")
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find proc PortmasterRecvVerdictRequest in dll: %s", err)
|
||||
}
|
||||
new.setVerdict, err = new.dll.FindProc("PortmasterSetVerdict")
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find proc PortmasterSetVerdict in dll: %s", err)
|
||||
}
|
||||
new.getPayload, err = new.dll.FindProc("PortmasterGetPayload")
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find proc PortmasterGetPayload in dll: %s", err)
|
||||
}
|
||||
new.clearCache, err = new.dll.FindProc("PortmasterClearCache")
|
||||
if err != nil {
|
||||
// the loaded dll is an old version
|
||||
log.Errorf("could not find proc PortmasterClearCache (v1.0.12+) in dll: %s", err)
|
||||
}
|
||||
|
||||
// initialize dll/kext
|
||||
rc, _, lastErr := new.init.Call()
|
||||
if rc != windows.NO_ERROR {
|
||||
return formatErr(lastErr, rc)
|
||||
}
|
||||
|
||||
// set kext
|
||||
kextLock.Lock()
|
||||
defer kextLock.Unlock()
|
||||
kext = new
|
||||
|
||||
func Init(path string) error {
|
||||
kextHandle = winInvalidHandleValue
|
||||
driverPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -117,19 +47,33 @@ func Start() error {
|
||||
kextLock.Lock()
|
||||
defer kextLock.Unlock()
|
||||
|
||||
// convert to C string
|
||||
charArray := make([]byte, len(kext.driverPath)+1)
|
||||
copy(charArray, []byte(kext.driverPath))
|
||||
charArray[len(charArray)-1] = 0 // force NULL byte at the end
|
||||
|
||||
rc, _, lastErr := kext.start.Call(
|
||||
uintptr(unsafe.Pointer(&charArray[0])),
|
||||
)
|
||||
if rc != windows.NO_ERROR {
|
||||
return formatErr(lastErr, rc)
|
||||
// initialize and start driver service
|
||||
var err error
|
||||
service, err = createKextService(driverName, driverPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
running, err := service.isRunning()
|
||||
if err == nil && !running {
|
||||
err = service.start(true)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("service not initialized: %w", err)
|
||||
}
|
||||
|
||||
// Open the driver
|
||||
filename := `\\.\` + driverName
|
||||
kextHandle, err = openDriver(filename)
|
||||
|
||||
// driver was not installed
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open driver: %q %w", filename, err)
|
||||
}
|
||||
|
||||
ready.Set()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,15 +81,27 @@ func Start() error {
|
||||
func Stop() error {
|
||||
kextLock.Lock()
|
||||
defer kextLock.Unlock()
|
||||
if !ready.IsSet() {
|
||||
return ErrKextNotReady
|
||||
}
|
||||
ready.UnSet()
|
||||
|
||||
rc, _, lastErr := kext.stop.Call()
|
||||
if rc != windows.NO_ERROR {
|
||||
return formatErr(lastErr, rc)
|
||||
err := closeDriver(kextHandle)
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to close the handle: %s", err)
|
||||
}
|
||||
|
||||
err = service.stop(true)
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to stop service: %s", err)
|
||||
}
|
||||
// Driver file may change on the next start so it's better to delete the service
|
||||
err = service.delete()
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to delete service: %s", err)
|
||||
}
|
||||
err = service.closeHandle()
|
||||
if err != nil {
|
||||
log.Warningf("winkext: failed to close the handle: %s", err)
|
||||
}
|
||||
|
||||
kextHandle = winInvalidHandleValue
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -153,36 +109,26 @@ func Stop() error {
|
||||
func RecvVerdictRequest() (*VerdictRequest, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
new := &VerdictRequest{}
|
||||
timestamp := time.Now()
|
||||
defer log.Tracef("winkext: getting verdict request took %s", time.Since(timestamp))
|
||||
// Initialize struct for the output data
|
||||
var new VerdictRequest
|
||||
|
||||
// wait for urgent requests to complete
|
||||
for i := 1; i <= 100; i++ {
|
||||
if atomic.LoadInt32(urgentRequests) <= 0 {
|
||||
break
|
||||
}
|
||||
if i == 100 {
|
||||
log.Warningf("winkext: RecvVerdictRequest waited 100 times")
|
||||
}
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
// Make driver request
|
||||
data := asByteArray(&new)
|
||||
bytesRead, err := deviceIOControl(kextHandle, IOCTL_RECV_VERDICT_REQ, nil, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if bytesRead == 0 {
|
||||
return nil, nil // no error, no new verdict request
|
||||
}
|
||||
|
||||
// timestamp := time.Now()
|
||||
rc, _, lastErr := kext.recvVerdictRequest.Call(
|
||||
uintptr(unsafe.Pointer(new)),
|
||||
)
|
||||
// log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp))
|
||||
|
||||
if rc != windows.NO_ERROR {
|
||||
if rc == winErrInvalidData {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, formatErr(lastErr, rc)
|
||||
}
|
||||
return new, nil
|
||||
return &new, nil
|
||||
}
|
||||
|
||||
// SetVerdict sets the verdict for a packet and/or connection.
|
||||
@@ -194,22 +140,19 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error {
|
||||
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s: kext not ready", verdict)
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
atomic.AddInt32(urgentRequests, 1)
|
||||
// timestamp := time.Now()
|
||||
rc, _, lastErr := kext.setVerdict.Call(
|
||||
uintptr(pkt.verdictRequest.id),
|
||||
uintptr(verdict),
|
||||
)
|
||||
// log.Tracef("winkext: settings verdict for packetID %d took %s", packetID, time.Now().Sub(timestamp))
|
||||
atomic.AddInt32(urgentRequests, -1)
|
||||
if rc != windows.NO_ERROR {
|
||||
verdictInfo := VerdictInfo{pkt.verdictRequest.id, verdict}
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&verdictInfo)
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_SET_VERDICT, data, nil)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id)
|
||||
return formatErr(lastErr, rc)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -220,34 +163,36 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
|
||||
return nil, ErrNoPacketID
|
||||
}
|
||||
|
||||
// Check if driver is initialized
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if !ready.IsSet() {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
buf := make([]byte, packetSize)
|
||||
|
||||
atomic.AddInt32(urgentRequests, 1)
|
||||
// timestamp := time.Now()
|
||||
rc, _, lastErr := kext.getPayload.Call(
|
||||
uintptr(packetID),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&packetSize)),
|
||||
)
|
||||
// log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp))
|
||||
atomic.AddInt32(urgentRequests, -1)
|
||||
// Combine id and length
|
||||
payload := struct {
|
||||
id uint32
|
||||
length uint32
|
||||
}{packetID, packetSize}
|
||||
|
||||
if rc != windows.NO_ERROR {
|
||||
return nil, formatErr(lastErr, rc)
|
||||
// Make driver request
|
||||
data := asByteArray(&payload)
|
||||
bytesRead, err := deviceIOControl(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if packetSize == 0 {
|
||||
// check the result and return
|
||||
if bytesRead == 0 {
|
||||
return nil, errors.New("windows kext did not return any data")
|
||||
}
|
||||
|
||||
if packetSize < uint32(len(buf)) {
|
||||
return buf[:packetSize], nil
|
||||
if bytesRead < uint32(len(buf)) {
|
||||
return buf[:bytesRead], nil
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
@@ -256,28 +201,94 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
|
||||
func ClearCache() error {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
if !ready.IsSet() {
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
if kext.clearCache == nil {
|
||||
log.Error("kext: cannot clear cache: clearCache function missing")
|
||||
}
|
||||
|
||||
rc, _, lastErr := kext.clearCache.Call()
|
||||
|
||||
if rc != windows.NO_ERROR {
|
||||
return formatErr(lastErr, rc)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatErr(err error, rc uintptr) error {
|
||||
sysErr, ok := err.(syscall.Errno)
|
||||
if ok {
|
||||
return fmt.Errorf("%s [LE 0x%X] [RC 0x%X]", err, uintptr(sysErr), rc)
|
||||
}
|
||||
// Make driver request
|
||||
_, err := deviceIOControl(kextHandle, IOCTL_CLEAR_CACHE, nil, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func UpdateVerdict(conn *network.Connection) error {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
var isIpv6 uint8 = 0
|
||||
if conn.IPVersion == packet.IPv6 {
|
||||
isIpv6 = 1
|
||||
}
|
||||
|
||||
// initialize variables
|
||||
info := VerdictUpdateInfo{
|
||||
ipV6: isIpv6,
|
||||
protocol: uint8(conn.IPProtocol),
|
||||
localIP: ipAddressToArray(conn.LocalIP, isIpv6 == 1),
|
||||
localPort: conn.LocalPort,
|
||||
remoteIP: ipAddressToArray(conn.Entity.IP, isIpv6 == 1),
|
||||
remotePort: conn.Entity.Port,
|
||||
verdict: uint8(conn.Verdict.Active),
|
||||
}
|
||||
|
||||
// Make driver request
|
||||
data := asByteArray(&info)
|
||||
err := deviceIoControlDirect(kextHandle, IOCTL_UPDATE_VERDICT, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func GetVersion() (*VersionInfo, error) {
|
||||
kextLock.RLock()
|
||||
defer kextLock.RUnlock()
|
||||
|
||||
// Check if driver is initialized
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
log.Error("kext: failed to clear the cache: kext not ready")
|
||||
return nil, ErrKextNotReady
|
||||
}
|
||||
|
||||
data := make([]uint8, 4)
|
||||
err := deviceIoControlDirect(kextHandle, IOCTL_VERSION, data)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
version := &VersionInfo{
|
||||
major: data[0],
|
||||
minor: data[1],
|
||||
revision: data[2],
|
||||
build: data[3],
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
func openDriver(filename string) (windows.Handle, error) {
|
||||
u16filename, err := syscall.UTF16FromString(filename)
|
||||
if err != nil {
|
||||
return winInvalidHandleValue, fmt.Errorf("failed to convert driver filename to UTF16 string %w", err)
|
||||
}
|
||||
|
||||
handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED, 0)
|
||||
if err != nil {
|
||||
return winInvalidHandleValue, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func closeDriver(handle windows.Handle) error {
|
||||
if kextHandle == winInvalidHandleValue {
|
||||
return ErrKextNotReady
|
||||
}
|
||||
|
||||
return windows.CloseHandle(handle)
|
||||
}
|
||||
|
||||
161
firewall/interception/windowskext/service.go
Normal file
161
firewall/interception/windowskext/service.go
Normal file
@@ -0,0 +1,161 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type KextService struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func createKextService(driverName string, driverPath string) (*KextService, error) {
|
||||
// Open the service manager:
|
||||
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open service manager: %d", err)
|
||||
}
|
||||
defer windows.CloseServiceHandle(manager)
|
||||
|
||||
driverNameU16, err := syscall.UTF16FromString(driverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err)
|
||||
}
|
||||
// Check if it's already created
|
||||
service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS)
|
||||
if err == nil {
|
||||
return &KextService{handle: service}, nil // service was already created
|
||||
}
|
||||
|
||||
driverPathU16, err := syscall.UTF16FromString(driverPath)
|
||||
|
||||
// Create the service
|
||||
service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KextService{handle: service}, nil
|
||||
}
|
||||
|
||||
func (s *KextService) isValid() bool {
|
||||
return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
|
||||
}
|
||||
|
||||
func (s *KextService) isRunning() (bool, error) {
|
||||
if !s.isValid() {
|
||||
return false, fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
var status windows.SERVICE_STATUS
|
||||
err := windows.QueryServiceStatus(s.handle, &status)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return status.CurrentState == windows.SERVICE_RUNNING, nil
|
||||
}
|
||||
|
||||
func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) {
|
||||
var status windows.SERVICE_STATUS
|
||||
status.CurrentState = windows.SERVICE_NO_CHANGE
|
||||
start := time.Now()
|
||||
for status.CurrentState == neededStatus {
|
||||
err := windows.QueryServiceStatus(handle, &status)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed while waiting for service to start: %w", err)
|
||||
}
|
||||
|
||||
if time.Since(start) > timeLimit {
|
||||
return false, fmt.Errorf("time limit reached")
|
||||
}
|
||||
|
||||
// Sleep for 1/10 of the wait hint, recommended time from microsoft
|
||||
time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *KextService) start(wait bool) error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
// Start the service:
|
||||
err := windows.StartService(s.handle, 0, nil)
|
||||
|
||||
if err != nil {
|
||||
err = windows.GetLastError()
|
||||
if err != windows.ERROR_SERVICE_ALREADY_RUNNING {
|
||||
// Failed to start service; clean-up:
|
||||
var status windows.SERVICE_STATUS
|
||||
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
|
||||
_ = windows.DeleteService(s.handle)
|
||||
_ = windows.CloseServiceHandle(s.handle)
|
||||
s.handle = winInvalidHandleValue
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for service to start
|
||||
if wait {
|
||||
success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second))
|
||||
if err != nil || !success {
|
||||
return fmt.Errorf("service did not start: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) stop(wait bool) error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
// Stop the service
|
||||
var status windows.SERVICE_STATUS
|
||||
err := windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service failed to stop: %w", err)
|
||||
}
|
||||
|
||||
// Wait for service to stop
|
||||
if wait {
|
||||
success, err := waitForServiceStatus(s.handle, windows.SERVICE_STOPPED, time.Duration(10*time.Second))
|
||||
if err != nil || !success {
|
||||
return fmt.Errorf("service did not stop: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) delete() error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
err := windows.DeleteService(s.handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete service: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *KextService) closeHandle() error {
|
||||
if !s.isValid() {
|
||||
return fmt.Errorf("kext service not initialized")
|
||||
}
|
||||
|
||||
err := windows.CloseServiceHandle(s.handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close service handle: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
89
firewall/interception/windowskext/syscall.go
Normal file
89
firewall/interception/windowskext/syscall.go
Normal file
@@ -0,0 +1,89 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package windowskext
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
METHOD_BUFFERED = 0
|
||||
METHOD_IN_DIRECT = 1
|
||||
METHOD_OUT_DIRECT = 2
|
||||
METHOD_NEITHER = 3
|
||||
|
||||
SIOCTL_TYPE = 40000
|
||||
)
|
||||
|
||||
var (
|
||||
IOCTL_VERSION = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_RECV_VERDICT_REQ_POLL = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) // Not used
|
||||
IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
IOCTL_UPDATE_VERDICT = ctlCode(SIOCTL_TYPE, 0x806, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
|
||||
)
|
||||
|
||||
func ctlCode(device_type, function, method, access uint32) uint32 {
|
||||
return (device_type << 16) | (access << 14) | (function << 2) | method
|
||||
}
|
||||
|
||||
func deviceIOControlAsync(handle windows.Handle, code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) {
|
||||
var inDataPtr *byte = nil
|
||||
var inDataSize uint32 = 0
|
||||
if inData != nil {
|
||||
inDataPtr = &inData[0]
|
||||
inDataSize = uint32(len(inData))
|
||||
}
|
||||
|
||||
var outDataPtr *byte = nil
|
||||
var outDataSize uint32 = 0
|
||||
if outData != nil {
|
||||
outDataPtr = &outData[0]
|
||||
outDataSize = uint32(len(outData))
|
||||
}
|
||||
|
||||
overlapped := &windows.Overlapped{}
|
||||
err := windows.DeviceIoControl(handle,
|
||||
code,
|
||||
inDataPtr, inDataSize,
|
||||
outDataPtr, outDataSize,
|
||||
nil, overlapped)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return overlapped, nil
|
||||
|
||||
}
|
||||
|
||||
func deviceIOControl(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) {
|
||||
overlapped, err := deviceIOControlAsync(handle, code, inData, outData)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var bytesReturned uint32
|
||||
err = windows.GetOverlappedResult(handle, overlapped, &bytesReturned, true)
|
||||
|
||||
return bytesReturned, err
|
||||
}
|
||||
|
||||
// Use for METHOD_NEITHER IOCTL, the data buffer is passed directly to the kernel
|
||||
func deviceIoControlDirect(handle windows.Handle, code uint32, data []byte) error {
|
||||
var dataPtr *byte = nil
|
||||
var dataSize uint32 = 0
|
||||
if data != nil {
|
||||
dataPtr = &data[0]
|
||||
dataSize = uint32(len(data))
|
||||
}
|
||||
|
||||
err := windows.DeviceIoControl(handle,
|
||||
code,
|
||||
dataPtr, dataSize,
|
||||
nil, 0,
|
||||
nil, nil)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -52,7 +52,6 @@ func MandatoryUpdates() (identifiers []string) {
|
||||
identifiers = append(
|
||||
identifiers,
|
||||
PlatformIdentifier("core/portmaster-core.exe"),
|
||||
PlatformIdentifier("kext/portmaster-kext.dll"),
|
||||
PlatformIdentifier("kext/portmaster-kext.sys"),
|
||||
PlatformIdentifier("kext/portmaster-kext.pdb"),
|
||||
PlatformIdentifier("start/portmaster-start.exe"),
|
||||
|
||||
Reference in New Issue
Block a user