From 1376e46d216ec0e54f69717168381ac958d6d093 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 6 Sep 2023 11:08:02 +0200 Subject: [PATCH] Improve IP/Port parsing --- firewall/api.go | 28 ++++--------------------- network/netutils/address.go | 42 +++++++++++++++++++++---------------- process/find.go | 2 +- 3 files changed, 29 insertions(+), 43 deletions(-) diff --git a/firewall/api.go b/firewall/api.go index ba4c0585..b17efe6d 100644 --- a/firewall/api.go +++ b/firewall/api.go @@ -2,12 +2,10 @@ package firewall import ( "context" - "errors" "fmt" "net" "net/http" "path/filepath" - "strconv" "strings" "time" @@ -16,6 +14,7 @@ import ( "github.com/safing/portbase/log" "github.com/safing/portbase/utils" "github.com/safing/portmaster/netenv" + "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/process" "github.com/safing/portmaster/updates" @@ -53,7 +52,7 @@ func prepAPIAuth() error { func startAPIAuth() { var err error - apiIP, apiPort, err = parseHostPort(apiListenAddress()) + apiIP, apiPort, err = netutils.ParseIPPort(apiListenAddress()) if err != nil { log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err) return @@ -71,13 +70,13 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er } // get local IP/Port - localIP, localPort, err := parseHostPort(s.Addr) + localIP, localPort, err := netutils.ParseIPPort(s.Addr) if err != nil { return nil, fmt.Errorf("failed to get local IP/Port: %w", err) } // get remote IP/Port - remoteIP, remotePort, err := parseHostPort(r.RemoteAddr) + remoteIP, remotePort, err := netutils.ParseIPPort(r.RemoteAddr) if err != nil { return nil, fmt.Errorf("failed to get remote IP/Port: %w", err) } @@ -214,22 +213,3 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo ) } } - -func parseHostPort(address string) (net.IP, uint16, error) { - ipString, portString, err := net.SplitHostPort(address) - if err != nil { - return nil, 0, err - } - - ip := net.ParseIP(ipString) - if ip == nil { - return nil, 0, errors.New("invalid IP address") - } - - port, err := strconv.ParseUint(portString, 10, 16) - if err != nil { - return nil, 0, err - } - - return ip, uint16(port), nil -} diff --git a/network/netutils/address.go b/network/netutils/address.go index 59d2a262..3d89c39c 100644 --- a/network/netutils/address.go +++ b/network/netutils/address.go @@ -2,39 +2,45 @@ package netutils import ( "errors" - "fmt" "net" "strconv" + + "github.com/safing/portmaster/network/packet" ) var errInvalidIP = errors.New("invalid IP address") -// IPFromAddr extracts or parses the IP address contained in the given address. -func IPFromAddr(addr net.Addr) (net.IP, error) { +// IPPortFromAddr extracts or parses the IP address and port contained in the given address. +func IPPortFromAddr(addr net.Addr) (ip net.IP, port uint16, err error) { // Convert addr to IP if needed. switch v := addr.(type) { case *net.TCPAddr: - return v.IP, nil + return v.IP, uint16(v.Port), nil case *net.UDPAddr: - return v.IP, nil + return v.IP, uint16(v.Port), nil case *net.IPAddr: - return v.IP, nil + return v.IP, 0, nil + case *net.UnixAddr: + return nil, 0, errors.New("unix addresses don't have IPs") default: - // Parse via string. - host, _, err := net.SplitHostPort(addr.String()) - if err != nil { - return nil, fmt.Errorf("failed to split host and port of %q: %w", addr, err) - } - ip := net.ParseIP(host) - if ip == nil { - return nil, fmt.Errorf("address %q does not contain a valid IP address", addr) - } - return ip, nil + return ParseIPPort(addr.String()) } } -// ParseHostPort parses a :port formatted address. -func ParseHostPort(address string) (net.IP, uint16, error) { +// ProtocolFromNetwork returns the protocol from the given net, as used in the "net" golang stdlib. +func ProtocolFromNetwork(net string) (protocol packet.IPProtocol) { + switch net { + case "tcp", "tcp4", "tcp6": + return packet.TCP + case "udp", "udp4", "udp6": + return packet.UDP + default: + return 0 + } +} + +// ParseIPPort parses a :port formatted address. +func ParseIPPort(address string) (net.IP, uint16, error) { ipString, portString, err := net.SplitHostPort(address) if err != nil { return nil, 0, err diff --git a/process/find.go b/process/find.go index 9c96e1fe..c0a209e9 100644 --- a/process/find.go +++ b/process/find.go @@ -115,7 +115,7 @@ func GetNetworkHost(ctx context.Context, remoteIP net.IP) (process *Process, err // GetProcessByRequestOrigin returns the process that initiated the API request ar. func GetProcessByRequestOrigin(ar *api.Request) (*Process, error) { // get remote IP/Port - remoteIP, remotePort, err := netutils.ParseHostPort(ar.RemoteAddr) + remoteIP, remotePort, err := netutils.ParseIPPort(ar.RemoteAddr) if err != nil { return nil, fmt.Errorf("failed to get remote IP/Port: %w", err) }