wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
46
spn/sluice/module.go
Normal file
46
spn/sluice/module.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package sluice
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
entrypointInfoMsg = []byte("You have reached the local SPN entry port, but your connection could not be matched to an SPN tunnel.\n")
|
||||
|
||||
// EnableListener indicates if it should start the sluice listeners. Must be set at startup.
|
||||
EnableListener bool = true
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("sluice", nil, start, stop, "terminal")
|
||||
}
|
||||
|
||||
func start() error {
|
||||
// TODO:
|
||||
// Listening on all interfaces for now, as we need this for Windows.
|
||||
// Handle similarly to the nameserver listener.
|
||||
|
||||
if conf.Client() && EnableListener {
|
||||
StartSluice("tcp4", "0.0.0.0:717")
|
||||
StartSluice("udp4", "0.0.0.0:717")
|
||||
|
||||
if netenv.IPv6Enabled() {
|
||||
StartSluice("tcp6", "[::]:717")
|
||||
StartSluice("udp6", "[::]:717")
|
||||
} else {
|
||||
log.Warningf("spn/sluice: no IPv6 stack detected, disabling IPv6 SPN entry endpoints")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
stopAllSluices()
|
||||
return nil
|
||||
}
|
||||
277
spn/sluice/packet_listener.go
Normal file
277
spn/sluice/packet_listener.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package sluice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// PacketListener is a listener for packet based protocols.
|
||||
type PacketListener struct {
|
||||
sock net.PacketConn
|
||||
closed *abool.AtomicBool
|
||||
newConns chan *PacketConn
|
||||
|
||||
lock sync.Mutex
|
||||
conns map[string]*PacketConn
|
||||
err error
|
||||
}
|
||||
|
||||
// ListenPacket creates a packet listener.
|
||||
func ListenPacket(network, address string) (net.Listener, error) {
|
||||
// Create a new listening packet socket.
|
||||
sock, err := net.ListenPacket(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create listener and start workers.
|
||||
ln := &PacketListener{
|
||||
sock: sock,
|
||||
closed: abool.New(),
|
||||
newConns: make(chan *PacketConn),
|
||||
conns: make(map[string]*PacketConn),
|
||||
}
|
||||
module.StartServiceWorker("packet listener reader", 0, ln.reader)
|
||||
module.StartServiceWorker("packet listener cleaner", time.Minute, ln.cleaner)
|
||||
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (ln *PacketListener) Accept() (net.Conn, error) {
|
||||
conn := <-ln.newConns
|
||||
if conn != nil {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check if there is a socket error.
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
if ln.err != nil {
|
||||
return nil, ln.err
|
||||
}
|
||||
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
func (ln *PacketListener) Close() error {
|
||||
if !ln.closed.SetToIf(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close all channels.
|
||||
close(ln.newConns)
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
for _, conn := range ln.conns {
|
||||
close(conn.in)
|
||||
}
|
||||
|
||||
// Close socket.
|
||||
return ln.sock.Close()
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (ln *PacketListener) Addr() net.Addr {
|
||||
return ln.sock.LocalAddr()
|
||||
}
|
||||
|
||||
func (ln *PacketListener) getConn(remoteAddr string) (conn *PacketConn, ok bool) {
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
|
||||
conn, ok = ln.conns[remoteAddr]
|
||||
return
|
||||
}
|
||||
|
||||
func (ln *PacketListener) setConn(conn *PacketConn) {
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
|
||||
ln.conns[conn.addr.String()] = conn
|
||||
}
|
||||
|
||||
func (ln *PacketListener) reader(_ context.Context) error {
|
||||
for {
|
||||
// Read data from connection.
|
||||
buf := make([]byte, 512)
|
||||
n, addr, err := ln.sock.ReadFrom(buf)
|
||||
if err != nil {
|
||||
// Set socket error.
|
||||
ln.lock.Lock()
|
||||
ln.err = err
|
||||
ln.lock.Unlock()
|
||||
// Close and return
|
||||
_ = ln.Close()
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
buf = buf[:n]
|
||||
|
||||
// Get connection and supply data.
|
||||
conn, ok := ln.getConn(addr.String())
|
||||
if ok {
|
||||
// Ignore if conn is closed.
|
||||
if conn.closed.IsSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case conn.in <- buf:
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Or create a new connection.
|
||||
conn = &PacketConn{
|
||||
ln: ln,
|
||||
addr: addr,
|
||||
closed: abool.New(),
|
||||
closing: make(chan struct{}),
|
||||
buf: buf,
|
||||
in: make(chan []byte, 1),
|
||||
inactivityCnt: new(uint32),
|
||||
}
|
||||
ln.setConn(conn)
|
||||
ln.newConns <- conn
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *PacketListener) cleaner(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(1 * time.Minute):
|
||||
// Check if listener has died.
|
||||
if ln.closed.IsSet() {
|
||||
return nil
|
||||
}
|
||||
// Clean connections.
|
||||
ln.cleanInactiveConns(10)
|
||||
|
||||
case <-ctx.Done():
|
||||
// Exit with module stop.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *PacketListener) cleanInactiveConns(overInactivityCnt uint32) {
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
|
||||
for k, conn := range ln.conns {
|
||||
cnt := atomic.AddUint32(conn.inactivityCnt, 1)
|
||||
switch {
|
||||
case cnt > overInactivityCnt*2:
|
||||
delete(ln.conns, k)
|
||||
case cnt > overInactivityCnt:
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PacketConn simulates a connection for a stateless protocol.
|
||||
type PacketConn struct {
|
||||
ln *PacketListener
|
||||
addr net.Addr
|
||||
closed *abool.AtomicBool
|
||||
closing chan struct{}
|
||||
|
||||
buf []byte
|
||||
in chan []byte
|
||||
|
||||
inactivityCnt *uint32
|
||||
}
|
||||
|
||||
// Read reads data from the connection.
|
||||
// Read can be made to time out and return an error after a fixed
|
||||
// time limit; see SetDeadline and SetReadDeadline.
|
||||
func (conn *PacketConn) Read(b []byte) (n int, err error) {
|
||||
// Check if connection is closed.
|
||||
if conn.closed.IsSet() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Mark as active.
|
||||
atomic.StoreUint32(conn.inactivityCnt, 0)
|
||||
|
||||
// Get new buffer.
|
||||
if conn.buf == nil {
|
||||
select {
|
||||
case conn.buf = <-conn.in:
|
||||
if conn.buf == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
case <-conn.closing:
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// Serve from buffer.
|
||||
copy(b, conn.buf)
|
||||
if len(b) >= len(conn.buf) {
|
||||
copied := len(conn.buf)
|
||||
conn.buf = nil
|
||||
return copied, nil
|
||||
}
|
||||
copied := len(b)
|
||||
conn.buf = conn.buf[copied:]
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
// Write writes data to the connection.
|
||||
// Write can be made to time out and return an error after a fixed
|
||||
// time limit; see SetDeadline and SetWriteDeadline.
|
||||
func (conn *PacketConn) Write(b []byte) (n int, err error) {
|
||||
// Check if connection is closed.
|
||||
if conn.closed.IsSet() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Mark as active.
|
||||
atomic.StoreUint32(conn.inactivityCnt, 0)
|
||||
|
||||
return conn.ln.sock.WriteTo(b, conn.addr)
|
||||
}
|
||||
|
||||
// Close is a no-op as UDP connections share a single socket. Just stop sending
|
||||
// packets without closing.
|
||||
func (conn *PacketConn) Close() error {
|
||||
if conn.closed.SetToIf(false, true) {
|
||||
close(conn.closing)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (conn *PacketConn) LocalAddr() net.Addr {
|
||||
return conn.ln.sock.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (conn *PacketConn) RemoteAddr() net.Addr {
|
||||
return conn.addr
|
||||
}
|
||||
|
||||
// SetDeadline is a no-op as UDP connections share a single socket.
|
||||
func (conn *PacketConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline is a no-op as UDP connections share a single socket.
|
||||
func (conn *PacketConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline is a no-op as UDP connections share a single socket.
|
||||
func (conn *PacketConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
78
spn/sluice/request.go
Normal file
78
spn/sluice/request.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package sluice
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSluiceTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnsupported is returned when a protocol is not supported.
|
||||
ErrUnsupported = errors.New("unsupported protocol")
|
||||
|
||||
// ErrSluiceOffline is returned when the sluice for a network is offline.
|
||||
ErrSluiceOffline = errors.New("is offline")
|
||||
)
|
||||
|
||||
// Request holds request data for a sluice entry.
|
||||
type Request struct {
|
||||
ConnInfo *network.Connection
|
||||
CallbackFn RequestCallbackFunc
|
||||
Expires time.Time
|
||||
}
|
||||
|
||||
// RequestCallbackFunc is called for taking a over handling connection that arrived at the sluice.
|
||||
type RequestCallbackFunc func(connInfo *network.Connection, conn net.Conn)
|
||||
|
||||
// AwaitRequest pre-registers a connection at the sluice for initializing it when it arrives.
|
||||
func AwaitRequest(connInfo *network.Connection, callbackFn RequestCallbackFunc) error {
|
||||
network := getNetworkFromConnInfo(connInfo)
|
||||
if network == "" {
|
||||
return ErrUnsupported
|
||||
}
|
||||
|
||||
sluice, ok := getSluice(network)
|
||||
if !ok {
|
||||
return fmt.Errorf("sluice for network %s %w", network, ErrSluiceOffline)
|
||||
}
|
||||
|
||||
return sluice.AwaitRequest(&Request{
|
||||
ConnInfo: connInfo,
|
||||
CallbackFn: callbackFn,
|
||||
Expires: time.Now().Add(defaultSluiceTTL),
|
||||
})
|
||||
}
|
||||
|
||||
func getNetworkFromConnInfo(connInfo *network.Connection) string {
|
||||
var network string
|
||||
|
||||
// protocol
|
||||
switch connInfo.IPProtocol { //nolint:exhaustive // Looking for specific values.
|
||||
case packet.TCP:
|
||||
network = "tcp"
|
||||
case packet.UDP:
|
||||
network = "udp"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
// IP version
|
||||
switch connInfo.IPVersion {
|
||||
case packet.IPv4:
|
||||
network += "4"
|
||||
case packet.IPv6:
|
||||
network += "6"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
return network
|
||||
}
|
||||
229
spn/sluice/sluice.go
Normal file
229
spn/sluice/sluice.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package sluice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
)
|
||||
|
||||
// Sluice is a tunnel entry listener.
|
||||
type Sluice struct {
|
||||
network string
|
||||
address string
|
||||
createListener ListenerFactory
|
||||
|
||||
lock sync.Mutex
|
||||
listener net.Listener
|
||||
pendingRequests map[string]*Request
|
||||
abandoned bool
|
||||
}
|
||||
|
||||
// ListenerFactory defines a function to create a listener.
|
||||
type ListenerFactory func(network, address string) (net.Listener, error)
|
||||
|
||||
// StartSluice starts a sluice listener at the given address.
|
||||
func StartSluice(network, address string) {
|
||||
s := &Sluice{
|
||||
network: network,
|
||||
address: address,
|
||||
pendingRequests: make(map[string]*Request),
|
||||
}
|
||||
|
||||
switch s.network {
|
||||
case "tcp4", "tcp6":
|
||||
s.createListener = net.Listen
|
||||
case "udp4", "udp6":
|
||||
s.createListener = ListenUDP
|
||||
default:
|
||||
log.Errorf("spn/sluice: cannot start sluice for %s: unsupported network", network)
|
||||
return
|
||||
}
|
||||
|
||||
// Start service worker.
|
||||
module.StartServiceWorker(
|
||||
fmt.Sprintf("%s sluice listener", s.network),
|
||||
10*time.Second,
|
||||
s.listenHandler,
|
||||
)
|
||||
}
|
||||
|
||||
// AwaitRequest pre-registers a connection.
|
||||
func (s *Sluice) AwaitRequest(r *Request) error {
|
||||
// Set default expiry.
|
||||
if r.Expires.IsZero() {
|
||||
r.Expires = time.Now().Add(defaultSluiceTTL)
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
// Check if a pending request already exists for this local address.
|
||||
key := net.JoinHostPort(r.ConnInfo.LocalIP.String(), strconv.Itoa(int(r.ConnInfo.LocalPort)))
|
||||
_, exists := s.pendingRequests[key]
|
||||
if exists {
|
||||
return fmt.Errorf("a pending request for %s already exists", key)
|
||||
}
|
||||
|
||||
// Add to pending requests.
|
||||
s.pendingRequests[key] = r
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Sluice) getRequest(address string) (r *Request, ok bool) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
r, ok = s.pendingRequests[address]
|
||||
if ok {
|
||||
delete(s.pendingRequests, address)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Sluice) init() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.abandoned = false
|
||||
|
||||
// start listening
|
||||
s.listener = nil
|
||||
ln, err := s.createListener(s.network, s.address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
s.listener = ln
|
||||
|
||||
// Add to registry.
|
||||
addSluice(s)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Sluice) abandon() {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
if s.abandoned {
|
||||
return
|
||||
}
|
||||
s.abandoned = true
|
||||
|
||||
// Remove from registry.
|
||||
removeSluice(s.network)
|
||||
|
||||
// Close listener.
|
||||
if s.listener != nil {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
|
||||
// Notify pending requests.
|
||||
for i, r := range s.pendingRequests {
|
||||
r.CallbackFn(r.ConnInfo, nil)
|
||||
delete(s.pendingRequests, i)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sluice) handleConnection(conn net.Conn) {
|
||||
// Close the connection if handling is not successful.
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Get IP address.
|
||||
var remoteIP net.IP
|
||||
switch typedAddr := conn.RemoteAddr().(type) {
|
||||
case *net.TCPAddr:
|
||||
remoteIP = typedAddr.IP
|
||||
case *net.UDPAddr:
|
||||
remoteIP = typedAddr.IP
|
||||
default:
|
||||
log.Warningf("spn/sluice: cannot handle connection for unsupported network %s", conn.RemoteAddr().Network())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the request is local.
|
||||
local, err := netenv.IsMyIP(remoteIP)
|
||||
if err != nil {
|
||||
log.Warningf("spn/sluice: failed to check if request from %s is local: %s", remoteIP, err)
|
||||
return
|
||||
}
|
||||
if !local {
|
||||
log.Warningf("spn/sluice: received external request from %s, ignoring", remoteIP)
|
||||
|
||||
// TODO:
|
||||
// Do not allow this to be spammed.
|
||||
// Only allow one trigger per second.
|
||||
// Do not trigger by same "remote IP" in a row.
|
||||
netenv.TriggerNetworkChangeCheck()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Get waiting request.
|
||||
r, ok := s.getRequest(conn.RemoteAddr().String())
|
||||
if !ok {
|
||||
_, err := conn.Write(entrypointInfoMsg)
|
||||
if err != nil {
|
||||
log.Warningf("spn/sluice: new %s request from %s without pending request, but failed to reply with info msg: %s", s.network, conn.RemoteAddr(), err)
|
||||
} else {
|
||||
log.Debugf("spn/sluice: new %s request from %s without pending request, replied with info msg", s.network, conn.RemoteAddr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Hand over to callback.
|
||||
log.Tracef(
|
||||
"spn/sluice: new %s request from %s for %s (%s:%d)",
|
||||
s.network, conn.RemoteAddr(),
|
||||
r.ConnInfo.Entity.Domain, r.ConnInfo.Entity.IP, r.ConnInfo.Entity.Port,
|
||||
)
|
||||
r.CallbackFn(r.ConnInfo, conn)
|
||||
success = true
|
||||
}
|
||||
|
||||
func (s *Sluice) listenHandler(_ context.Context) error {
|
||||
defer s.abandon()
|
||||
err := s.init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle new connections.
|
||||
log.Infof("spn/sluice: started listening for %s requests on %s", s.network, s.listener.Addr())
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if module.IsStopping() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to accept connection: %w", err)
|
||||
}
|
||||
|
||||
// Handle accepted connection.
|
||||
s.handleConnection(conn)
|
||||
|
||||
// Clean up old leftovers.
|
||||
s.cleanConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sluice) cleanConnections() {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for address, request := range s.pendingRequests {
|
||||
if now.After(request.Expires) {
|
||||
delete(s.pendingRequests, address)
|
||||
log.Debugf("spn/sluice: removed expired pending %s connection %s", s.network, request.ConnInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
47
spn/sluice/sluices.go
Normal file
47
spn/sluice/sluices.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package sluice
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
sluices = make(map[string]*Sluice)
|
||||
sluicesLock sync.RWMutex
|
||||
)
|
||||
|
||||
func getSluice(network string) (s *Sluice, ok bool) {
|
||||
sluicesLock.RLock()
|
||||
defer sluicesLock.RUnlock()
|
||||
|
||||
s, ok = sluices[network]
|
||||
return
|
||||
}
|
||||
|
||||
func addSluice(s *Sluice) {
|
||||
sluicesLock.Lock()
|
||||
defer sluicesLock.Unlock()
|
||||
|
||||
sluices[s.network] = s
|
||||
}
|
||||
|
||||
func removeSluice(network string) {
|
||||
sluicesLock.Lock()
|
||||
defer sluicesLock.Unlock()
|
||||
|
||||
delete(sluices, network)
|
||||
}
|
||||
|
||||
func copySluices() map[string]*Sluice {
|
||||
sluicesLock.Lock()
|
||||
defer sluicesLock.Unlock()
|
||||
|
||||
copied := make(map[string]*Sluice, len(sluices))
|
||||
for k, v := range sluices {
|
||||
copied[k] = v
|
||||
}
|
||||
return copied
|
||||
}
|
||||
|
||||
func stopAllSluices() {
|
||||
for _, sluice := range copySluices() {
|
||||
sluice.abandon()
|
||||
}
|
||||
}
|
||||
334
spn/sluice/udp_listener.go
Normal file
334
spn/sluice/udp_listener.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package sluice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
const onWindows = runtime.GOOS == "windows"
|
||||
|
||||
// UDPListener is a listener for UDP.
|
||||
type UDPListener struct {
|
||||
sock *net.UDPConn
|
||||
closed *abool.AtomicBool
|
||||
newConns chan *UDPConn
|
||||
oobSize int
|
||||
|
||||
lock sync.Mutex
|
||||
conns map[string]*UDPConn
|
||||
err error
|
||||
}
|
||||
|
||||
// ListenUDP creates a packet listener.
|
||||
func ListenUDP(network, address string) (net.Listener, error) {
|
||||
// Parse address.
|
||||
udpAddr, err := net.ResolveUDPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine oob data size.
|
||||
oobSize := 40 // IPv6 (measured)
|
||||
if udpAddr.IP.To4() != nil {
|
||||
oobSize = 32 // IPv4 (measured)
|
||||
}
|
||||
|
||||
// Create a new listening UDP socket.
|
||||
sock, err := net.ListenUDP(network, udpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create listener.
|
||||
ln := &UDPListener{
|
||||
sock: sock,
|
||||
closed: abool.New(),
|
||||
newConns: make(chan *UDPConn),
|
||||
oobSize: oobSize,
|
||||
conns: make(map[string]*UDPConn),
|
||||
}
|
||||
|
||||
// Set socket options on listener.
|
||||
err = ln.setSocketOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start workers.
|
||||
module.StartServiceWorker("udp listener reader", 0, ln.reader)
|
||||
module.StartServiceWorker("udp listener cleaner", time.Minute, ln.cleaner)
|
||||
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (ln *UDPListener) Accept() (net.Conn, error) {
|
||||
conn := <-ln.newConns
|
||||
if conn != nil {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check if there is a socket error.
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
if ln.err != nil {
|
||||
return nil, ln.err
|
||||
}
|
||||
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
func (ln *UDPListener) Close() error {
|
||||
if !ln.closed.SetToIf(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close all channels.
|
||||
close(ln.newConns)
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
for _, conn := range ln.conns {
|
||||
close(conn.in)
|
||||
}
|
||||
|
||||
// Close socket.
|
||||
return ln.sock.Close()
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (ln *UDPListener) Addr() net.Addr {
|
||||
return ln.sock.LocalAddr()
|
||||
}
|
||||
|
||||
func (ln *UDPListener) getConn(remoteAddr string) (conn *UDPConn, ok bool) {
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
|
||||
conn, ok = ln.conns[remoteAddr]
|
||||
return
|
||||
}
|
||||
|
||||
func (ln *UDPListener) setConn(conn *UDPConn) {
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
|
||||
ln.conns[conn.addr.String()] = conn
|
||||
}
|
||||
|
||||
func (ln *UDPListener) reader(_ context.Context) error {
|
||||
for {
|
||||
// TODO: Find good buf size.
|
||||
// With a buf size of 512 we have seen this error on Windows:
|
||||
// wsarecvmsg: A message sent on a datagram socket was larger than the internal message buffer or some other network limit, or the buffer used to receive a datagram into was smaller than the datagram itself.
|
||||
// UDP is not (yet) heavily used, so we can go for the 1500 bytes size for now.
|
||||
|
||||
// Read data from connection.
|
||||
buf := make([]byte, 1500) // TODO: see comment above.
|
||||
oob := make([]byte, ln.oobSize)
|
||||
n, oobn, _, addr, err := ln.sock.ReadMsgUDP(buf, oob)
|
||||
if err != nil {
|
||||
// Set socket error.
|
||||
ln.lock.Lock()
|
||||
ln.err = err
|
||||
ln.lock.Unlock()
|
||||
// Close and return
|
||||
_ = ln.Close()
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
buf = buf[:n]
|
||||
oob = oob[:oobn]
|
||||
|
||||
// Get connection and supply data.
|
||||
conn, ok := ln.getConn(addr.String())
|
||||
if ok {
|
||||
// Ignore if conn is closed.
|
||||
if conn.closed.IsSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case conn.in <- buf:
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Or create a new connection.
|
||||
conn = &UDPConn{
|
||||
ln: ln,
|
||||
addr: addr,
|
||||
oob: oob,
|
||||
closed: abool.New(),
|
||||
closing: make(chan struct{}),
|
||||
buf: buf,
|
||||
in: make(chan []byte, 1),
|
||||
inactivityCnt: new(uint32),
|
||||
}
|
||||
ln.setConn(conn)
|
||||
ln.newConns <- conn
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *UDPListener) cleaner(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(1 * time.Minute):
|
||||
// Check if listener has died.
|
||||
if ln.closed.IsSet() {
|
||||
return nil
|
||||
}
|
||||
// Clean connections.
|
||||
ln.cleanInactiveConns(10)
|
||||
|
||||
case <-ctx.Done():
|
||||
// Exit with module stop.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *UDPListener) cleanInactiveConns(overInactivityCnt uint32) {
|
||||
ln.lock.Lock()
|
||||
defer ln.lock.Unlock()
|
||||
|
||||
for k, conn := range ln.conns {
|
||||
cnt := atomic.AddUint32(conn.inactivityCnt, 1)
|
||||
switch {
|
||||
case cnt > overInactivityCnt*2:
|
||||
delete(ln.conns, k)
|
||||
case cnt > overInactivityCnt:
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setUDPSocketOptions sets socket options so that the source address for
|
||||
// replies is correct.
|
||||
func (ln *UDPListener) setSocketOptions() error {
|
||||
// Setting socket options is not supported on windows.
|
||||
if onWindows {
|
||||
return nil
|
||||
}
|
||||
|
||||
// As we might be listening on an interface that supports both IPv4 and IPv6,
|
||||
// try to set the socket options on both.
|
||||
// Only report an error if it fails on both.
|
||||
err4 := ipv4.NewPacketConn(ln.sock).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
|
||||
err6 := ipv6.NewPacketConn(ln.sock).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
|
||||
if err4 != nil && err6 != nil {
|
||||
return err4
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UDPConn simulates a connection for a stateless protocol.
|
||||
type UDPConn struct {
|
||||
ln *UDPListener
|
||||
addr *net.UDPAddr
|
||||
oob []byte
|
||||
closed *abool.AtomicBool
|
||||
closing chan struct{}
|
||||
|
||||
buf []byte
|
||||
in chan []byte
|
||||
|
||||
inactivityCnt *uint32
|
||||
}
|
||||
|
||||
// Read reads data from the connection.
|
||||
// Read can be made to time out and return an error after a fixed
|
||||
// time limit; see SetDeadline and SetReadDeadline.
|
||||
func (conn *UDPConn) Read(b []byte) (n int, err error) {
|
||||
// Check if connection is closed.
|
||||
if conn.closed.IsSet() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Mark as active.
|
||||
atomic.StoreUint32(conn.inactivityCnt, 0)
|
||||
|
||||
// Get new buffer.
|
||||
if conn.buf == nil {
|
||||
select {
|
||||
case conn.buf = <-conn.in:
|
||||
if conn.buf == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
case <-conn.closing:
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// Serve from buffer.
|
||||
copy(b, conn.buf)
|
||||
if len(b) >= len(conn.buf) {
|
||||
copied := len(conn.buf)
|
||||
conn.buf = nil
|
||||
return copied, nil
|
||||
}
|
||||
copied := len(b)
|
||||
conn.buf = conn.buf[copied:]
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
// Write writes data to the connection.
|
||||
// Write can be made to time out and return an error after a fixed
|
||||
// time limit; see SetDeadline and SetWriteDeadline.
|
||||
func (conn *UDPConn) Write(b []byte) (n int, err error) {
|
||||
// Check if connection is closed.
|
||||
if conn.closed.IsSet() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Mark as active.
|
||||
atomic.StoreUint32(conn.inactivityCnt, 0)
|
||||
|
||||
n, _, err = conn.ln.sock.WriteMsgUDP(b, conn.oob, conn.addr)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close is a no-op as UDP connections share a single socket. Just stop sending
|
||||
// packets without closing.
|
||||
func (conn *UDPConn) Close() error {
|
||||
if conn.closed.SetToIf(false, true) {
|
||||
close(conn.closing)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (conn *UDPConn) LocalAddr() net.Addr {
|
||||
return conn.ln.sock.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (conn *UDPConn) RemoteAddr() net.Addr {
|
||||
return conn.addr
|
||||
}
|
||||
|
||||
// SetDeadline is a no-op as UDP connections share a single socket.
|
||||
func (conn *UDPConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline is a no-op as UDP connections share a single socket.
|
||||
func (conn *UDPConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline is a no-op as UDP connections share a single socket.
|
||||
func (conn *UDPConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user