From 5d7caeb4bb8ee5c234c49f8d5c708cfc068ac042 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 20 Jul 2023 13:36:09 +0200 Subject: [PATCH] Expose connection ID creation --- network/packet/packet.go | 19 +------------------ network/packet/packetinfo.go | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/network/packet/packet.go b/network/packet/packet.go index a4c9d780..1ac3047f 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -112,28 +112,11 @@ func (pkt *Base) Payload() []byte { // GetConnectionID returns the link ID for this packet. func (pkt *Base) GetConnectionID() string { if pkt.connID == "" { - pkt.createConnectionID() + pkt.connID = pkt.info.CreateConnectionID() } return pkt.connID } -func (pkt *Base) createConnectionID() { - // TODO: make this ID not depend on the packet direction for better support for forwarded packets. - if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { - if pkt.info.Inbound { - pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) - } else { - pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) - } - } else { - if pkt.info.Inbound { - pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) - } else { - pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) - } - } -} - // MatchesAddress checks if a the packet matches a given endpoint (remote or local) in protocol, network and port. // // Comparison matrix: diff --git a/network/packet/packetinfo.go b/network/packet/packetinfo.go index 37f40216..82131625 100644 --- a/network/packet/packetinfo.go +++ b/network/packet/packetinfo.go @@ -1,6 +1,7 @@ package packet import ( + "fmt" "net" "time" ) @@ -50,3 +51,26 @@ func (pi *Info) RemotePort() uint16 { } return pi.DstPort } + +// CreateConnectionID creates a connection ID. +// In most circumstances, this method should not be used directly, but +// packet.GetConnectionID() should be called instead. +func (pi *Info) CreateConnectionID() string { + return CreateConnectionID(pi.Protocol, pi.Src, pi.SrcPort, pi.Dst, pi.DstPort, pi.Inbound) +} + +// CreateConnectionID creates a connection ID. +func CreateConnectionID(protocol IPProtocol, src net.IP, srcPort uint16, dst net.IP, dstPort uint16, inbound bool) string { + // TODO: make this ID not depend on the packet direction for better support for forwarded packets. + if protocol == TCP || protocol == UDP { + if inbound { + return fmt.Sprintf("%d-%s-%d-%s-%d", protocol, dst, dstPort, src, srcPort) + } + return fmt.Sprintf("%d-%s-%d-%s-%d", protocol, src, srcPort, dst, dstPort) + } + + if inbound { + return fmt.Sprintf("%d-%s-%s", protocol, dst, src) + } + return fmt.Sprintf("%d-%s-%s", protocol, src, dst) +}