Improve IP/Port parsing
This commit is contained in:
@@ -2,12 +2,10 @@ package firewall
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,6 +14,7 @@ import (
|
|||||||
"github.com/safing/portbase/log"
|
"github.com/safing/portbase/log"
|
||||||
"github.com/safing/portbase/utils"
|
"github.com/safing/portbase/utils"
|
||||||
"github.com/safing/portmaster/netenv"
|
"github.com/safing/portmaster/netenv"
|
||||||
|
"github.com/safing/portmaster/network/netutils"
|
||||||
"github.com/safing/portmaster/network/packet"
|
"github.com/safing/portmaster/network/packet"
|
||||||
"github.com/safing/portmaster/process"
|
"github.com/safing/portmaster/process"
|
||||||
"github.com/safing/portmaster/updates"
|
"github.com/safing/portmaster/updates"
|
||||||
@@ -53,7 +52,7 @@ func prepAPIAuth() error {
|
|||||||
|
|
||||||
func startAPIAuth() {
|
func startAPIAuth() {
|
||||||
var err error
|
var err error
|
||||||
apiIP, apiPort, err = parseHostPort(apiListenAddress())
|
apiIP, apiPort, err = netutils.ParseIPPort(apiListenAddress())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err)
|
log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err)
|
||||||
return
|
return
|
||||||
@@ -71,13 +70,13 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get local IP/Port
|
// get local IP/Port
|
||||||
localIP, localPort, err := parseHostPort(s.Addr)
|
localIP, localPort, err := netutils.ParseIPPort(s.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get local IP/Port: %w", err)
|
return nil, fmt.Errorf("failed to get local IP/Port: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// get remote IP/Port
|
// get remote IP/Port
|
||||||
remoteIP, remotePort, err := parseHostPort(r.RemoteAddr)
|
remoteIP, remotePort, err := netutils.ParseIPPort(r.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get remote IP/Port: %w", err)
|
return nil, fmt.Errorf("failed to get remote IP/Port: %w", err)
|
||||||
}
|
}
|
||||||
@@ -214,22 +213,3 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseHostPort(address string) (net.IP, uint16, error) {
|
|
||||||
ipString, portString, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP(ipString)
|
|
||||||
if ip == nil {
|
|
||||||
return nil, 0, errors.New("invalid IP address")
|
|
||||||
}
|
|
||||||
|
|
||||||
port, err := strconv.ParseUint(portString, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ip, uint16(port), nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,39 +2,45 @@ package netutils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/safing/portmaster/network/packet"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errInvalidIP = errors.New("invalid IP address")
|
var errInvalidIP = errors.New("invalid IP address")
|
||||||
|
|
||||||
// IPFromAddr extracts or parses the IP address contained in the given address.
|
// IPPortFromAddr extracts or parses the IP address and port contained in the given address.
|
||||||
func IPFromAddr(addr net.Addr) (net.IP, error) {
|
func IPPortFromAddr(addr net.Addr) (ip net.IP, port uint16, err error) {
|
||||||
// Convert addr to IP if needed.
|
// Convert addr to IP if needed.
|
||||||
switch v := addr.(type) {
|
switch v := addr.(type) {
|
||||||
case *net.TCPAddr:
|
case *net.TCPAddr:
|
||||||
return v.IP, nil
|
return v.IP, uint16(v.Port), nil
|
||||||
case *net.UDPAddr:
|
case *net.UDPAddr:
|
||||||
return v.IP, nil
|
return v.IP, uint16(v.Port), nil
|
||||||
case *net.IPAddr:
|
case *net.IPAddr:
|
||||||
return v.IP, nil
|
return v.IP, 0, nil
|
||||||
|
case *net.UnixAddr:
|
||||||
|
return nil, 0, errors.New("unix addresses don't have IPs")
|
||||||
default:
|
default:
|
||||||
// Parse via string.
|
return ParseIPPort(addr.String())
|
||||||
host, _, err := net.SplitHostPort(addr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to split host and port of %q: %w", addr, err)
|
|
||||||
}
|
|
||||||
ip := net.ParseIP(host)
|
|
||||||
if ip == nil {
|
|
||||||
return nil, fmt.Errorf("address %q does not contain a valid IP address", addr)
|
|
||||||
}
|
|
||||||
return ip, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseHostPort parses a <ip>:port formatted address.
|
// ProtocolFromNetwork returns the protocol from the given net, as used in the "net" golang stdlib.
|
||||||
func ParseHostPort(address string) (net.IP, uint16, error) {
|
func ProtocolFromNetwork(net string) (protocol packet.IPProtocol) {
|
||||||
|
switch net {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
return packet.TCP
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
return packet.UDP
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseIPPort parses a <ip>:port formatted address.
|
||||||
|
func ParseIPPort(address string) (net.IP, uint16, error) {
|
||||||
ipString, portString, err := net.SplitHostPort(address)
|
ipString, portString, err := net.SplitHostPort(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ func GetNetworkHost(ctx context.Context, remoteIP net.IP) (process *Process, err
|
|||||||
// GetProcessByRequestOrigin returns the process that initiated the API request ar.
|
// GetProcessByRequestOrigin returns the process that initiated the API request ar.
|
||||||
func GetProcessByRequestOrigin(ar *api.Request) (*Process, error) {
|
func GetProcessByRequestOrigin(ar *api.Request) (*Process, error) {
|
||||||
// get remote IP/Port
|
// get remote IP/Port
|
||||||
remoteIP, remotePort, err := netutils.ParseHostPort(ar.RemoteAddr)
|
remoteIP, remotePort, err := netutils.ParseIPPort(ar.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get remote IP/Port: %w", err)
|
return nil, fmt.Errorf("failed to get remote IP/Port: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user