diff --git a/firewall/interception/windowskext/bandwidth_stats.go b/firewall/interception/windowskext/bandwidth_stats.go index b013623c..6e9dd05f 100644 --- a/firewall/interception/windowskext/bandwidth_stats.go +++ b/firewall/interception/windowskext/bandwidth_stats.go @@ -6,9 +6,11 @@ package windowskext // This file contains example code how to read bandwidth stats from the kext. Its not ment to be used in production. import ( + "context" "time" "github.com/safing/portbase/log" + "github.com/safing/portmaster/network/packet" ) type Rxtxdata struct { @@ -27,7 +29,56 @@ type Key struct { var m = make(map[Key]Rxtxdata) -func StartBandwidthWorker() { +func BandwidthStatsWorker(ctx context.Context, collectInterval time.Duration, bandwidthUpdates chan *packet.BandwidthUpdate) error { + // Setup ticker. + ticker := time.NewTicker(collectInterval) + defer ticker.Stop() + + // Collect bandwidth at every tick. + for { + select { + case <-ticker.C: + err := reportBandwidth(ctx, bandwidthUpdates) + if err != nil { + return err + } + case <-ctx.Done(): + return nil + } + } +} + +func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.BandwidthUpdate) error { + stats, err := GetConnectionsStats() + if err != nil { + return err + } + + // Report all statistics. + for _, stat := range stats { + connID := packet.CreateConnectionID( + packet.IPProtocol(stat.protocol), + convertArrayToIP(stat.localIP, stat.ipV6 == 1), stat.localPort, + convertArrayToIP(stat.remoteIP, stat.ipV6 == 1), stat.remotePort, + false, + ) + update := &packet.BandwidthUpdate{ + ConnID: connID, + RecvBytes: stat.receivedBytes, + SentBytes: stat.transmittedBytes, + Method: packet.Additive, + } + select { + case bandwidthUpdates <- update: + case <-ctx.Done(): + return nil + } + } + + return nil +} + +func StartBandwithConsoleLogger() { go func() { ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() @@ -66,14 +117,13 @@ func StartBandwidthWorker() { } log.Debug("----------------------------------") for key, value := range m { - if key.ipv6 { - log.Debugf("Conn: %d %s:%d %s:%d rx:%d tx:%d", key.protocol, convertIPv6(key.localIP), key.localPort, convertIPv6(key.remoteIP), key.remotePort, value.rx, value.tx) - } else { - log.Debugf("Conn: %d %s:%d %s:%d rx:%d tx:%d", key.protocol, convertIPv4(key.localIP), key.localPort, convertIPv4(key.remoteIP), key.remotePort, value.rx, value.tx) - } - + log.Debugf( + "Conn: %d %s:%d %s:%d rx:%d tx:%d", key.protocol, + convertArrayToIP(key.localIP, key.ipv6), key.localPort, + convertArrayToIP(key.remoteIP, key.ipv6), key.remotePort, + value.rx, value.tx, + ) } - } }() } diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index d943fccc..f5d66761 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -4,6 +4,7 @@ package windowskext import ( + "context" "encoding/binary" "errors" "fmt" @@ -96,9 +97,7 @@ func (v *VersionInfo) String() string { } // Handler transforms received packets to the Packet interface. -func Handler(packets chan packet.Packet) { - defer close(packets) - +func Handler(ctx context.Context, packets chan packet.Packet) { for { packetInfo, err := RecvVerdictRequest() if err != nil { @@ -144,28 +143,14 @@ func Handler(packets chan packet.Packet) { } // Set IPs - if info.Version == packet.IPv4 { - // IPv4 - if info.Inbound { - // Inbound - info.Src = convertIPv4(packetInfo.remoteIP) - info.Dst = convertIPv4(packetInfo.localIP) - } else { - // Outbound - info.Src = convertIPv4(packetInfo.localIP) - info.Dst = convertIPv4(packetInfo.remoteIP) - } + if info.Inbound { + // Inbound + info.Src = convertArrayToIP(packetInfo.remoteIP, info.Version == packet.IPv6) + info.Dst = convertArrayToIP(packetInfo.localIP, info.Version == packet.IPv6) } else { - // IPv6 - if info.Inbound { - // Inbound - info.Src = convertIPv6(packetInfo.remoteIP) - info.Dst = convertIPv6(packetInfo.localIP) - } else { - // Outbound - info.Src = convertIPv6(packetInfo.localIP) - info.Dst = convertIPv6(packetInfo.remoteIP) - } + // Outbound + info.Src = convertArrayToIP(packetInfo.localIP, info.Version == packet.IPv6) + info.Dst = convertArrayToIP(packetInfo.remoteIP, info.Version == packet.IPv6) } // Set Ports @@ -183,14 +168,14 @@ func Handler(packets chan packet.Packet) { } } -// convertIPv4 as needed for data from the kernel -func convertIPv4(input [4]uint32) net.IP { - addressBuf := make([]byte, 4) - binary.BigEndian.PutUint32(addressBuf, input[0]) - return net.IP(addressBuf) -} +// convertArrayToIP converts an array of uint32 values to a net.IP address. +func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP { + if !ipv6 { + addressBuf := make([]byte, 4) + binary.BigEndian.PutUint32(addressBuf, input[0]) + return net.IP(addressBuf) + } -func convertIPv6(input [4]uint32) net.IP { addressBuf := make([]byte, 16) for i := 0; i < 4; i++ { binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index 7766a786..351d1b74 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -8,7 +8,6 @@ import ( "fmt" "sync" "syscall" - "time" "unsafe" "github.com/safing/portbase/log" @@ -290,6 +289,8 @@ func GetVersion() (*VersionInfo, error) { return version, nil } +var sizeOfConnectionStat = uint32(unsafe.Sizeof(ConnectionStat{})) + func GetConnectionsStats() ([]ConnectionStat, error) { kextLock.RLock() defer kextLock.RUnlock() @@ -302,12 +303,13 @@ func GetConnectionsStats() ([]ConnectionStat, error) { var data [100]ConnectionStat size := len(data) - _, err := deviceIOControl(kextHandle, IOCTL_GET_CONNECTIONS_STAT, asByteArray(&size), asByteArray(&data)) + bytesReturned, err := deviceIOControl(kextHandle, IOCTL_GET_CONNECTIONS_STAT, asByteArray(&size), asByteArray(&data)) if err != nil { return nil, err } - return data[:], nil + + return data[:bytesReturned/sizeOfConnectionStat], nil } func openDriver(filename string) (windows.Handle, error) {