Implement review suggestions

This commit is contained in:
Daniel
2020-05-19 16:57:55 +02:00
parent 65a3456165
commit e65ae8b55d
14 changed files with 130 additions and 100 deletions

View File

@@ -11,7 +11,11 @@ const (
UDPConnectionTTL = 10 * time.Minute
)
// Exists checks if the given connection is present in the system state tables.
func Exists(pktInfo *packet.Info, now time.Time) (exists bool) {
// TODO: create lookup maps before running a flurry of Exists() checks.
switch {
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
tcp4Lock.Lock()
@@ -76,7 +80,10 @@ func existsUDP(
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
udpConnState, ok := getUDPConnState(socketInfo, udpStates, remoteIP, remotePort)
udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{
IP: remoteIP,
Port: remotePort,
})
switch {
case !ok:
return false

View File

@@ -24,10 +24,6 @@ import (
// - switch direction to outbound if outbound packet is seen?
// - IP: Unidentified Process
const (
UnidentifiedProcessID = -1
)
// Errors
var (
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
@@ -75,7 +71,7 @@ func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) {
return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo)
default:
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
}
}
@@ -119,7 +115,7 @@ func searchTCP(
connections, listeners = updateTables()
}
return UnidentifiedProcessID, false, ErrConnectionNotFound
return socket.UnidentifiedProcessID, false, ErrConnectionNotFound
}
func searchUDP(
@@ -170,5 +166,5 @@ func searchUDP(
binds = updateTable()
}
return UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
}

View File

@@ -14,24 +14,14 @@ var (
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID {
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode)
if ok {
socketInfo.PID = pid
} else {
socketInfo.PID = UnidentifiedProcessID
}
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
}
return socketInfo.PID, connInbound, nil
}
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID {
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode)
if ok {
socketInfo.PID = pid
} else {
socketInfo.PID = UnidentifiedProcessID
}
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
}
return socketInfo.PID, connInbound, nil
}

View File

@@ -18,9 +18,8 @@ var (
)
func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
// FIXME: repeatable once
connections, listeners, err := getTCP4Table()
var err error
connections, listeners, err = getTCP4Table()
if err != nil {
log.Warningf("state: failed to get TCP4 socket table: %s", err)
return
@@ -28,11 +27,12 @@ func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*sock
tcp4Connections = connections
tcp4Listeners = listeners
return tcp4Connections, tcp4Listeners
return
}
func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
connections, listeners, err := getTCP6Table()
var err error
connections, listeners, err = getTCP6Table()
if err != nil {
log.Warningf("state: failed to get TCP6 socket table: %s", err)
return
@@ -40,27 +40,29 @@ func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*sock
tcp6Connections = connections
tcp6Listeners = listeners
return tcp6Connections, tcp6Listeners
return
}
func updateUDP4Table() (binds []*socket.BindInfo) {
binds, err := getUDP4Table()
var err error
binds, err = getUDP4Table()
if err != nil {
log.Warningf("state: failed to get UDP4 socket table: %s", err)
return
}
udp4Binds = binds
return udp4Binds
return
}
func updateUDP6Table() (binds []*socket.BindInfo) {
binds, err := getUDP6Table()
var err error
binds, err = getUDP6Table()
if err != nil {
log.Warningf("state: failed to get UDP6 socket table: %s", err)
return
}
udp6Binds = binds
return udp6Binds
return
}

View File

@@ -2,7 +2,6 @@ package state
import (
"context"
"net"
"time"
"github.com/safing/portmaster/network/packet"
@@ -15,7 +14,7 @@ type udpState struct {
}
const (
UpdConnStateTTL = 72 * time.Hour
UdpConnStateTTL = 72 * time.Hour
UdpConnStateShortenedTTL = 3 * time.Hour
AggressiveCleaningThreshold = 256
)
@@ -25,10 +24,10 @@ var (
udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock
)
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16) (udpConnState *udpState, ok bool) {
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)]
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) {
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)]
if ok {
udpConnState, ok = bindMap[makeUDPStateKey(remoteIP, remotePort)]
udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)]
return
}
@@ -36,7 +35,7 @@ func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[strin
}
func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) {
localKey := makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)
localKey := makeUDPStateKey(socketInfo.Local)
bindMap, ok := udpStates[localKey]
if !ok {
@@ -44,7 +43,10 @@ func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[strin
udpStates[localKey] = bindMap
}
remoteKey := makeUDPStateKey(pktInfo.RemoteIP(), pktInfo.RemotePort())
remoteKey := makeUDPStateKey(socket.Address{
IP: pktInfo.RemoteIP(),
Port: pktInfo.RemotePort(),
})
udpConnState, ok := bindMap[remoteKey]
if !ok {
bindMap[remoteKey] = &udpState{
@@ -79,19 +81,18 @@ func cleanStates(
now time.Time,
) {
// compute thresholds
threshold := now.Add(-UpdConnStateTTL)
threshold := now.Add(-UdpConnStateTTL)
shortThreshhold := now.Add(-UdpConnStateShortenedTTL)
// make list of all active keys
// make lookup map of all active keys
bindKeys := make(map[string]struct{})
for _, socketInfo := range binds {
bindKeys[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] = struct{}{}
bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{}
}
// clean the udp state storage
for localKey, bindMap := range udpStates {
_, active := bindKeys[localKey]
if active {
if _, active := bindKeys[localKey]; active {
// clean old entries
for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(threshold) {
@@ -113,7 +114,7 @@ func cleanStates(
}
}
func makeUDPStateKey(ip net.IP, port uint16) string {
func makeUDPStateKey(address socket.Address) string {
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
return string(ip) + string(port)
return string(address.IP) + string(address.Port)
}