diff --git a/windows_kext/kext_interface/command.go b/windows_kext/kext_interface/command.go index 45b5cc8c..57ba00ed 100644 --- a/windows_kext/kext_interface/command.go +++ b/windows_kext/kext_interface/command.go @@ -5,6 +5,7 @@ import ( "io" ) +// Command IDs. const ( CommandShutdown = 0 CommandVerdict = 1 @@ -17,8 +18,10 @@ const ( CommandCleanEndedConnections = 8 ) +// KextVerdict is the verdict ID used to with the kext. type KextVerdict uint8 +// Kext Verdicts. // Make sure this is in sync with the Rust version. const ( // VerdictUndecided is the default status of new connections. @@ -37,20 +40,20 @@ const ( type Verdict struct { command uint8 - Id uint64 + ID uint64 Verdict uint8 } type RedirectV4 struct { command uint8 - Id uint64 + ID uint64 RemoteAddress [4]byte RemotePort uint16 } type RedirectV6 struct { command uint8 - Id uint64 + ID uint64 RemoteAddress [16]byte RemotePort uint16 } @@ -75,46 +78,55 @@ type UpdateV6 struct { Verdict uint8 } +// SendShutdownCommand sends a Shutdown command to the kext. func SendShutdownCommand(writer io.Writer) error { _, err := writer.Write([]byte{CommandShutdown}) return err } +// SendVerdictCommand sends a Verdict command to the kext. func SendVerdictCommand(writer io.Writer, verdict Verdict) error { verdict.command = CommandVerdict return binary.Write(writer, binary.LittleEndian, verdict) } +// SendUpdateV4Command sends a UpdateV4 command to the kext. func SendUpdateV4Command(writer io.Writer, update UpdateV4) error { update.command = CommandUpdateV4 return binary.Write(writer, binary.LittleEndian, update) } +// SendUpdateV6Command sends a UpdateV6 command to the kext. func SendUpdateV6Command(writer io.Writer, update UpdateV6) error { update.command = CommandUpdateV6 return binary.Write(writer, binary.LittleEndian, update) } +// SendClearCacheCommand sends a ClearCache command to the kext. func SendClearCacheCommand(writer io.Writer) error { _, err := writer.Write([]byte{CommandClearCache}) return err } +// SendGetLogsCommand sends a GetLogs command to the kext. func SendGetLogsCommand(writer io.Writer) error { _, err := writer.Write([]byte{CommandGetLogs}) return err } +// SendGetBandwidthStatsCommand sends a GetBandwidthStats command to the kext. func SendGetBandwidthStatsCommand(writer io.Writer) error { _, err := writer.Write([]byte{CommandBandwidthStats}) return err } +// SendPrintMemoryStatsCommand sends a PrintMemoryStats command to the kext. func SendPrintMemoryStatsCommand(writer io.Writer) error { _, err := writer.Write([]byte{CommandPrintMemoryStats}) return err } +// SendCleanEndedConnectionsCommand sends a CleanEndedConnections command to the kext. func SendCleanEndedConnectionsCommand(writer io.Writer) error { _, err := writer.Write([]byte{CommandCleanEndedConnections}) return err diff --git a/windows_kext/kext_interface/info.go b/windows_kext/kext_interface/info.go index efc7b56d..98efe877 100644 --- a/windows_kext/kext_interface/info.go +++ b/windows_kext/kext_interface/info.go @@ -23,8 +23,8 @@ type connectionV4Internal struct { ProcessId uint64 Direction byte Protocol byte - LocalIp [4]byte - RemoteIp [4]byte + LocalIP [4]byte + RemoteIP [4]byte LocalPort uint16 RemotePort uint16 PayloadLayer uint8 @@ -40,19 +40,19 @@ func (c *ConnectionV4) Compare(other *ConnectionV4) bool { c.ProcessId == other.ProcessId && c.Direction == other.Direction && c.Protocol == other.Protocol && - c.LocalIp == other.LocalIp && - c.RemoteIp == other.RemoteIp && + c.LocalIP == other.LocalIP && + c.RemoteIP == other.RemoteIP && c.LocalPort == other.LocalPort && c.RemotePort == other.RemotePort } type connectionV6Internal struct { Id uint64 - ProcessId uint64 + ProcessID uint64 Direction byte Protocol byte - LocalIp [16]byte - RemoteIp [16]byte + LocalIP [16]byte + RemoteIP [16]byte LocalPort uint16 RemotePort uint16 PayloadLayer uint8 @@ -65,11 +65,11 @@ type ConnectionV6 struct { func (c ConnectionV6) Compare(other *ConnectionV6) bool { return c.Id == other.Id && - c.ProcessId == other.ProcessId && + c.ProcessID == other.ProcessID && c.Direction == other.Direction && c.Protocol == other.Protocol && - c.LocalIp == other.LocalIp && - c.RemoteIp == other.RemoteIp && + c.LocalIP == other.LocalIP && + c.RemoteIP == other.RemoteIP && c.LocalPort == other.LocalPort && c.RemotePort == other.RemotePort } diff --git a/windows_kext/kext_interface/kext_file_test.go b/windows_kext/kext_interface/kext_file_test.go index 0731df7c..83b7d33e 100644 --- a/windows_kext/kext_interface/kext_file_test.go +++ b/windows_kext/kext_interface/kext_file_test.go @@ -9,4 +9,4 @@ func (f *KextFile) Read(buffer []byte) (int, error) { return 0, nil } -func (f *KextFile) flush_buffer() {} +func (f *KextFile) flushBuffer() {} diff --git a/windows_kext/kext_interface/protocol_test.go b/windows_kext/kext_interface/protocol_test.go index 5807bf3e..c908be07 100644 --- a/windows_kext/kext_interface/protocol_test.go +++ b/windows_kext/kext_interface/protocol_test.go @@ -2,6 +2,7 @@ package kext_interface import ( "bytes" + "errors" "io" "math/rand" "os" @@ -9,36 +10,42 @@ import ( ) func TestRustInfoFile(t *testing.T) { + t.Parallel() + file, err := os.Open("../protocol/rust_info_test.bin") if err != nil { panic(err) } - defer file.Close() + defer func() { + _ = file.Close() + }() for { info, err := RecvInfo(file) if err != nil { - if err != io.EOF { + if errors.Is(err, io.EOF) { t.Errorf("unexpected error: %s\n", err) } return } - if info.LogLine != nil { + + switch { + case info.LogLine != nil: if info.LogLine.Severity != 1 { t.Errorf("unexpected Log severity: %d\n", info.LogLine.Severity) } - if info.LogLine.Line != "prefix: test log" { t.Errorf("unexpected Log line: %s\n", info.LogLine.Line) } - } else if info.ConnectionV4 != nil { + + case info.ConnectionV4 != nil: conn := info.ConnectionV4 expected := connectionV4Internal{ Id: 1, ProcessId: 2, Direction: 3, Protocol: 4, - LocalIp: [4]byte{1, 2, 3, 4}, - RemoteIp: [4]byte{2, 3, 4, 5}, + LocalIP: [4]byte{1, 2, 3, 4}, + RemoteIP: [4]byte{2, 3, 4, 5}, LocalPort: 5, RemotePort: 6, PayloadLayer: 7, @@ -49,15 +56,16 @@ func TestRustInfoFile(t *testing.T) { if !bytes.Equal(conn.Payload, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { t.Errorf("unexpected ConnectionV4 payload: %+v\n", conn.Payload) } - } else if info.ConnectionV6 != nil { + + case info.ConnectionV6 != nil: conn := info.ConnectionV6 expected := connectionV6Internal{ Id: 1, - ProcessId: 2, + ProcessID: 2, Direction: 3, Protocol: 4, - LocalIp: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - RemoteIp: [16]byte{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, + LocalIP: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + RemoteIP: [16]byte{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, LocalPort: 5, RemotePort: 6, PayloadLayer: 7, @@ -68,7 +76,8 @@ func TestRustInfoFile(t *testing.T) { if !bytes.Equal(conn.Payload, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { t.Errorf("unexpected ConnectionV6 payload: %+v\n", conn.Payload) } - } else if info.ConnectionEndV4 != nil { + + case info.ConnectionEndV4 != nil: endEvent := info.ConnectionEndV4 expected := ConnectionEndV4{ ProcessId: 1, @@ -82,7 +91,8 @@ func TestRustInfoFile(t *testing.T) { if *endEvent != expected { t.Errorf("unexpected ConnectionEndV4: %+v\n", endEvent) } - } else if info.ConnectionEndV6 != nil { + + case info.ConnectionEndV6 != nil: endEvent := info.ConnectionEndV6 expected := ConnectionEndV6{ ProcessId: 1, @@ -96,7 +106,8 @@ func TestRustInfoFile(t *testing.T) { if *endEvent != expected { t.Errorf("unexpected ConnectionEndV6: %+v\n", endEvent) } - } else if info.BandwidthStats != nil { + + case info.BandwidthStats != nil: stats := info.BandwidthStats if stats.Protocol != 1 { t.Errorf("unexpected Bandwidth stats protocol: %d\n", stats.Protocol) @@ -163,11 +174,15 @@ func TestRustInfoFile(t *testing.T) { } func TestGenerateCommandFile(t *testing.T) { + t.Parallel() + file, err := os.Create("go_command_test.bin") if err != nil { t.Errorf("failed to create file: %s", err) } - defer file.Close() + defer func() { + _ = file.Close() + }() enums := []byte{ CommandShutdown, CommandVerdict, @@ -181,64 +196,80 @@ func TestGenerateCommandFile(t *testing.T) { selected := make([]byte, 5000) for i := range selected { - selected[i] = enums[rand.Intn(len(enums))] + selected[i] = enums[rand.Intn(len(enums))] //nolint:gosec } for _, value := range selected { switch value { case CommandShutdown: - { - SendShutdownCommand(file) + err := SendShutdownCommand(file) + if err != nil { + t.Fatal(err) } - case CommandVerdict: - { - SendVerdictCommand(file, Verdict{ - Id: 1, - Verdict: 2, - }) - } - case CommandUpdateV4: - { - SendUpdateV4Command(file, UpdateV4{ - Protocol: 1, - LocalAddress: [4]byte{1, 2, 3, 4}, - LocalPort: 2, - RemoteAddress: [4]byte{2, 3, 4, 5}, - RemotePort: 3, - Verdict: 4, - }) - } - case CommandUpdateV6: - { - SendUpdateV6Command(file, UpdateV6{ - Protocol: 1, - LocalAddress: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - LocalPort: 2, - RemoteAddress: [16]byte{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, - RemotePort: 3, - Verdict: 4, - }) + case CommandVerdict: + err := SendVerdictCommand(file, Verdict{ + ID: 1, + Verdict: 2, + }) + if err != nil { + t.Fatal(err) } + + case CommandUpdateV4: + err := SendUpdateV4Command(file, UpdateV4{ + Protocol: 1, + LocalAddress: [4]byte{1, 2, 3, 4}, + LocalPort: 2, + RemoteAddress: [4]byte{2, 3, 4, 5}, + RemotePort: 3, + Verdict: 4, + }) + if err != nil { + t.Fatal(err) + } + + case CommandUpdateV6: + err := SendUpdateV6Command(file, UpdateV6{ + Protocol: 1, + LocalAddress: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + LocalPort: 2, + RemoteAddress: [16]byte{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, + RemotePort: 3, + Verdict: 4, + }) + if err != nil { + t.Fatal(err) + } + case CommandClearCache: - { - SendClearCacheCommand(file) + err := SendClearCacheCommand(file) + if err != nil { + t.Fatal(err) } + case CommandGetLogs: - { - SendGetLogsCommand(file) + err := SendGetLogsCommand(file) + if err != nil { + t.Fatal(err) } + case CommandBandwidthStats: - { - SendGetBandwidthStatsCommand(file) + err := SendGetBandwidthStatsCommand(file) + if err != nil { + t.Fatal(err) } + case CommandPrintMemoryStats: - { - SendPrintMemoryStatsCommand(file) + err := SendPrintMemoryStatsCommand(file) + if err != nil { + t.Fatal(err) } + case CommandCleanEndedConnections: - { - SendCleanEndedConnectionsCommand(file) + err := SendCleanEndedConnectionsCommand(file) + if err != nil { + t.Fatal(err) } } } diff --git a/windows_kext/wdk/README.md b/windows_kext/wdk/README.md index 675f2155..36107c4b 100644 --- a/windows_kext/wdk/README.md +++ b/windows_kext/wdk/README.md @@ -11,4 +11,4 @@ see: `wdk/src/irp_helper.rs` Open issues need to be resolved: https://github.com/microsoft/wdkmetadata/issues/59 -https://github.com/microsoft/windows-rs/issues/2805 \ No newline at end of file +https://github.com/microsoft/windows-rs/issues/2805