diff --git a/firewall/api.go b/firewall/api.go new file mode 100644 index 00000000..9b00f49b --- /dev/null +++ b/firewall/api.go @@ -0,0 +1,103 @@ +package firewall + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "strconv" + "strings" + + "github.com/safing/portbase/database" + "github.com/safing/portbase/log" + + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/process" + + "github.com/safing/portbase/api" +) + +var ( + dbRoot string + + apiAddressSet bool + apiIP net.IP + apiPort uint16 +) + +func prepAPIAuth() error { + dbRoot = database.GetDatabaseRoot() + return api.SetAuthenticator(apiAuthenticator) +} + +func startAPIAuth() { + var err error + apiIP, apiPort, err = parseHostPort(api.GetAPIAddress()) + if err != nil { + log.Warningf("firewall: failed to parse API address for improved api auth mechanism: %s", err) + } + apiAddressSet = true +} + +func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err error) { + // get local IP/Port + localIP, localPort, err := parseHostPort(s.Addr) + if err != nil { + return false, fmt.Errorf("failed to get local IP/Port: %s", err) + } + + // get remote IP/Port + remoteIP, remotePort, err := parseHostPort(r.RemoteAddr) + if err != nil { + return false, fmt.Errorf("failed to get remote IP/Port: %s", err) + } + + var procsChecked []string + + // get process + proc, err := process.GetProcessByEndpoints(r.Context(), remoteIP, remotePort, localIP, localPort, packet.TCP) // switch reverse/local to get remote process + if err != nil { + return false, fmt.Errorf("failed to get process: %s", err) + } + + // go up up to two levels, if we don't match + for i := 0; i < 3; i++ { + // check if the requesting process is in database root / updates dir + if strings.HasPrefix(proc.Path, dbRoot) { + return true, nil + } + // add checked process to list + procsChecked = append(procsChecked, proc.Path) + + if i < 2 { + // get parent process + proc, err = process.GetOrFindProcess(context.Background(), proc.ParentPid) + if err != nil { + return false, fmt.Errorf("failed to get process: %s", err) + } + } + } + + log.Debugf("firewall: denying api access to %s - also checked %s (trusted root is %s)", procsChecked[0], strings.Join(procsChecked[1:], " "), dbRoot) + return true, nil +} + +func parseHostPort(address string) (net.IP, uint16, error) { + ipString, portString, err := net.SplitHostPort(address) + if err != nil { + return nil, 0, err + } + + ip := net.ParseIP(ipString) + if ip == nil { + return nil, 0, errors.New("invalid IP address") + } + + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, 0, err + } + + return ip, uint16(port), nil +} diff --git a/firewall/firewall.go b/firewall/firewall.go index 74fb95bb..96620eda 100644 --- a/firewall/firewall.go +++ b/firewall/firewall.go @@ -50,6 +50,11 @@ func prep() (err error) { return err } + err = prepAPIAuth() + if err != nil { + return err + } + _, localNet4, err = net.ParseCIDR("127.0.0.0/24") // Yes, this would normally be 127.0.0.0/8 // TODO: figure out any side effects @@ -77,12 +82,9 @@ func prep() (err error) { } func start() error { + startAPIAuth() go statLogger() go run() - // go run() - // go run() - // go run() - go portsInUseCleaner() return interception.Start() @@ -108,6 +110,15 @@ func handlePacket(pkt packet.Packet) { return } + // allow api access, if address was parsed successfully + if apiAddressSet { + if (pkt.Info().DstPort == apiPort && pkt.Info().Dst.Equal(apiIP)) || (pkt.Info().SrcPort == apiPort && pkt.Info().Src.Equal(apiIP)) { + log.Debugf("accepting api connection: %s", pkt) + pkt.PermanentAccept() + return + } + } + // // redirect dns (if we know that it's not our own request) // if pkt.IsOutbound() && intel.RemoteIsActiveNameserver(pkt) { // log.Debugf("redirecting dns: %s", pkt)