Implement review suggestions
This commit is contained in:
@@ -11,7 +11,11 @@ const (
|
||||
UDPConnectionTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// Exists checks if the given connection is present in the system state tables.
|
||||
func Exists(pktInfo *packet.Info, now time.Time) (exists bool) {
|
||||
|
||||
// TODO: create lookup maps before running a flurry of Exists() checks.
|
||||
|
||||
switch {
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
|
||||
tcp4Lock.Lock()
|
||||
@@ -76,7 +80,10 @@ func existsUDP(
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
|
||||
|
||||
udpConnState, ok := getUDPConnState(socketInfo, udpStates, remoteIP, remotePort)
|
||||
udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{
|
||||
IP: remoteIP,
|
||||
Port: remotePort,
|
||||
})
|
||||
switch {
|
||||
case !ok:
|
||||
return false
|
||||
|
||||
@@ -24,10 +24,6 @@ import (
|
||||
// - switch direction to outbound if outbound packet is seen?
|
||||
// - IP: Unidentified Process
|
||||
|
||||
const (
|
||||
UnidentifiedProcessID = -1
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
|
||||
@@ -75,7 +71,7 @@ func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) {
|
||||
return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo)
|
||||
|
||||
default:
|
||||
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
|
||||
return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +115,7 @@ func searchTCP(
|
||||
connections, listeners = updateTables()
|
||||
}
|
||||
|
||||
return UnidentifiedProcessID, false, ErrConnectionNotFound
|
||||
return socket.UnidentifiedProcessID, false, ErrConnectionNotFound
|
||||
}
|
||||
|
||||
func searchUDP(
|
||||
@@ -170,5 +166,5 @@ func searchUDP(
|
||||
binds = updateTable()
|
||||
}
|
||||
|
||||
return UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
|
||||
return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
|
||||
}
|
||||
|
||||
@@ -14,24 +14,14 @@ var (
|
||||
|
||||
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
|
||||
if socketInfo.PID == proc.UnfetchedProcessID {
|
||||
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode)
|
||||
if ok {
|
||||
socketInfo.PID = pid
|
||||
} else {
|
||||
socketInfo.PID = UnidentifiedProcessID
|
||||
}
|
||||
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
|
||||
}
|
||||
return socketInfo.PID, connInbound, nil
|
||||
}
|
||||
|
||||
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
|
||||
if socketInfo.PID == proc.UnfetchedProcessID {
|
||||
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode)
|
||||
if ok {
|
||||
socketInfo.PID = pid
|
||||
} else {
|
||||
socketInfo.PID = UnidentifiedProcessID
|
||||
}
|
||||
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
|
||||
}
|
||||
return socketInfo.PID, connInbound, nil
|
||||
}
|
||||
|
||||
@@ -18,9 +18,8 @@ var (
|
||||
)
|
||||
|
||||
func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
|
||||
// FIXME: repeatable once
|
||||
|
||||
connections, listeners, err := getTCP4Table()
|
||||
var err error
|
||||
connections, listeners, err = getTCP4Table()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get TCP4 socket table: %s", err)
|
||||
return
|
||||
@@ -28,11 +27,12 @@ func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*sock
|
||||
|
||||
tcp4Connections = connections
|
||||
tcp4Listeners = listeners
|
||||
return tcp4Connections, tcp4Listeners
|
||||
return
|
||||
}
|
||||
|
||||
func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
|
||||
connections, listeners, err := getTCP6Table()
|
||||
var err error
|
||||
connections, listeners, err = getTCP6Table()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get TCP6 socket table: %s", err)
|
||||
return
|
||||
@@ -40,27 +40,29 @@ func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*sock
|
||||
|
||||
tcp6Connections = connections
|
||||
tcp6Listeners = listeners
|
||||
return tcp6Connections, tcp6Listeners
|
||||
return
|
||||
}
|
||||
|
||||
func updateUDP4Table() (binds []*socket.BindInfo) {
|
||||
binds, err := getUDP4Table()
|
||||
var err error
|
||||
binds, err = getUDP4Table()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get UDP4 socket table: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
udp4Binds = binds
|
||||
return udp4Binds
|
||||
return
|
||||
}
|
||||
|
||||
func updateUDP6Table() (binds []*socket.BindInfo) {
|
||||
binds, err := getUDP6Table()
|
||||
var err error
|
||||
binds, err = getUDP6Table()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get UDP6 socket table: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
udp6Binds = binds
|
||||
return udp6Binds
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
@@ -15,7 +14,7 @@ type udpState struct {
|
||||
}
|
||||
|
||||
const (
|
||||
UpdConnStateTTL = 72 * time.Hour
|
||||
UdpConnStateTTL = 72 * time.Hour
|
||||
UdpConnStateShortenedTTL = 3 * time.Hour
|
||||
AggressiveCleaningThreshold = 256
|
||||
)
|
||||
@@ -25,10 +24,10 @@ var (
|
||||
udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock
|
||||
)
|
||||
|
||||
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16) (udpConnState *udpState, ok bool) {
|
||||
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)]
|
||||
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) {
|
||||
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)]
|
||||
if ok {
|
||||
udpConnState, ok = bindMap[makeUDPStateKey(remoteIP, remotePort)]
|
||||
udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)]
|
||||
return
|
||||
}
|
||||
|
||||
@@ -36,7 +35,7 @@ func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[strin
|
||||
}
|
||||
|
||||
func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) {
|
||||
localKey := makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)
|
||||
localKey := makeUDPStateKey(socketInfo.Local)
|
||||
|
||||
bindMap, ok := udpStates[localKey]
|
||||
if !ok {
|
||||
@@ -44,7 +43,10 @@ func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[strin
|
||||
udpStates[localKey] = bindMap
|
||||
}
|
||||
|
||||
remoteKey := makeUDPStateKey(pktInfo.RemoteIP(), pktInfo.RemotePort())
|
||||
remoteKey := makeUDPStateKey(socket.Address{
|
||||
IP: pktInfo.RemoteIP(),
|
||||
Port: pktInfo.RemotePort(),
|
||||
})
|
||||
udpConnState, ok := bindMap[remoteKey]
|
||||
if !ok {
|
||||
bindMap[remoteKey] = &udpState{
|
||||
@@ -79,19 +81,18 @@ func cleanStates(
|
||||
now time.Time,
|
||||
) {
|
||||
// compute thresholds
|
||||
threshold := now.Add(-UpdConnStateTTL)
|
||||
threshold := now.Add(-UdpConnStateTTL)
|
||||
shortThreshhold := now.Add(-UdpConnStateShortenedTTL)
|
||||
|
||||
// make list of all active keys
|
||||
// make lookup map of all active keys
|
||||
bindKeys := make(map[string]struct{})
|
||||
for _, socketInfo := range binds {
|
||||
bindKeys[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] = struct{}{}
|
||||
bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{}
|
||||
}
|
||||
|
||||
// clean the udp state storage
|
||||
for localKey, bindMap := range udpStates {
|
||||
_, active := bindKeys[localKey]
|
||||
if active {
|
||||
if _, active := bindKeys[localKey]; active {
|
||||
// clean old entries
|
||||
for remoteKey, udpConnState := range bindMap {
|
||||
if udpConnState.lastSeen.Before(threshold) {
|
||||
@@ -113,7 +114,7 @@ func cleanStates(
|
||||
}
|
||||
}
|
||||
|
||||
func makeUDPStateKey(ip net.IP, port uint16) string {
|
||||
func makeUDPStateKey(address socket.Address) string {
|
||||
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
|
||||
return string(ip) + string(port)
|
||||
return string(address.IP) + string(address.Port)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user