From 78a0b3c1fbabb0f850453aa985f417972ecb7ebe Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 26 Apr 2019 11:33:24 +0200 Subject: [PATCH] Add windowskext integration, update related packages --- .gitignore | 6 + dnsonly.go | 8 +- firewall/firewall.go | 31 ++- firewall/interception/interception_windows.go | 72 +++++- firewall/interception/kexttest/main.go | 21 -- firewall/interception/nfqueue/LICENSE | 201 ---------------- firewall/interception/nfqueue/README.md | 45 +--- firewall/interception/windowskext/handler.go | 214 ++++++++---------- firewall/interception/windowskext/kext.go | 189 ++++++++++++++-- firewall/interception/windowskext/packet.go | 71 +++++- .../windowskext/test/endian/main.go | 26 +++ .../interception/windowskext/test/main.go | 110 +++++++++ firewall/master.go | 14 +- main.go | 26 ++- nameserver/nameserver.go | 70 +++--- nameserver/only/nameserver.go | 17 +- network/communication.go | 4 +- network/packet/packet.go | 163 ++++++------- network/packet/parse.go | 91 ++++++++ network/status.go | 2 +- network/unknown.go | 4 +- pmctl/run.go | 14 +- pmctl/upgrade.go | 15 +- process/find.go | 26 +-- process/getpid_windows.go | 2 +- process/iphelper/get.go | 163 +++++++------ process/iphelper/tables.go | 40 ++-- process/iphelper/test/main.go | 2 +- process/process_windows.go | 4 +- profile/const_windows.go | 6 + updates/fetch.go | 9 +- updates/get.go | 2 +- updates/latest.go | 1 + 33 files changed, 979 insertions(+), 690 deletions(-) delete mode 100644 firewall/interception/kexttest/main.go delete mode 100644 firewall/interception/nfqueue/LICENSE create mode 100644 firewall/interception/windowskext/test/endian/main.go create mode 100644 firewall/interception/windowskext/test/main.go create mode 100644 network/packet/parse.go create mode 100644 profile/const_windows.go diff --git a/.gitignore b/.gitignore index 31265c6b..45ee86f2 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,9 @@ _testmain.go # Output of the go coverage tool, specifically when used with LiteIDE *.out + +# OS specifics +.DS_Store + +# Custom dev scripts +win_dev_* diff --git a/dnsonly.go b/dnsonly.go index 1be8919b..3d534a89 100644 --- a/dnsonly.go +++ b/dnsonly.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "os/signal" + "runtime" "syscall" "github.com/Safing/portbase/info" @@ -16,6 +17,8 @@ import ( func main() { + runtime.GOMAXPROCS(4) + // Set Info info.Set("Portmaster (DNS only)", "0.2.0", "AGPLv3", false) @@ -25,10 +28,7 @@ func main() { if err == modules.ErrCleanExit { os.Exit(0) } else { - err = modules.Shutdown() - if err != nil { - log.Shutdown() - } + modules.Shutdown() os.Exit(1) } } diff --git a/firewall/firewall.go b/firewall/firewall.go index 1dfafa03..e956fbe1 100644 --- a/firewall/firewall.go +++ b/firewall/firewall.go @@ -39,7 +39,7 @@ var ( ) func init() { - modules.Register("firewall", prep, start, stop, "core", "network", "nameserver", "profile") + modules.Register("firewall", prep, start, stop, "core", "network", "nameserver", "profile", "updates") } func prep() (err error) { @@ -91,26 +91,36 @@ func stop() error { func handlePacket(pkt packet.Packet) { - // log.Tracef("handling packet: %s", pkt) + log.Tracef("handling packet: %s", pkt) // allow local dns - if pkt.MatchesIP(packet.Remote, localNet4) && pkt.GetTCPUDPHeader() != nil && pkt.GetTCPUDPHeader().DstPort == 53 { + if pkt.Info().Src.Equal(pkt.Info().Dst) && pkt.Info().DstPort == 53 { + log.Tracef("accepting local dns: %s", pkt) pkt.PermanentAccept() return } - // allow ICMP and IGMP + // allow ICMP, IGMP and DHCP // TODO: actually handle these - switch pkt.GetIPHeader().Protocol { + switch pkt.Info().Protocol { case packet.ICMP: + log.Tracef("accepting ICMP: %s", pkt) pkt.PermanentAccept() return case packet.ICMPv6: + log.Tracef("accepting ICMPv6: %s", pkt) pkt.PermanentAccept() return case packet.IGMP: + log.Tracef("accepting IGMP: %s", pkt) pkt.PermanentAccept() return + case packet.UDP: + if pkt.Info().DstPort == 67 || pkt.Info().DstPort == 68 { + log.Tracef("accepting DHCP: %s", pkt) + pkt.PermanentAccept() + return + } } // log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetLinkID()) @@ -122,11 +132,11 @@ func handlePacket(pkt packet.Packet) { // check if packet is destined for tunnel // switch pkt.IPVersion() { // case packet.IPv4: - // if TunnelNet4 != nil && TunnelNet4.Contains(pkt.GetIPHeader().Dst) { + // if TunnelNet4 != nil && TunnelNet4.Contains(pkt.Info().Dst) { // tunnelHandler(pkt) // } // case packet.IPv6: - // if TunnelNet6 != nil && TunnelNet6.Contains(pkt.GetIPHeader().Dst) { + // if TunnelNet6 != nil && TunnelNet6.Contains(pkt.Info().Dst) { // tunnelHandler(pkt) // } // } @@ -169,8 +179,11 @@ func initialHandler(pkt packet.Packet, link *network.Link) { // add new Link to Communication (and save both) comm.AddLink(link) + log.Tracef("comm [%s] has new link [%s]", comm, link) + // reroute dns requests to nameserver - if comm.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 { + if comm.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) { + log.Tracef("redirecting [%s] to nameserver", link) link.RerouteToNameserver() verdict(pkt, link.GetVerdict()) link.StopFirewallHandler() @@ -283,7 +296,7 @@ func verdict(pkt packet.Packet, action network.Verdict) { } // func tunnelHandler(pkt packet.Packet) { -// tunnelInfo := GetTunnelInfo(pkt.GetIPHeader().Dst) +// tunnelInfo := GetTunnelInfo(pkt.Info().Dst) // if tunnelInfo == nil { // pkt.Block() // return diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index ddebb0b7..d084ee73 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -3,8 +3,12 @@ package interception import ( "fmt" - "github.com/Safing/portmaster/firewall/interception/windivert" + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/notifications" + "github.com/Safing/portbase/utils/osdetail" + "github.com/Safing/portmaster/firewall/interception/windowskext" "github.com/Safing/portmaster/network/packet" + "github.com/Safing/portmaster/updates" ) var Packets chan packet.Packet @@ -17,15 +21,73 @@ func init() { // Start starts the interception. func Start() error { - wd, err := windivert.New("/WinDivert.dll", "") + dllFile, err := updates.GetPlatformFile("kext/portmaster-kext.dll") if err != nil { - return fmt.Errorf("firewall/interception: could not init windivert: %s", err) + return fmt.Errorf("interception: could not get kext dll: %s", err) + } + kextFile, err := updates.GetPlatformFile("kext/portmaster-kext.sys") + if err != nil { + return fmt.Errorf("interception: could not get kext sys: %s", err) } - return wd.Packets(Packets) + err = windowskext.Init(dllFile.Path(), kextFile.Path()) + if err != nil { + return fmt.Errorf("interception: could not init windows kext: %s", err) + } + + err = windowskext.Start() + if err != nil { + return fmt.Errorf("interception: could not start windows kext: %s", err) + } + + go windowskext.Handler(Packets) + go handleWindowsDNSCache() + + return nil } // Stop starts the interception. func Stop() error { - return nil + return windowskext.Stop() +} + +func handleWindowsDNSCache() { + + err := osdetail.StopService("dnscache") + if err != nil { + // cannot stop dnscache, try disabling + if err == osdetail.ErrServiceNotStoppable { + err := osdetail.DisableDNSCache() + if err != nil { + log.Warningf("firewall/interception: failed to disable Windows Service \"DNS Client\" (dnscache) for better interception: %s", err) + notifyDisableDNSCache() + } + notifyRebootRequired() + return + } + + // error while stopping service + log.Warningf("firewall/interception: failed to stop Windows Service \"DNS Client\" (dnscache) for better interception: %s", err) + notifyDisableDNSCache() + } + + // log that service is stopped + log.Info("firewall/interception: Windows Service \"DNS Client\" (dnscache) is stopped for better interception") + +} + +func notifyDisableDNSCache() { + (¬ifications.Notification{ + ID: "windows-disable-dns-cache", + Message: "The Portmaster needs the Windows Service \"DNS Client\" (dnscache) to be disabled for best effectiveness.", + Type: notifications.Warning, + }).Init().Save() +} + +func notifyRebootRequired() { + (¬ifications.Notification{ + ID: "windows-dnscache-reboot-required", + Message: "Please restart your system to complete Portmaster integration.", + Type: notifications.Warning, + }).Init().Save() } diff --git a/firewall/interception/kexttest/main.go b/firewall/interception/kexttest/main.go deleted file mode 100644 index 9ef3aa3f..00000000 --- a/firewall/interception/kexttest/main.go +++ /dev/null @@ -1,21 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/Safing/portmaster/firewall/interception/windowskext" -) - -func main() { - kext, err := windowskext.New("./WinDivert.dll") - if err != nil { - panic(err) - } - - vR, err := kext.RecvVerdictRequest() - if err != nil { - panic(err) - } - - fmt.Printf("verdictRequest: %+v", vR) -} diff --git a/firewall/interception/nfqueue/LICENSE b/firewall/interception/nfqueue/LICENSE deleted file mode 100644 index ad410e11..00000000 --- a/firewall/interception/nfqueue/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ -Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/firewall/interception/nfqueue/README.md b/firewall/interception/nfqueue/README.md index 895a1fce..ebab35e4 100644 --- a/firewall/interception/nfqueue/README.md +++ b/firewall/interception/nfqueue/README.md @@ -1,42 +1,3 @@ -Go-NFQueue -========== -Go Wrapper For Creating IPTables' NFQueue clients in Go - -Usage ------- -Check the `examples/main.go` file - -```bash - cd $GOPATH/github.com/OneOfOne/go-nfqueue/examples - go build -race && sudo ./examples -``` -* Open another terminal : -```bash -sudo iptables -I INPUT 1 -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 -#or -sudo iptables -I INPUT -i eth0 -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 -curl --head localhost -ping localhost -sudo iptables -D INPUT -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 -``` -Then you can `ctrl+c` the program to exit. - -* If you have recent enough iptables/nfqueue you could also use a balanced (multithreaded queue). -* check the example in `examples/mq/multiqueue.go` - -```bash -iptables -I INPUT 1 -m conntrack --ctstate NEW -j NFQUEUE --queue-balance 0:5 --queue-cpu-fanout -``` -Notes ------ - -You must run the executable as root. -This is *WIP*, but all patches are welcome. - -License -------- -go-nfqueue is under the Apache v2 license, check the included license file. -Copyright © [Ahmed W.](http://www.limitlessfx.com/) -See the included `LICENSE` file. - -> Copyright (c) 2014 Ahmed W. \ No newline at end of file +Parts of this package (this directory) are forked from the go-nfqueue repo: https://github.com/OneOfOne/go-nfqueue +These portions are copyrighted by Ahmed W. +The (high probable) fork commit is: https://github.com/OneOfOne/go-nfqueue/commit/3bdd8bdfd98a1ed51119f9cf7494162484dfbe7c diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 192ffa40..53ca7582 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -1,154 +1,130 @@ package windowskext import ( - "errors" + "encoding/binary" "fmt" + "net" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" + "github.com/tevino/abool" "github.com/Safing/portbase/log" "github.com/Safing/portmaster/network/packet" - - "github.com/tevino/abool" ) -func (wd *WinDivert) Packets(packets chan packet.Packet) error { - go wd.packetHandler(packets) - return nil +// VerdictRequest is the request structure from the Kext. +type VerdictRequest struct { + id uint32 /* ID from RegisterPacket */ + processID uint64 /* Process ID. Nice to have*/ + direction uint8 + ipV6 uint8 /* True: IPv6, False: IPv4 */ + protocol uint8 /* Protocol */ + _ uint8 + localIP [4]uint32 /* Source Address */ + remoteIP [4]uint32 /* Destination Address */ + localPort uint16 /* Source Port */ + remotePort uint16 /* Destination port */ + compartmentId uint32 + interfaceIndex uint32 + subInterfaceIndex uint32 + packetSize uint32 } -func (wd *WinDivert) packetHandler(packets chan packet.Packet) { +// Handler transforms received packets to the Packet interface. +func Handler(packets chan packet.Packet) { + if !ready.IsSet() { + return + } + defer close(packets) for { - if !wd.valid.IsSet() { + if !ready.IsSet() { return } - packetData, packetAddress, err := wd.Recv() + packetInfo, err := RecvVerdictRequest() if err != nil { - log.Warningf("failed to get packet from windivert: %s", err) + log.Warningf("failed to get packet from windows kext: %s", err) continue } - ipHeader, tpcUdpHeader, payload, err := parseIpPacket(packetData) - if err != nil { - log.Warningf("failed to parse packet from windivert: %s", err) - log.Warningf("failed packet payload (%d): %s", len(packetData), string(packetData)) + if packetInfo == nil { continue } + // log.Tracef("packet: %+v", packetInfo) + + // New Packet new := &Packet{ - windivert: wd, - packetData: packetData, - packetAddress: packetAddress, - verdictSet: abool.NewBool(false), + verdictRequest: packetInfo, + verdictSet: abool.NewBool(false), } - new.IPHeader = ipHeader - new.TCPUDPHeader = tpcUdpHeader - new.Payload = payload - if packetAddress.Direction == directionInbound { - new.Direction = packet.InBound + + info := new.Info() + info.Direction = packetInfo.direction > 0 + info.InTunnel = false + info.Protocol = packet.IPProtocol(packetInfo.protocol) + + // IP version + if packetInfo.ipV6 == 1 { + info.Version = packet.IPv6 } else { - new.Direction = packet.OutBound + info.Version = packet.IPv4 + } + + // IPs + if info.Version == packet.IPv4 { + // IPv4 + if info.Direction { + // Inbound + info.Src = convertIPv4(packetInfo.remoteIP) + info.Dst = convertIPv4(packetInfo.localIP) + } else { + // Outbound + info.Src = convertIPv4(packetInfo.localIP) + info.Dst = convertIPv4(packetInfo.remoteIP) + } + } else { + // IPv6 + if info.Direction { + // Inbound + info.Src = convertIPv6(packetInfo.remoteIP) + info.Dst = convertIPv6(packetInfo.localIP) + } else { + // Outbound + info.Src = convertIPv6(packetInfo.localIP) + info.Dst = convertIPv6(packetInfo.remoteIP) + } + } + + // Ports + if info.Direction { + // Inbound + info.SrcPort = packetInfo.remotePort + info.DstPort = packetInfo.localPort + } else { + // Outbound + info.SrcPort = packetInfo.localPort + info.DstPort = packetInfo.remotePort } packets <- new } } -func parseIpPacket(packetData []byte) (ipHeader *packet.IPHeader, tpcUdpHeader *packet.TCPUDPHeader, payload []byte, err error) { - - var parsedPacket gopacket.Packet - - if len(packetData) == 0 { - return nil, nil, nil, errors.New("empty packet") - } - - switch packetData[0] >> 4 { - case 4: - parsedPacket = gopacket.NewPacket(packetData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) - if ipv4Layer := parsedPacket.Layer(layers.LayerTypeIPv4); ipv4Layer != nil { - ipv4, _ := ipv4Layer.(*layers.IPv4) - ipHeader = &packet.IPHeader{ - Version: 4, - Protocol: packet.IPProtocol(ipv4.Protocol), - Tos: ipv4.TOS, - TTL: ipv4.TTL, - Src: ipv4.SrcIP, - Dst: ipv4.DstIP, - } - } else { - var err error - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - err = errLayer.Error() - } - return nil, nil, nil, fmt.Errorf("failed to parse IPv4 packet: %s", err) - } - case 6: - parsedPacket = gopacket.NewPacket(packetData, layers.LayerTypeIPv6, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) - if ipv6Layer := parsedPacket.Layer(layers.LayerTypeIPv6); ipv6Layer != nil { - ipv6, _ := ipv6Layer.(*layers.IPv6) - ipHeader = &packet.IPHeader{ - Version: 6, - Protocol: packet.IPProtocol(ipv6.NextHeader), - Tos: ipv6.TrafficClass, - TTL: ipv6.HopLimit, - Src: ipv6.SrcIP, - Dst: ipv6.DstIP, - } - } else { - var err error - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - err = errLayer.Error() - } - return nil, nil, nil, fmt.Errorf("failed to parse IPv6 packet: %s", err) - } - default: - return nil, nil, nil, errors.New("unknown IP version") - } - - switch ipHeader.Protocol { - case packet.TCP: - if tcpLayer := parsedPacket.Layer(layers.LayerTypeTCP); tcpLayer != nil { - tcp, _ := tcpLayer.(*layers.TCP) - tpcUdpHeader = &packet.TCPUDPHeader{ - SrcPort: uint16(tcp.SrcPort), - DstPort: uint16(tcp.DstPort), - Checksum: tcp.Checksum, - } - } else { - var err error - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - err = errLayer.Error() - } - return nil, nil, nil, fmt.Errorf("could not parse TCP layer: %s", err) - } - case packet.UDP: - if udpLayer := parsedPacket.Layer(layers.LayerTypeUDP); udpLayer != nil { - udp, _ := udpLayer.(*layers.UDP) - tpcUdpHeader = &packet.TCPUDPHeader{ - SrcPort: uint16(udp.SrcPort), - DstPort: uint16(udp.DstPort), - Checksum: udp.Checksum, - } - } else { - var err error - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - err = errLayer.Error() - } - return nil, nil, nil, fmt.Errorf("could not parse UDP layer: %s", err) - } - } - - if appLayer := parsedPacket.ApplicationLayer(); appLayer != nil { - payload = appLayer.Payload() - } - - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - return nil, nil, nil, errLayer.Error() - } - - return +func convertIPv4(input [4]uint32) net.IP { + return net.IPv4( + uint8(input[0]>>24&0xFF), + uint8(input[0]>>16&0xFF), + uint8(input[0]>>8&0xFF), + uint8(input[0]&0xFF), + ) +} + +func convertIPv6(input [4]uint32) net.IP { + addressBuf := make([]byte, 16) + for i := 0; i < 4; i++ { + binary.BigEndian.PutUint32(addressBuf[i:i+3], input[i]) + } + return net.IP(addressBuf) } diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index 0d8e264c..52488494 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -1,61 +1,206 @@ package windowskext import ( + "errors" "fmt" + "sync" + "syscall" "unsafe" + "github.com/Safing/portmaster/network" "github.com/tevino/abool" "golang.org/x/sys/windows" ) +// Package errors +var ( + ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands") + + kext *WinKext + kextLock sync.RWMutex + ready = abool.NewBool(false) +) + +// WinKext holds the DLL handle. type WinKext struct { - dll *windows.DLL + sync.RWMutex + dll *windows.DLL + driverPath string + + init *windows.Proc + start *windows.Proc + stop *windows.Proc recvVerdictRequest *windows.Proc - - valid *abool.AtomicBool + setVerdict *windows.Proc + getPayload *windows.Proc } -type VerdictRequest struct { - ID uint32 - ProcessID uint32 - Direction bool - IPv6 bool - Protocol uint8 - SrcIP [4]uint32 - DstIP [4]uint32 - SrcPort uint16 - DstPort uint16 -} +// Init initializes the DLL and the Kext (Kernel Driver). +func Init(dllPath, driverPath string) error { -func New(dllLocation string) (*WinKext, error) { + new := &WinKext{ + driverPath: driverPath, + } - new := &WinKext{} var err error // load dll - new.dll, err = windows.LoadDLL(dllLocation) + new.dll, err = windows.LoadDLL(dllPath) if err != nil { - return nil, err + return err } // load functions + new.init, err = new.dll.FindProc("PortmasterInit") + if err != nil { + return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err) + } + new.start, err = new.dll.FindProc("PortmasterStart") + if err != nil { + return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err) + } + new.stop, err = new.dll.FindProc("PortmasterStop") + if err != nil { + return fmt.Errorf("could not find proc PortmasterStop in dll: %s", err) + } new.recvVerdictRequest, err = new.dll.FindProc("PortmasterRecvVerdictRequest") if err != nil { - return nil, fmt.Errorf("could not find proc PortmasterRecvVerdictRequest: %s", err) + return fmt.Errorf("could not find proc PortmasterRecvVerdictRequest in dll: %s", err) + } + new.setVerdict, err = new.dll.FindProc("PortmasterSetVerdict") + if err != nil { + return fmt.Errorf("could not find proc PortmasterSetVerdict in dll: %s", err) + } + new.getPayload, err = new.dll.FindProc("PortmasterGetPayload") + if err != nil { + return fmt.Errorf("could not find proc PortmasterGetPayload in dll: %s", err) } - return new, nil + // initialize dll/kext + rc, _, lastErr := new.init.Call() + if rc != windows.NO_ERROR { + return formatErr(lastErr) + } + + // set kext + kextLock.Lock() + defer kextLock.Unlock() + kext = new + + return nil } -func (kext *WinKext) RecvVerdictRequest() (*VerdictRequest, error) { +// Start intercepting. +func Start() error { + kextLock.Lock() + defer kextLock.Unlock() + + // convert to C string + charArray := make([]byte, len(kext.driverPath)+1) + copy(charArray, []byte(kext.driverPath)) + charArray[len(charArray)-1] = 0 // force NULL byte at the end + + rc, _, lastErr := kext.start.Call( + uintptr(unsafe.Pointer(&charArray[0])), + ) + if rc != windows.NO_ERROR { + return formatErr(lastErr) + } + + ready.Set() + return nil +} + +// Stop intercepting. +func Stop() error { + kextLock.Lock() + defer kextLock.Unlock() + if !ready.IsSet() { + return ErrKextNotReady + } + ready.UnSet() + + rc, _, lastErr := kext.stop.Call() + if rc != windows.NO_ERROR { + return formatErr(lastErr) + } + return nil +} + +// RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil. +func RecvVerdictRequest() (*VerdictRequest, error) { + kextLock.RLock() + defer kextLock.RUnlock() + if !ready.IsSet() { + return nil, ErrKextNotReady + } + new := &VerdictRequest{} rc, _, lastErr := kext.recvVerdictRequest.Call( uintptr(unsafe.Pointer(new)), ) if rc != 0 { - return nil, lastErr + if rc == 13 /* ERROR_INVALID_DATA */ { + return nil, nil + } + return nil, formatErr(lastErr) } return new, nil } + +// SetVerdict sets the verdict for a packet and/or connection. +func SetVerdict(packetID uint32, verdict network.Verdict) error { + kextLock.RLock() + defer kextLock.RUnlock() + if !ready.IsSet() { + return ErrKextNotReady + } + + rc, _, lastErr := kext.setVerdict.Call( + uintptr(packetID), + uintptr(verdict), + ) + if rc != windows.NO_ERROR { + return formatErr(lastErr) + } + return nil +} + +// GetPayload returns the payload of a packet. +func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) { + kextLock.RLock() + defer kextLock.RUnlock() + if !ready.IsSet() { + return nil, ErrKextNotReady + } + + buf := make([]byte, packetSize) + + rc, _, lastErr := kext.getPayload.Call( + uintptr(packetID), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&packetSize)), + ) + if rc != windows.NO_ERROR { + return nil, formatErr(lastErr) + } + + if packetSize == 0 { + return nil, errors.New("windows kext did not return any data") + } + + if packetSize < uint32(len(buf)) { + return buf[:packetSize], nil + } + return buf, nil +} + +func formatErr(err error) error { + sysErr, ok := err.(syscall.Errno) + if ok { + return fmt.Errorf("%s [0x%X]", err, uintptr(sysErr)) + } + return err +} diff --git a/firewall/interception/windowskext/packet.go b/firewall/interception/windowskext/packet.go index 0a5ef492..408edde4 100644 --- a/firewall/interception/windowskext/packet.go +++ b/firewall/interception/windowskext/packet.go @@ -1,55 +1,108 @@ package windowskext import ( + "sync" + "github.com/tevino/abool" + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/network" "github.com/Safing/portmaster/network/packet" ) +// Packet represents an IP packet. type Packet struct { packet.PacketBase - kextID uint32 - packetData []byte + verdictRequest *VerdictRequest + verdictSet *abool.AtomicBool - verdictSet *abool.AtomicBool + payloadLoaded bool + lock sync.Mutex } +// GetPayload returns the full raw packet. +func (pkt *Packet) GetPayload() ([]byte, error) { + pkt.lock.Lock() + defer pkt.lock.Unlock() + + if !pkt.payloadLoaded { + pkt.payloadLoaded = true + + payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize) + if err != nil { + log.Errorf("windowskext: failed to load payload %s", err) + return nil, packet.ErrFailedToLoadPayload + } + pkt.Payload = payload + } + + if len(pkt.Payload) == 0 { + return nil, packet.ErrFailedToLoadPayload + } + return pkt.Payload, nil +} + +// Accept accepts the packet. func (pkt *Packet) Accept() error { if pkt.verdictSet.SetToIf(false, true) { - return pkt.windivert.Send(pkt.packetData, pkt.packetAddress) + return SetVerdict(pkt.verdictRequest.id, -network.VerdictAccept) } return nil } +// Block blocks the packet. func (pkt *Packet) Block() error { if pkt.verdictSet.SetToIf(false, true) { - // TODO: implement blocking mechanism - return nil + return SetVerdict(pkt.verdictRequest.id, -network.VerdictBlock) } return nil } +// Drop drops the packet. func (pkt *Packet) Drop() error { + if pkt.verdictSet.SetToIf(false, true) { + return SetVerdict(pkt.verdictRequest.id, -network.VerdictDrop) + } return nil } +// PermanentAccept permanently accepts connection (and the current packet). func (pkt *Packet) PermanentAccept() error { - return pkt.Accept() + if pkt.verdictSet.SetToIf(false, true) { + return SetVerdict(pkt.verdictRequest.id, network.VerdictAccept) + } + return nil } +// PermanentBlock permanently blocks connection (and the current packet). func (pkt *Packet) PermanentBlock() error { - return pkt.Block() + if pkt.verdictSet.SetToIf(false, true) { + return SetVerdict(pkt.verdictRequest.id, network.VerdictBlock) + } + return nil } +// PermanentDrop permanently drops connection (and the current packet). func (pkt *Packet) PermanentDrop() error { - return pkt.Drop() + if pkt.verdictSet.SetToIf(false, true) { + return SetVerdict(pkt.verdictRequest.id, network.VerdictDrop) + } + return nil } +// RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet). func (pkt *Packet) RerouteToNameserver() error { + if pkt.verdictSet.SetToIf(false, true) { + return SetVerdict(pkt.verdictRequest.id, network.VerdictRerouteToNameserver) + } return nil } +// RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet). func (pkt *Packet) RerouteToTunnel() error { + if pkt.verdictSet.SetToIf(false, true) { + return SetVerdict(pkt.verdictRequest.id, network.VerdictRerouteToTunnel) + } return nil } diff --git a/firewall/interception/windowskext/test/endian/main.go b/firewall/interception/windowskext/test/endian/main.go new file mode 100644 index 00000000..22116b78 --- /dev/null +++ b/firewall/interception/windowskext/test/endian/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "unsafe" +) + +const integerSize int = int(unsafe.Sizeof(0)) + +func isBigEndian() bool { + var i int = 0x1 + bs := (*[integerSize]byte)(unsafe.Pointer(&i)) + if bs[0] == 0 { + return true + } else { + return false + } +} + +func main() { + if isBigEndian() { + fmt.Println("System is Big Endian (Network Byte Order): uint16 0x1234 is 0x1234 in memory") + } else { + fmt.Println("System is Little Endian (Host Byte Order): uint16 0x1234 is 0x3412 in memory") + } +} diff --git a/firewall/interception/windowskext/test/main.go b/firewall/interception/windowskext/test/main.go new file mode 100644 index 00000000..f9040510 --- /dev/null +++ b/firewall/interception/windowskext/test/main.go @@ -0,0 +1,110 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/Safing/portbase/log" + "github.com/Safing/portmaster/firewall/interception/windowskext" + "github.com/Safing/portmaster/network/packet" +) + +var ( + packets chan packet.Packet +) + +func main() { + + // check parameter count + if len(os.Args) < 3 { + fmt.Printf("usage: %s ", os.Args[0]) + os.Exit(1) + } + + // check parameters + for i := 1; i < 3; i++ { + if _, err := os.Stat(os.Args[i]); err != nil { + fmt.Printf("could not access %s: %s", os.Args[i], err) + os.Exit(2) + } + } + + // logging + log.Start() + log.Info("starting Portmaster Windows Kext Test Program") + + // init + err := windowskext.Init(os.Args[1], os.Args[2]) + if err != nil { + panic(err) + } + + // start + err = windowskext.Start() + if err != nil { + panic(err) + } + + packets = make(chan packet.Packet, 1000) + go windowskext.Handler(packets) + go handlePackets() + + // catch interrupt for clean shutdown + signalCh := make(chan os.Signal) + signal.Notify( + signalCh, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + ) + <-signalCh + fmt.Println(" ") + log.Warning("program was interrupted, shutting down") + + // stop + err = windowskext.Stop() + if err != nil { + fmt.Printf("error stopping: %s\n", err) + } + + log.Info("shutdown complete") + log.Shutdown() + + os.Exit(0) +} + +func handlePackets() { + for { + pkt := <-packets + + if pkt == nil { + log.Infof("stopped handling packets") + return + } + + log.Infof("received packet: %s", pkt) + + data, err := pkt.GetPayload() + if err != nil { + log.Errorf("failed to get payload: %s", err) + } else { + log.Infof("payload is: %x", data) + } + + // reroute dns requests to nameserver + if pkt.IsOutbound() && !pkt.Info().Src.Equal(pkt.Info().Dst) && pkt.Info().DstPort == 53 { + log.Infof("rerouting %s", pkt) + pkt.RerouteToNameserver() + continue + } + + // accept all + log.Infof("accepting %s", pkt) + pkt.PermanentAccept() + + } +} diff --git a/firewall/master.go b/firewall/master.go index 117368c2..d2625886 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -496,18 +496,14 @@ func DecideOnLink(comm *network.Communication, link *network.Link, pkt packet.Pa // remoteIP var remoteIP net.IP if comm.Direction { - remoteIP = pkt.GetIPHeader().Src + remoteIP = pkt.Info().Src } else { - remoteIP = pkt.GetIPHeader().Dst + remoteIP = pkt.Info().Dst } // protocol and destination port - protocol := uint8(pkt.GetIPHeader().Protocol) - var dstPort uint16 - tcpUDPHeader := pkt.GetTCPUDPHeader() - if tcpUDPHeader != nil { - dstPort = tcpUDPHeader.DstPort - } + protocol := uint8(pkt.Info().Protocol) + dstPort := pkt.Info().DstPort // check endpoints list result, reason := profileSet.CheckEndpointIP(fqdn, remoteIP, protocol, dstPort, comm.Direction) @@ -635,7 +631,7 @@ func DecideOnLink(comm *network.Communication, link *network.Link, pkt packet.Pa case "permit-domain-distinct": // everything already set case "permit-ip", "permit-ip-incoming": - if pkt.GetIPHeader().Version == packet.IPv4 { + if pkt.Info().Version == packet.IPv4 { new.Type = profile.EptIPv4 } else { new.Type = profile.EptIPv6 diff --git a/main.go b/main.go index 90809691..43a636a0 100644 --- a/main.go +++ b/main.go @@ -39,10 +39,7 @@ func main() { if err == modules.ErrCleanExit { os.Exit(0) } else { - err = modules.Shutdown() - if err != nil { - log.Shutdown() - } + modules.Shutdown() os.Exit(1) } } @@ -53,6 +50,7 @@ func main() { signal.Notify( signalCh, os.Interrupt, + os.Kill, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, @@ -60,9 +58,18 @@ func main() { ) select { case <-signalCh: + fmt.Println(" ") log.Warning("main: program was interrupted, shutting down.") + // catch signals during shutdown + go func() { + for { + <-signalCh + fmt.Println(" again, but already shutting down") + } + }() + if printStackOnExit { fmt.Println("=== PRINTING STACK ===") pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) @@ -70,13 +77,18 @@ func main() { } go func() { - time.Sleep(3 * time.Second) + time.Sleep(5 * time.Second) fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") pprof.Lookup("goroutine").WriteTo(os.Stdout, 2) os.Exit(1) }() - modules.Shutdown() - os.Exit(0) + + err := modules.Shutdown() + if err != nil { + os.Exit(1) + } else { + os.Exit(0) + } case <-modules.ShuttingDown(): } diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index a0dd7a7d..378e19dc 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -43,7 +43,7 @@ func prep() error { } func start() error { - server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"} + server := &dns.Server{Addr: "0.0.0.0:53", Net: "udp"} dns.HandleFunc(".", handleRequest) go run(server) return nil @@ -68,16 +68,47 @@ func nxDomain(w dns.ResponseWriter, query *dns.Msg) { func handleRequest(w dns.ResponseWriter, query *dns.Msg) { - // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain - // only process first question, that's how everyone does it. question := query.Question[0] fqdn := dns.Fqdn(question.Name) qtype := dns.Type(question.Qtype) - // use this to time how long it takes process this request - // timed := time.Now() - // defer log.Tracef("nameserver: took %s to handle request for %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String()) + // get addresses + remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr) + if !ok { + log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", fqdn, qtype) + return + } + localAddr, ok := w.RemoteAddr().(*net.UDPAddr) + if !ok { + log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", fqdn, qtype) + return + } + + // ignore external request + if !remoteAddr.IP.Equal(localAddr.IP) { + log.Warningf("nameserver: external request for %s%s, ignoring", fqdn, qtype) + return + } + + log.Tracef("nameserver: handling request for %s%s from %s:%d", fqdn, qtype, remoteAddr.IP, remoteAddr.Port) + + // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain + + // check class + if question.Qclass != dns.ClassINET { + // we only serve IN records, return nxdomain + nxDomain(w, query) + return + } + + // handle request for localhost + if fqdn == "localhost." { + m := new(dns.Msg) + m.SetReply(query) + m.Answer = localhostIPs + w.WriteMsg(m) + } // check if valid domain name if !netutils.IsValidFqdn(fqdn) { @@ -96,37 +127,12 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { return } - // check class - if question.Qclass != dns.ClassINET { - // we only serve IN records, send NXDOMAIN - nxDomain(w, query) - return - } - - // handle request for localhost - if fqdn == "localhost." { - m := new(dns.Msg) - m.SetReply(query) - m.Answer = localhostIPs - w.WriteMsg(m) - } - - // get remote address - // start := time.Now() - rAddr, ok := w.RemoteAddr().(*net.UDPAddr) - // log.Tracef("nameserver: took %s to get remote address", time.Since(start)) - if !ok { - log.Warningf("nameserver: could not get address of request, returning nxdomain") - nxDomain(w, query) - return - } - // [1/2] use this to time how long it takes to get process info // timed := time.Now() // get connection // start = time.Now() - comm, err := network.GetCommunicationByDNSRequest(rAddr.IP, uint16(rAddr.Port), fqdn) + comm, err := network.GetCommunicationByDNSRequest(remoteAddr.IP, uint16(remoteAddr.Port), fqdn) // log.Tracef("nameserver: took %s to get comms (and maybe process)", time.Since(start)) if err != nil { log.Warningf("nameserver: someone is requesting %s, but could not identify process: %s, returning nxdomain", fqdn, err) diff --git a/nameserver/only/nameserver.go b/nameserver/only/nameserver.go index 4c1edc27..a82a8016 100644 --- a/nameserver/only/nameserver.go +++ b/nameserver/only/nameserver.go @@ -1,6 +1,7 @@ package only import ( + "net" "time" "github.com/miekg/dns" @@ -18,7 +19,7 @@ func init() { } func start() error { - server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"} + server := &dns.Server{Addr: "0.0.0.0:53", Net: "udp"} dns.HandleFunc(".", handleRequest) go run(server) return nil @@ -51,6 +52,15 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { fqdn := dns.Fqdn(question.Name) qtype := dns.Type(question.Qtype) + // debug log + rAddr, ok := w.RemoteAddr().(*net.UDPAddr) + if !ok { + log.Warningf("nameserver: could not get address of request, returning nxdomain") + nxDomain(w, query) + return + } + // log.Tracef("nameserver: got request for %s%s from %s:%d", fqdn, qtype, rAddr.IP, uint16(rAddr.Port)) + // use this to time how long it takes process this request // timed := time.Now() // defer log.Tracef("nameserver: took %s to handle request for %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String()) @@ -65,7 +75,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { // check for possible DNS tunneling / data transmission // TODO: improve this lms := algs.LmsScoreOfDomain(fqdn) - // log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) + log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) if lms < 10 { log.Tracef("nameserver: possible data tunnel: %s has lms score of %f, returning nxdomain", fqdn, lms) nxDomain(w, query) @@ -85,7 +95,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { // log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start)) if rrCache == nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains - log.Infof("nameserver: %s is nxdomain", fqdn) + log.Infof("nameserver: %s%s is nxdomain", fqdn, qtype) nxDomain(w, query) return } @@ -97,4 +107,5 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { m.Ns = rrCache.Ns m.Extra = rrCache.Extra w.WriteMsg(m) + log.Tracef("nameserver: replied to request for %s%s from %s:%d", fqdn, qtype, rAddr.IP, uint16(rAddr.Port)) } diff --git a/network/communication.go b/network/communication.go index 85bfca1a..053022ec 100644 --- a/network/communication.go +++ b/network/communication.go @@ -154,7 +154,7 @@ func GetCommunicationByFirstPacket(pkt packet.Packet) (*Communication, error) { // Incoming if direction { - switch netutils.ClassifyIP(pkt.GetIPHeader().Src) { + switch netutils.ClassifyIP(pkt.Info().Src) { case netutils.HostLocal: domain = IncomingHost case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: @@ -186,7 +186,7 @@ func GetCommunicationByFirstPacket(pkt packet.Packet) (*Communication, error) { if err != nil { // if no domain could be found, it must be a direct connection (ie. no DNS) - switch netutils.ClassifyIP(pkt.GetIPHeader().Dst) { + switch netutils.ClassifyIP(pkt.Info().Dst) { case netutils.HostLocal: domain = PeerHost case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: diff --git a/network/packet/packet.go b/network/packet/packet.go index 16dcf166..dac03710 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -3,6 +3,7 @@ package packet import ( + "errors" "fmt" "net" ) @@ -43,6 +44,10 @@ const ( STOP ) +var ( + ErrFailedToLoadPayload = errors.New("could not load packet payload") +) + // Returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16 func (v IPVersion) ByteSize() int { switch v { @@ -92,58 +97,59 @@ func (v Verdict) String() string { return fmt.Sprintf("", uint8(v)) } -type IPHeader struct { - Version IPVersion +// PacketInfo holds IP and TCP/UDP header information +type PacketInfo struct { + Direction bool + InTunnel bool - Tos, TTL uint8 - Protocol IPProtocol - Src, Dst net.IP -} - -type TCPUDPHeader struct { + Version IPVersion + Src, Dst net.IP + Protocol IPProtocol SrcPort, DstPort uint16 - Checksum uint16 //not implemented } type PacketBase struct { - linkID string - Direction bool - InTunnel bool - Payload []byte - *IPHeader - *TCPUDPHeader + info PacketInfo + linkID string + Payload []byte } -func (pkt *PacketBase) GetIPHeader() *IPHeader { - return pkt.IPHeader +func (pkt *PacketBase) Info() *PacketInfo { + return &pkt.info } -func (pkt *PacketBase) GetTCPUDPHeader() *TCPUDPHeader { - return pkt.TCPUDPHeader -} - -func (pkt *PacketBase) GetPayload() []byte { - return pkt.Payload +func (pkt *PacketBase) SetPacketInfo(packetInfo PacketInfo) { + pkt.info = packetInfo } func (pkt *PacketBase) SetInbound() { - pkt.Direction = true + pkt.info.Direction = true } func (pkt *PacketBase) SetOutbound() { - pkt.Direction = false + pkt.info.Direction = false } func (pkt *PacketBase) IsInbound() bool { - return pkt.Direction + return pkt.info.Direction } func (pkt *PacketBase) IsOutbound() bool { - return !pkt.Direction + return !pkt.info.Direction } -func (pkt *PacketBase) IPVersion() IPVersion { - return pkt.Version +func (pkt *PacketBase) HasPorts() bool { + switch pkt.info.Protocol { + case TCP: + return true + case UDP: + return true + } + return false +} + +func (pkt *PacketBase) GetPayload() ([]byte, error) { + return pkt.Payload, ErrFailedToLoadPayload } func (pkt *PacketBase) GetLinkID() string { @@ -154,82 +160,63 @@ func (pkt *PacketBase) GetLinkID() string { } func (pkt *PacketBase) createLinkID() { - if pkt.IPHeader.Protocol == TCP || pkt.IPHeader.Protocol == UDP { - if pkt.Direction { - pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Dst, pkt.DstPort, pkt.Src, pkt.SrcPort) + if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { + if pkt.info.Direction { + pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) } else { - pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort) + pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) } } else { - if pkt.Direction { - pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Dst, pkt.Src) + if pkt.info.Direction { + pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) } else { - pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Src, pkt.Dst) + pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) } } } // Matches checks if a the packet matches a given endpoint (remote or local) in protocol, network and port. // +// Comparison matrix: // IN OUT // Local Dst Src // Remote Src Dst // -func (pkt *PacketBase) MatchesAddress(endpoint bool, protocol IPProtocol, network *net.IPNet, port uint16) bool { - if pkt.Protocol != protocol { +func (pkt *PacketBase) MatchesAddress(remote bool, protocol IPProtocol, network *net.IPNet, port uint16) bool { + if pkt.info.Protocol != protocol { return false } - if pkt.Direction != endpoint { - if !network.Contains(pkt.Src) { + if pkt.info.Direction != remote { + if !network.Contains(pkt.info.Src) { return false } - if port != 0 && pkt.TCPUDPHeader != nil { - if pkt.SrcPort != port { - return false - } + if pkt.info.SrcPort != port { + return false } } else { - if !network.Contains(pkt.Dst) { + if !network.Contains(pkt.info.Dst) { return false } - if port != 0 && pkt.TCPUDPHeader != nil { - if pkt.DstPort != port { - return false - } + if pkt.info.DstPort != port { + return false } } return true } func (pkt *PacketBase) MatchesIP(endpoint bool, network *net.IPNet) bool { - if pkt.Direction != endpoint { - if network.Contains(pkt.Src) { + if pkt.info.Direction != endpoint { + if network.Contains(pkt.info.Src) { return true } } else { - if network.Contains(pkt.Dst) { + if network.Contains(pkt.info.Dst) { return true } } return false } -// func (pkt *PacketBase) Accept() error { -// return nil -// } -// -// func (pkt *PacketBase) Drop() error { -// return nil -// } -// -// func (pkt *PacketBase) Block() error { -// return nil -// } -// -// func (pkt *PacketBase) Verdict(verdict Verdict) error { -// return nil -// } - // FORMATTING func (pkt *PacketBase) String() string { @@ -238,45 +225,45 @@ func (pkt *PacketBase) String() string { // FmtPacket returns the most important information about the packet as a string func (pkt *PacketBase) FmtPacket() string { - if pkt.IPHeader.Protocol == TCP || pkt.IPHeader.Protocol == UDP { - if pkt.Direction { - return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.Protocol, pkt.Dst, pkt.DstPort, pkt.Src, pkt.SrcPort) + if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { + if pkt.info.Direction { + return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) } - return fmt.Sprintf("OUT %s %s:%d <-> %s:%d", pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort) + return fmt.Sprintf("OUT %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) } - if pkt.Direction { - return fmt.Sprintf("IN %s %s <-> %s", pkt.Protocol, pkt.Dst, pkt.Src) + if pkt.info.Direction { + return fmt.Sprintf("IN %s %s <-> %s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) } - return fmt.Sprintf("OUT %s %s <-> %s", pkt.Protocol, pkt.Src, pkt.Dst) + return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) } // FmtProtocol returns the protocol as a string func (pkt *PacketBase) FmtProtocol() string { - return pkt.IPHeader.Protocol.String() + return pkt.info.Protocol.String() } // FmtRemoteIP returns the remote IP address as a string func (pkt *PacketBase) FmtRemoteIP() string { - if pkt.Direction { - return pkt.IPHeader.Src.String() + if pkt.info.Direction { + return pkt.info.Src.String() } - return pkt.IPHeader.Dst.String() + return pkt.info.Dst.String() } // FmtRemotePort returns the remote port as a string func (pkt *PacketBase) FmtRemotePort() string { - if pkt.TCPUDPHeader != nil { - if pkt.Direction { - return fmt.Sprintf("%d", pkt.TCPUDPHeader.SrcPort) + if pkt.info.SrcPort != 0 { + if pkt.info.Direction { + return fmt.Sprintf("%d", pkt.info.SrcPort) } - return fmt.Sprintf("%d", pkt.TCPUDPHeader.DstPort) + return fmt.Sprintf("%d", pkt.info.DstPort) } return "-" } // FmtRemoteAddress returns the full remote address (protocol, IP, port) as a string func (pkt *PacketBase) FmtRemoteAddress() string { - return fmt.Sprintf("%s:%s:%s", pkt.IPHeader.Protocol.String(), pkt.FmtRemoteIP(), pkt.FmtRemotePort()) + return fmt.Sprintf("%s:%s:%s", pkt.info.Protocol.String(), pkt.FmtRemoteIP(), pkt.FmtRemotePort()) } // Packet is an interface to a network packet to provide object behaviour the same across all systems @@ -292,15 +279,15 @@ type Packet interface { RerouteToTunnel() error // INFO - GetIPHeader() *IPHeader - GetTCPUDPHeader() *TCPUDPHeader - GetPayload() []byte + Info() *PacketInfo + SetPacketInfo(PacketInfo) IsInbound() bool IsOutbound() bool SetInbound() SetOutbound() + HasPorts() bool + GetPayload() ([]byte, error) GetLinkID() string - IPVersion() IPVersion // MATCHING MatchesAddress(bool, IPProtocol, *net.IPNet, uint16) bool diff --git a/network/packet/parse.go b/network/packet/parse.go new file mode 100644 index 00000000..7987e40f --- /dev/null +++ b/network/packet/parse.go @@ -0,0 +1,91 @@ +package packet + +import ( + "errors" + "fmt" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// Parse parses an IP packet and saves the information in the given packet object. +func Parse(packetData []byte, packet *PacketBase) error { + + var parsedPacket gopacket.Packet + + if len(packetData) == 0 { + return errors.New("empty packet") + } + + switch packetData[0] >> 4 { + case 4: + parsedPacket = gopacket.NewPacket(packetData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) + if ipv4Layer := parsedPacket.Layer(layers.LayerTypeIPv4); ipv4Layer != nil { + ipv4, _ := ipv4Layer.(*layers.IPv4) + packet.info.Version = IPv4 + packet.info.Protocol = IPProtocol(ipv4.Protocol) + packet.info.Src = ipv4.SrcIP + packet.info.Dst = ipv4.DstIP + } else { + var err error + if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { + err = errLayer.Error() + } + return fmt.Errorf("failed to parse IPv4 packet: %s", err) + } + case 6: + parsedPacket = gopacket.NewPacket(packetData, layers.LayerTypeIPv6, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) + if ipv6Layer := parsedPacket.Layer(layers.LayerTypeIPv6); ipv6Layer != nil { + ipv6, _ := ipv6Layer.(*layers.IPv6) + packet.info.Version = IPv6 + packet.info.Protocol = IPProtocol(ipv6.NextHeader) + packet.info.Src = ipv6.SrcIP + packet.info.Dst = ipv6.DstIP + } else { + var err error + if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { + err = errLayer.Error() + } + return fmt.Errorf("failed to parse IPv6 packet: %s", err) + } + default: + return errors.New("unknown IP version") + } + + switch packet.info.Protocol { + case TCP: + if tcpLayer := parsedPacket.Layer(layers.LayerTypeTCP); tcpLayer != nil { + tcp, _ := tcpLayer.(*layers.TCP) + packet.info.SrcPort = uint16(tcp.SrcPort) + packet.info.DstPort = uint16(tcp.DstPort) + } else { + var err error + if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { + err = errLayer.Error() + } + return fmt.Errorf("could not parse TCP layer: %s", err) + } + case UDP: + if udpLayer := parsedPacket.Layer(layers.LayerTypeUDP); udpLayer != nil { + udp, _ := udpLayer.(*layers.UDP) + packet.info.SrcPort = uint16(udp.SrcPort) + packet.info.DstPort = uint16(udp.DstPort) + } else { + var err error + if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { + err = errLayer.Error() + } + return fmt.Errorf("could not parse UDP layer: %s", err) + } + } + + if appLayer := parsedPacket.ApplicationLayer(); appLayer != nil { + packet.Payload = appLayer.Payload() + } + + if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { + return errLayer.Error() + } + + return nil +} diff --git a/network/status.go b/network/status.go index d25a28c5..6b2dae35 100644 --- a/network/status.go +++ b/network/status.go @@ -3,7 +3,7 @@ package network // Verdict describes the decision made about a connection or link. -type Verdict uint8 +type Verdict int8 // List of values a Status can have const ( diff --git a/network/unknown.go b/network/unknown.go index c8161208..08503ffa 100644 --- a/network/unknown.go +++ b/network/unknown.go @@ -16,7 +16,7 @@ const ( // GetUnknownCommunication returns the connection to a packet of unknown owner. func GetUnknownCommunication(pkt packet.Packet) (*Communication, error) { if pkt.IsInbound() { - switch netutils.ClassifyIP(pkt.GetIPHeader().Src) { + switch netutils.ClassifyIP(pkt.Info().Src) { case netutils.HostLocal: return getOrCreateUnknownCommunication(pkt, IncomingHost) case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: @@ -28,7 +28,7 @@ func GetUnknownCommunication(pkt packet.Packet) (*Communication, error) { } } - switch netutils.ClassifyIP(pkt.GetIPHeader().Dst) { + switch netutils.ClassifyIP(pkt.Info().Dst) { case netutils.HostLocal: return getOrCreateUnknownCommunication(pkt, PeerHost) case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: diff --git a/pmctl/run.go b/pmctl/run.go index 0d6fa5af..85fe3972 100644 --- a/pmctl/run.go +++ b/pmctl/run.go @@ -5,6 +5,7 @@ import ( "io" "os" "os/exec" + "runtime" "strings" "github.com/spf13/cobra" @@ -100,11 +101,16 @@ func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error { } // check permission - info, err := os.Stat(file.Path()) - if info.Mode() != 0755 { - err := os.Chmod(file.Path(), 0755) + if runtime.GOOS != "windows" { + info, err := os.Stat(file.Path()) if err != nil { - return fmt.Errorf("%s failed to set exec permissions on %s: %s", logPrefix, file.Path(), err) + return fmt.Errorf("%s failed to get file info on %s: %s", logPrefix, file.Path(), err) + } + if info.Mode() != 0755 { + err := os.Chmod(file.Path(), 0755) + if err != nil { + return fmt.Errorf("%s failed to set exec permissions on %s: %s", logPrefix, file.Path(), err) + } } } diff --git a/pmctl/upgrade.go b/pmctl/upgrade.go index f3d600e5..2d2ff5cf 100644 --- a/pmctl/upgrade.go +++ b/pmctl/upgrade.go @@ -5,6 +5,7 @@ import ( "io" "os" "path/filepath" + "runtime" "github.com/Safing/portbase/info" "github.com/Safing/portmaster/updates" @@ -51,14 +52,18 @@ func doSelfUpgrade(file *updates.File) error { } // check permission - info, err := os.Stat(dst) - if info.Mode() != 0755 { - err := os.Chmod(dst, 0755) + if runtime.GOOS != "windows" { + info, err := os.Stat(dst) if err != nil { - return fmt.Errorf("failed to set permissions on %s: %s", dst, err) + return fmt.Errorf("failed to get file info on %s: %s", dst, err) + } + if info.Mode() != 0755 { + err := os.Chmod(dst, 0755) + if err != nil { + return fmt.Errorf("failed to set permissions on %s: %s", dst, err) + } } } - return nil } diff --git a/process/find.go b/process/find.go index 842922b9..dab6ef8e 100644 --- a/process/find.go +++ b/process/find.go @@ -22,30 +22,30 @@ func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { var remoteIP net.IP var remotePort uint16 if pkt.IsInbound() { - localIP = pkt.GetIPHeader().Dst - remoteIP = pkt.GetIPHeader().Src + localIP = pkt.Info().Dst + remoteIP = pkt.Info().Src } else { - localIP = pkt.GetIPHeader().Src - remoteIP = pkt.GetIPHeader().Dst + localIP = pkt.Info().Src + remoteIP = pkt.Info().Dst } - if pkt.GetIPHeader().Protocol == packet.TCP || pkt.GetIPHeader().Protocol == packet.UDP { + if pkt.HasPorts() { if pkt.IsInbound() { - localPort = pkt.GetTCPUDPHeader().DstPort - remotePort = pkt.GetTCPUDPHeader().SrcPort + localPort = pkt.Info().DstPort + remotePort = pkt.Info().SrcPort } else { - localPort = pkt.GetTCPUDPHeader().SrcPort - remotePort = pkt.GetTCPUDPHeader().DstPort + localPort = pkt.Info().SrcPort + remotePort = pkt.Info().DstPort } } switch { - case pkt.GetIPHeader().Protocol == packet.TCP && pkt.IPVersion() == packet.IPv4: + case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv4: return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.GetIPHeader().Protocol == packet.UDP && pkt.IPVersion() == packet.IPv4: + case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv4: return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.GetIPHeader().Protocol == packet.TCP && pkt.IPVersion() == packet.IPv6: + case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv6: return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.GetIPHeader().Protocol == packet.UDP && pkt.IPVersion() == packet.IPv6: + case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv6: return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) default: return -1, false, errors.New("unsupported protocol for finding process") diff --git a/process/getpid_windows.go b/process/getpid_windows.go index 21835e07..879c9c14 100644 --- a/process/getpid_windows.go +++ b/process/getpid_windows.go @@ -1,7 +1,7 @@ package process import ( - "github.com/Safing/safing-core/process/iphelper" + "github.com/Safing/portmaster/process/iphelper" ) var ( diff --git a/process/iphelper/get.go b/process/iphelper/get.go index 4aec4334..85b788a4 100644 --- a/process/iphelper/get.go +++ b/process/iphelper/get.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "time" ) var ( @@ -21,6 +22,8 @@ var ( ipHelper *IPHelper lock sync.RWMutex + + waitTime = 15 * time.Millisecond ) func checkIPHelper() (err error) { @@ -34,57 +37,71 @@ func checkIPHelper() (err error) { func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { // search - pid, direction = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) + pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) if pid >= 0 { - return + return pid, pktDirection, nil } - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - tcp4Connections, tcp4Listeners, err = ipHelper.GetTables(TCP, IPv4) - } - lock.Unlock() - if err != nil { - return -1, direction, err + for i := 0; i < 3; i++ { + // give kernel some time, then try again + // log.Tracef("process: giving kernel some time to think") + + // if unable to find, refresh + lock.Lock() + err = checkIPHelper() + if err == nil { + tcp4Connections, tcp4Listeners, err = ipHelper.GetTables(TCP, IPv4) + } + lock.Unlock() + if err != nil { + return -1, pktDirection, err + } + + // search + pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) + if pid >= 0 { + return pid, pktDirection, nil + } + + time.Sleep(waitTime) } - // search - pid, direction = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return - } - - return -1, direction, nil + return -1, pktDirection, nil } func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { // search - pid, direction = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) + pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) if pid >= 0 { - return + return pid, pktDirection, nil } - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - tcp6Connections, tcp6Listeners, err = ipHelper.GetTables(TCP, IPv6) - } - lock.Unlock() - if err != nil { - return -1, direction, err + for i := 0; i < 3; i++ { + // give kernel some time, then try again + // log.Tracef("process: giving kernel some time to think") + + // if unable to find, refresh + lock.Lock() + err = checkIPHelper() + if err == nil { + tcp6Connections, tcp6Listeners, err = ipHelper.GetTables(TCP, IPv6) + } + lock.Unlock() + if err != nil { + return -1, pktDirection, err + } + + // search + pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) + if pid >= 0 { + return pid, pktDirection, nil + } + + time.Sleep(waitTime) } - // search - pid, direction = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return - } - - return -1, direction, nil + return -1, pktDirection, nil } func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { @@ -95,21 +112,28 @@ func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote return pid, pktDirection, nil } - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - udp4Connections, udp4Listeners, err = ipHelper.GetTables(UDP, IPv4) - } - lock.Unlock() - if err != nil { - return -1, pktDirection, err - } + for i := 0; i < 3; i++ { + // give kernel some time, then try again + // log.Tracef("process: giving kernel some time to think") - // search - pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil + // if unable to find, refresh + lock.Lock() + err = checkIPHelper() + if err == nil { + udp4Connections, udp4Listeners, err = ipHelper.GetTables(UDP, IPv4) + } + lock.Unlock() + if err != nil { + return -1, pktDirection, err + } + + // search + pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) + if pid >= 0 { + return pid, pktDirection, nil + } + + time.Sleep(waitTime) } return -1, pktDirection, nil @@ -123,21 +147,28 @@ func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote return pid, pktDirection, nil } - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - udp6Connections, udp6Listeners, err = ipHelper.GetTables(UDP, IPv6) - } - lock.Unlock() - if err != nil { - return -1, pktDirection, err - } + for i := 0; i < 3; i++ { + // give kernel some time, then try again + // log.Tracef("process: giving kernel some time to think") - // search - pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil + // if unable to find, refresh + lock.Lock() + err = checkIPHelper() + if err == nil { + udp6Connections, udp6Listeners, err = ipHelper.GetTables(UDP, IPv6) + } + lock.Unlock() + if err != nil { + return -1, pktDirection, err + } + + // search + pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) + if pid >= 0 { + return pid, pktDirection, nil + } + + time.Sleep(waitTime) } return -1, pktDirection, nil @@ -190,8 +221,8 @@ func searchListeners(list []*connectionEntry, localIP net.IP, localPort uint16) for _, entry := range list { if localPort == entry.localPort && - entry.localIP == nil || // nil IP means zero IP, see tables.go - localIP.Equal(entry.localIP) { + (entry.localIP == nil || // nil IP means zero IP, see tables.go + localIP.Equal(entry.localIP)) { return entry.pid } } diff --git a/process/iphelper/tables.go b/process/iphelper/tables.go index b8f73f37..c27f849f 100644 --- a/process/iphelper/tables.go +++ b/process/iphelper/tables.go @@ -3,7 +3,6 @@ package iphelper import ( - "encoding/binary" "errors" "fmt" "net" @@ -125,19 +124,19 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection r1, _, err = ipHelper.getExtendedTcpTable.Call( uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize - 0, // _In_ BOOL bOrder - uintptr(afClass), // _In_ ULONG ulAf - iphelper_TCP_TABLE_OWNER_PID_ALL, // _In_ TCP_TABLE_CLASS TableClass - 0, // _In_ ULONG Reserved + 0, // _In_ BOOL bOrder + uintptr(afClass), // _In_ ULONG ulAf + iphelper_TCP_TABLE_OWNER_PID_ALL, // _In_ TCP_TABLE_CLASS TableClass + 0, // _In_ ULONG Reserved ) case UDP: r1, _, err = ipHelper.getExtendedUdpTable.Call( uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable, uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize, - 0, // _In_ BOOL bOrder, - uintptr(afClass), // _In_ ULONG ulAf, - iphelper_UDP_TABLE_OWNER_PID, // _In_ UDP_TABLE_CLASS TableClass, - 0, // _In_ ULONG Reserved + 0, // _In_ BOOL bOrder, + uintptr(afClass), // _In_ ULONG ulAf, + iphelper_UDP_TABLE_OWNER_PID, // _In_ UDP_TABLE_CLASS TableClass, + 0, // _In_ ULONG Reserved ) } @@ -165,19 +164,16 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection new.pid = int(row.owningPid) // local - new.localIP = make([]byte, 4) - binary.LittleEndian.PutUint32(new.localIP, row.localAddr) + if row.localAddr != 0 { + new.localIP = convertIPv4(row.localAddr) + } new.localPort = uint16(row.localPort>>8 | row.localPort<<8) // remote if row.state == iphelper_TCP_STATE_LISTEN { - if new.localIP.Equal(net.IPv4zero) { - new.localIP = nil - } listeners = append(listeners, new) } else { - new.remoteIP = make([]byte, 4) - binary.LittleEndian.PutUint32(new.remoteIP, row.remoteAddr) + new.remoteIP = convertIPv4(row.remoteAddr) new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8) connections = append(connections, new) } @@ -229,8 +225,7 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection if row.localAddr == 0 { listeners = append(listeners, new) } else { - new.localIP = make([]byte, 4) - binary.LittleEndian.PutUint32(new.localIP, row.localAddr) + new.localIP = convertIPv4(row.localAddr) connections = append(connections, new) } } @@ -261,3 +256,12 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection return connections, listeners, nil } + +func convertIPv4(input uint32) net.IP { + return net.IPv4( + uint8(input&0xFF), + uint8(input>>8&0xFF), + uint8(input>>16&0xFF), + uint8(input>>24&0xFF), + ) +} diff --git a/process/iphelper/test/main.go b/process/iphelper/test/main.go index b87a56bd..6f2d08ab 100644 --- a/process/iphelper/test/main.go +++ b/process/iphelper/test/main.go @@ -5,7 +5,7 @@ package main import ( "fmt" - "github.com/Safing/safing-core/process/iphelper" + "github.com/Safing/portmaster/process/iphelper" ) func main() { diff --git a/process/process_windows.go b/process/process_windows.go index b1e38349..5e7f2cdc 100644 --- a/process/process_windows.go +++ b/process/process_windows.go @@ -5,12 +5,12 @@ import "strings" // IsUser returns whether the process is run by a normal user. func (m *Process) IsUser() bool { return m.Pid != 4 && // Kernel - !strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!) + !strings.HasPrefix(m.UserName, "NT") // NT-Authority (localized!) } // IsAdmin returns whether the process is run by an admin user. func (m *Process) IsAdmin() bool { - return strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!) + return strings.HasPrefix(m.UserName, "NT") // NT-Authority (localized!) } // IsSystem returns whether the process is run by the operating system. diff --git a/profile/const_windows.go b/profile/const_windows.go new file mode 100644 index 00000000..e4479c63 --- /dev/null +++ b/profile/const_windows.go @@ -0,0 +1,6 @@ +package profile + +// OS Identifier +const ( + osIdentifier = PlatformWindows +) diff --git a/updates/fetch.go b/updates/fetch.go index ccdeac65..d07a0248 100644 --- a/updates/fetch.go +++ b/updates/fetch.go @@ -9,6 +9,7 @@ import ( "os" "path" "path/filepath" + "runtime" "time" "github.com/google/renameio" @@ -70,9 +71,11 @@ func fetchFile(realFilepath, updateFilepath string, tries int) error { return fmt.Errorf("updates: failed to finalize file %s: %s", realFilepath, err) } // set permissions - err = os.Chmod(realFilepath, 0644) - if err != nil { - log.Warningf("updates: failed to set permissions on downloaded file %s: %s", realFilepath, err) + if runtime.GOOS != "windows" { + err = os.Chmod(realFilepath, 0644) + if err != nil { + log.Warningf("updates: failed to set permissions on downloaded file %s: %s", realFilepath, err) + } } log.Infof("updates: fetched %s (stored to %s)", downloadURL, realFilepath) diff --git a/updates/get.go b/updates/get.go index c37c2577..6ce99c63 100644 --- a/updates/get.go +++ b/updates/get.go @@ -56,7 +56,7 @@ func loadOrFetchFile(identifier string) (*File, error) { } // build final filepath - realFilePath := filepath.Join(updateStoragePath, versionedFilePath) + realFilePath := filepath.Join(updateStoragePath, filepath.FromSlash(versionedFilePath)) if _, err := os.Stat(realFilePath); err == nil { // file exists updateUsedStatus(identifier, version) diff --git a/updates/latest.go b/updates/latest.go index 0e6628bc..615cda9f 100644 --- a/updates/latest.go +++ b/updates/latest.go @@ -92,6 +92,7 @@ func ScanForLatest(baseDir string, hardFail bool) (latest map[string]string, las storedVersion, ok := latest[identifierPath] if ok { // FIXME: this will fail on multi-digit version segments! + // FIXME: use https://github.com/hashicorp/go-version if version > storedVersion { latest[identifierPath] = version }