wip: migrate to mono-repo. SPN has already been moved to spn/

This commit is contained in:
Patrick Pacher
2024-03-15 11:55:13 +01:00
parent b30fd00ccf
commit 8579430db9
577 changed files with 35981 additions and 818 deletions

View File

@@ -0,0 +1,454 @@
package terminal
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/modules"
)
// FlowControl defines the flow control interface.
type FlowControl interface {
Deliver(msg *Msg) *Error
Receive() <-chan *Msg
Send(msg *Msg, timeout time.Duration) *Error
ReadyToSend() <-chan struct{}
Flush(timeout time.Duration)
StartWorkers(m *modules.Module, terminalName string)
RecvQueueLen() int
SendQueueLen() int
}
// FlowControlType represents a flow control type.
type FlowControlType uint8
// Flow Control Types.
const (
FlowControlDefault FlowControlType = 0
FlowControlDFQ FlowControlType = 1
FlowControlNone FlowControlType = 2
defaultFlowControl = FlowControlDFQ
)
// DefaultSize returns the default flow control size.
func (fct FlowControlType) DefaultSize() uint32 {
if fct == FlowControlDefault {
fct = defaultFlowControl
}
switch fct {
case FlowControlDFQ:
return 50000
case FlowControlNone:
return 10000
case FlowControlDefault:
fallthrough
default:
return 0
}
}
// Flow Queue Configuration.
const (
DefaultQueueSize = 50000
MaxQueueSize = 1000000
forceReportBelowPercent = 0.75
)
// DuplexFlowQueue is a duplex flow control mechanism using queues.
type DuplexFlowQueue struct {
// ti is the Terminal that is using the DFQ.
ctx context.Context
// submitUpstream is used to submit messages to the upstream channel.
submitUpstream func(msg *Msg, timeout time.Duration)
// sendQueue holds the messages that are waiting to be sent.
sendQueue chan *Msg
// prioMsgs holds the number of messages to send with high priority.
prioMsgs *int32
// sendSpace indicates the amount free slots in the recvQueue on the other end.
sendSpace *int32
// readyToSend is used to notify sending components that there is free space.
readyToSend chan struct{}
// wakeSender is used to wake a sender in case the sendSpace was zero and the
// sender is waiting for available space.
wakeSender chan struct{}
// recvQueue holds the messages that are waiting to be processed.
recvQueue chan *Msg
// reportedSpace indicates the amount of free slots that the other end knows
// about.
reportedSpace *int32
// spaceReportLock locks the calculation of space to report.
spaceReportLock sync.Mutex
// forceSpaceReport forces the sender to send a space report.
forceSpaceReport chan struct{}
// flush is used to send a finish function to the handler, which will write
// all pending messages and then call the received function.
flush chan func()
}
// NewDuplexFlowQueue returns a new duplex flow queue.
func NewDuplexFlowQueue(
ctx context.Context,
queueSize uint32,
submitUpstream func(msg *Msg, timeout time.Duration),
) *DuplexFlowQueue {
dfq := &DuplexFlowQueue{
ctx: ctx,
submitUpstream: submitUpstream,
sendQueue: make(chan *Msg, queueSize),
prioMsgs: new(int32),
sendSpace: new(int32),
readyToSend: make(chan struct{}),
wakeSender: make(chan struct{}, 1),
recvQueue: make(chan *Msg, queueSize),
reportedSpace: new(int32),
forceSpaceReport: make(chan struct{}, 1),
flush: make(chan func()),
}
atomic.StoreInt32(dfq.sendSpace, int32(queueSize))
atomic.StoreInt32(dfq.reportedSpace, int32(queueSize))
return dfq
}
// StartWorkers starts the necessary workers to operate the flow queue.
func (dfq *DuplexFlowQueue) StartWorkers(m *modules.Module, terminalName string) {
m.StartWorker(terminalName+" flow queue", dfq.FlowHandler)
}
// shouldReportRecvSpace returns whether the receive space should be reported.
func (dfq *DuplexFlowQueue) shouldReportRecvSpace() bool {
return atomic.LoadInt32(dfq.reportedSpace) < int32(float32(cap(dfq.recvQueue))*forceReportBelowPercent)
}
// decrementReportedRecvSpace decreases the reported recv space by 1 and
// returns if the receive space should be reported.
func (dfq *DuplexFlowQueue) decrementReportedRecvSpace() (shouldReportRecvSpace bool) {
return atomic.AddInt32(dfq.reportedSpace, -1) < int32(float32(cap(dfq.recvQueue))*forceReportBelowPercent)
}
// getSendSpace returns the current send space.
func (dfq *DuplexFlowQueue) getSendSpace() int32 {
return atomic.LoadInt32(dfq.sendSpace)
}
// decrementSendSpace decreases the send space by 1 and returns it.
func (dfq *DuplexFlowQueue) decrementSendSpace() int32 {
return atomic.AddInt32(dfq.sendSpace, -1)
}
func (dfq *DuplexFlowQueue) addToSendSpace(n int32) {
// Add new space to send space and check if it was zero.
atomic.AddInt32(dfq.sendSpace, n)
// Wake the sender in case it is waiting.
select {
case dfq.wakeSender <- struct{}{}:
default:
}
}
// reportableRecvSpace returns how much free space can be reported to the other
// end. The returned number must be communicated to the other end and must not
// be ignored.
func (dfq *DuplexFlowQueue) reportableRecvSpace() int32 {
// Changes to the recvQueue during calculation are no problem.
// We don't want to report space twice though!
dfq.spaceReportLock.Lock()
defer dfq.spaceReportLock.Unlock()
// Calculate reportable receive space and add it to the reported space.
reportedSpace := atomic.LoadInt32(dfq.reportedSpace)
toReport := int32(cap(dfq.recvQueue)-len(dfq.recvQueue)) - reportedSpace
// Never report values below zero.
// This can happen, as dfq.reportedSpace is decreased after a container is
// submitted to dfq.recvQueue by dfq.Deliver(). This race condition can only
// lower the space to report, not increase it. A simple check here solved
// this problem and keeps performance high.
// Also, don't report values of 1, as the benefit is minimal and this might
// be commonly triggered due to the buffer of the force report channel.
if toReport <= 1 {
return 0
}
// Add space to report to dfq.reportedSpace and return it.
atomic.AddInt32(dfq.reportedSpace, toReport)
return toReport
}
// FlowHandler handles all flow queue internals and must be started as a worker
// in the module where it is used.
func (dfq *DuplexFlowQueue) FlowHandler(_ context.Context) error {
// The upstreamSender is started by the terminal module, but is tied to the
// flow owner instead. Make sure that the flow owner's module depends on the
// terminal module so that it is shut down earlier.
var sendSpaceDepleted bool
var flushFinished func()
// Drain all queues when shutting down.
defer func() {
for {
select {
case msg := <-dfq.sendQueue:
msg.Finish()
case msg := <-dfq.recvQueue:
msg.Finish()
default:
return
}
}
}()
sending:
for {
// If the send queue is depleted, wait to be woken.
if sendSpaceDepleted {
select {
case <-dfq.wakeSender:
if dfq.getSendSpace() > 0 {
sendSpaceDepleted = false
} else {
continue sending
}
case <-dfq.forceSpaceReport:
// Forced reporting of space.
// We do not need to check if there is enough sending space, as there is
// no data included.
spaceToReport := dfq.reportableRecvSpace()
if spaceToReport > 0 {
msg := NewMsg(varint.Pack64(uint64(spaceToReport)))
dfq.submitUpstream(msg, 0)
}
continue sending
case <-dfq.ctx.Done():
return nil
}
}
// Get message from send queue.
select {
case dfq.readyToSend <- struct{}{}:
// Notify that we are ready to send.
case msg := <-dfq.sendQueue:
// Send message from queue.
// If nil, the queue is being shut down.
if msg == nil {
return nil
}
// Check if we are handling a high priority message or waiting for one.
// Mark any msgs as high priority, when there is one in the pipeline.
remainingPrioMsgs := atomic.AddInt32(dfq.prioMsgs, -1)
switch {
case remainingPrioMsgs >= 0:
msg.Unit.MakeHighPriority()
case remainingPrioMsgs < -30_000:
// Prevent wrap to positive.
// Compatible with int16 or bigger.
atomic.StoreInt32(dfq.prioMsgs, 0)
}
// Wait for processing slot.
msg.Unit.WaitForSlot()
// Prepend available receiving space.
msg.Data.Prepend(varint.Pack64(uint64(dfq.reportableRecvSpace())))
// Submit for sending upstream.
dfq.submitUpstream(msg, 0)
// Decrease the send space and set flag if depleted.
if dfq.decrementSendSpace() <= 0 {
sendSpaceDepleted = true
}
// Check if the send queue is empty now and signal flushers.
if flushFinished != nil && len(dfq.sendQueue) == 0 {
flushFinished()
flushFinished = nil
}
case <-dfq.forceSpaceReport:
// Forced reporting of space.
// We do not need to check if there is enough sending space, as there is
// no data included.
spaceToReport := dfq.reportableRecvSpace()
if spaceToReport > 0 {
msg := NewMsg(varint.Pack64(uint64(spaceToReport)))
dfq.submitUpstream(msg, 0)
}
case newFlushFinishedFn := <-dfq.flush:
// Signal immediately if send queue is empty.
if len(dfq.sendQueue) == 0 {
newFlushFinishedFn()
} else {
// If there already is a flush finished function, stack them.
if flushFinished != nil {
stackedFlushFinishFn := flushFinished
flushFinished = func() {
stackedFlushFinishFn()
newFlushFinishedFn()
}
} else {
flushFinished = newFlushFinishedFn
}
}
case <-dfq.ctx.Done():
return nil
}
}
}
// Flush waits for all waiting data to be sent.
func (dfq *DuplexFlowQueue) Flush(timeout time.Duration) {
// Create channel and function for notifying.
wait := make(chan struct{})
finished := func() {
close(wait)
}
// Request flush and return when stopping.
select {
case dfq.flush <- finished:
case <-dfq.ctx.Done():
return
case <-TimedOut(timeout):
return
}
// Wait for flush to finish and return when stopping.
select {
case <-wait:
case <-dfq.ctx.Done():
case <-TimedOut(timeout):
}
}
var ready = make(chan struct{})
func init() {
close(ready)
}
// ReadyToSend returns a channel that can be read when data can be sent.
func (dfq *DuplexFlowQueue) ReadyToSend() <-chan struct{} {
if atomic.LoadInt32(dfq.sendSpace) > 0 {
return ready
}
return dfq.readyToSend
}
// Send adds the given container to the send queue.
func (dfq *DuplexFlowQueue) Send(msg *Msg, timeout time.Duration) *Error {
select {
case dfq.sendQueue <- msg:
if msg.Unit.IsHighPriority() {
// Reset prioMsgs to the current queue size, so that all waiting and the
// message we just added are all handled as high priority.
atomic.StoreInt32(dfq.prioMsgs, int32(len(dfq.sendQueue)))
}
return nil
case <-TimedOut(timeout):
msg.Finish()
return ErrTimeout
case <-dfq.ctx.Done():
msg.Finish()
return ErrStopping
}
}
// Receive receives a container from the recv queue.
func (dfq *DuplexFlowQueue) Receive() <-chan *Msg {
// If the reported recv space is nearing its end, force a report.
if dfq.shouldReportRecvSpace() {
select {
case dfq.forceSpaceReport <- struct{}{}:
default:
}
}
return dfq.recvQueue
}
// Deliver submits a container for receiving from upstream.
func (dfq *DuplexFlowQueue) Deliver(msg *Msg) *Error {
// Ignore nil containers.
if msg == nil || msg.Data == nil {
msg.Finish()
return ErrMalformedData.With("no data")
}
// Get and add new reported space.
addSpace, err := msg.Data.GetNextN16()
if err != nil {
msg.Finish()
return ErrMalformedData.With("failed to parse reported space: %w", err)
}
if addSpace > 0 {
dfq.addToSendSpace(int32(addSpace))
}
// Abort processing if the container only contained a space update.
if !msg.Data.HoldsData() {
msg.Finish()
return nil
}
select {
case dfq.recvQueue <- msg:
// If the recv queue accepted the Container, decrement the recv space.
shouldReportRecvSpace := dfq.decrementReportedRecvSpace()
// If the reported recv space is nearing its end, force a report, if the
// sender worker is idle.
if shouldReportRecvSpace {
select {
case dfq.forceSpaceReport <- struct{}{}:
default:
}
}
return nil
default:
// If the recv queue is full, return an error.
// The whole point of the flow queue is to guarantee that this never happens.
msg.Finish()
return ErrQueueOverflow
}
}
// FlowStats returns a k=v formatted string of internal stats.
func (dfq *DuplexFlowQueue) FlowStats() string {
return fmt.Sprintf(
"sq=%d rq=%d sends=%d reps=%d",
len(dfq.sendQueue),
len(dfq.recvQueue),
atomic.LoadInt32(dfq.sendSpace),
atomic.LoadInt32(dfq.reportedSpace),
)
}
// RecvQueueLen returns the current length of the receive queue.
func (dfq *DuplexFlowQueue) RecvQueueLen() int {
return len(dfq.recvQueue)
}
// SendQueueLen returns the current length of the send queue.
func (dfq *DuplexFlowQueue) SendQueueLen() int {
return len(dfq.sendQueue)
}

36
spn/terminal/defaults.go Normal file
View File

@@ -0,0 +1,36 @@
package terminal
const (
// UsePriorityDataMsgs defines whether priority data messages should be used.
UsePriorityDataMsgs = true
)
// DefaultCraneControllerOpts returns the default terminal options for a crane
// controller terminal.
func DefaultCraneControllerOpts() *TerminalOpts {
return &TerminalOpts{
Padding: 0, // Crane already applies padding.
FlowControl: FlowControlNone,
UsePriorityDataMsgs: UsePriorityDataMsgs,
}
}
// DefaultHomeHubTerminalOpts returns the default terminal options for a crane
// terminal used for the home hub.
func DefaultHomeHubTerminalOpts() *TerminalOpts {
return &TerminalOpts{
Padding: 0, // Crane already applies padding.
FlowControl: FlowControlDFQ,
UsePriorityDataMsgs: UsePriorityDataMsgs,
}
}
// DefaultExpansionTerminalOpts returns the default terminal options for an
// expansion terminal.
func DefaultExpansionTerminalOpts() *TerminalOpts {
return &TerminalOpts{
Padding: 8,
FlowControl: FlowControlDFQ,
UsePriorityDataMsgs: UsePriorityDataMsgs,
}
}

221
spn/terminal/errors.go Normal file
View File

@@ -0,0 +1,221 @@
package terminal
import (
"context"
"errors"
"fmt"
"github.com/safing/portbase/formats/varint"
)
// Error is a terminal error.
type Error struct {
// id holds the internal error ID.
id uint8
// external signifies if the error was received from the outside.
external bool
// err holds the wrapped error or the default error message.
err error
}
// ID returns the internal ID of the error.
func (e *Error) ID() uint8 {
return e.id
}
// Error returns the human readable format of the error.
func (e *Error) Error() string {
if e.external {
return "[ext] " + e.err.Error()
}
return e.err.Error()
}
// IsExternal returns whether the error occurred externally.
func (e *Error) IsExternal() bool {
if e == nil {
return false
}
return e.external
}
// Is returns whether the given error is of the same type.
func (e *Error) Is(target error) bool {
if e == nil || target == nil {
return false
}
t, ok := target.(*Error) //nolint:errorlint // Error implementation, not usage.
if !ok {
return false
}
return e.id == t.id
}
// Unwrap returns the wrapped error.
func (e *Error) Unwrap() error {
if e == nil || e.err == nil {
return nil
}
return e.err
}
// With adds context and details where the error occurred. The provided
// message is appended to the error.
// A new error with the same ID is returned and must be compared with
// errors.Is().
func (e *Error) With(format string, a ...interface{}) *Error {
// Return nil if error is nil.
if e == nil {
return nil
}
return &Error{
id: e.id,
err: fmt.Errorf(e.Error()+": "+format, a...),
}
}
// Wrap adds context higher up in the call chain. The provided message is
// prepended to the error.
// A new error with the same ID is returned and must be compared with
// errors.Is().
func (e *Error) Wrap(format string, a ...interface{}) *Error {
// Return nil if error is nil.
if e == nil {
return nil
}
return &Error{
id: e.id,
err: fmt.Errorf(format+": "+e.Error(), a...),
}
}
// AsExternal creates and returns an external version of the error.
func (e *Error) AsExternal() *Error {
// Return nil if error is nil.
if e == nil {
return nil
}
return &Error{
id: e.id,
err: e.err,
external: true,
}
}
// Pack returns the serialized internal error ID. The additional message is
// lost and is replaced with the default message upon parsing.
func (e *Error) Pack() []byte {
// Return nil slice if error is nil.
if e == nil {
return nil
}
return varint.Pack8(e.id)
}
// ParseExternalError parses an external error.
func ParseExternalError(id []byte) (*Error, error) {
// Return nil for an empty error.
if len(id) == 0 {
return ErrStopping.AsExternal(), nil
}
parsedID, _, err := varint.Unpack8(id)
if err != nil {
return nil, fmt.Errorf("failed to unpack error ID: %w", err)
}
return NewExternalError(parsedID), nil
}
// NewExternalError creates an external error based on the given ID.
func NewExternalError(id uint8) *Error {
err, ok := errorRegistry[id]
if ok {
return err.AsExternal()
}
return ErrUnknownError.AsExternal()
}
var errorRegistry = make(map[uint8]*Error)
func registerError(id uint8, err error) *Error {
// Check for duplicate.
_, ok := errorRegistry[id]
if ok {
panic(fmt.Sprintf("error with id %d already registered", id))
}
newErr := &Error{
id: id,
err: err,
}
errorRegistry[id] = newErr
return newErr
}
// func (e *Error) IsSpecial() bool {
// if e == nil {
// return false
// }
// return e.id > 0 && e.id < 8
// }
// IsOK returns if the error represents a "OK" or success status.
func (e *Error) IsOK() bool {
return !e.IsError()
}
// IsError returns if the error represents an erronous condition.
func (e *Error) IsError() bool {
if e == nil || e.err == nil {
return false
}
if e.id == 0 || e.id >= 8 {
return true
}
return false
}
// Terminal Errors.
var (
// ErrUnknownError is the default error.
ErrUnknownError = registerError(0, errors.New("unknown error"))
// Error IDs 1-7 are reserved for special "OK" values.
ErrStopping = registerError(2, errors.New("stopping"))
ErrExplicitAck = registerError(3, errors.New("explicit ack"))
ErrNoActivity = registerError(4, errors.New("no activity"))
// Errors IDs 8 and up are for regular errors.
ErrInternalError = registerError(8, errors.New("internal error"))
ErrMalformedData = registerError(9, errors.New("malformed data"))
ErrUnexpectedMsgType = registerError(10, errors.New("unexpected message type"))
ErrUnknownOperationType = registerError(11, errors.New("unknown operation type"))
ErrUnknownOperationID = registerError(12, errors.New("unknown operation id"))
ErrPermissionDenied = registerError(13, errors.New("permission denied"))
ErrIntegrity = registerError(14, errors.New("integrity violated"))
ErrInvalidOptions = registerError(15, errors.New("invalid options"))
ErrHubNotReady = registerError(16, errors.New("hub not ready"))
ErrRateLimited = registerError(24, errors.New("rate limited"))
ErrIncorrectUsage = registerError(22, errors.New("incorrect usage"))
ErrTimeout = registerError(62, errors.New("timed out"))
ErrUnsupportedVersion = registerError(93, errors.New("unsupported version"))
ErrHubUnavailable = registerError(101, errors.New("hub unavailable"))
ErrAbandonedTerminal = registerError(102, errors.New("terminal is being abandoned"))
ErrShipSunk = registerError(108, errors.New("ship sunk"))
ErrDestinationUnavailable = registerError(113, errors.New("destination unavailable"))
ErrTryAgainLater = registerError(114, errors.New("try again later"))
ErrConnectionError = registerError(121, errors.New("connection error"))
ErrQueueOverflow = registerError(122, errors.New("queue overflowed"))
ErrCanceled = registerError(125, context.Canceled)
)

27
spn/terminal/fmt.go Normal file
View File

@@ -0,0 +1,27 @@
package terminal
import "fmt"
// CustomTerminalIDFormatting defines an interface for terminal to define their custom ID format.
type CustomTerminalIDFormatting interface {
CustomIDFormat() string
}
// FmtID formats the terminal ID together with the parent's ID.
func (t *TerminalBase) FmtID() string {
if t.ext != nil {
if customFormatting, ok := t.ext.(CustomTerminalIDFormatting); ok {
return customFormatting.CustomIDFormat()
}
}
return fmtTerminalID(t.parentID, t.id)
}
func fmtTerminalID(craneID string, terminalID uint32) string {
return fmt.Sprintf("%s#%d", craneID, terminalID)
}
func fmtOperationID(craneID string, terminalID, operationID uint32) string {
return fmt.Sprintf("%s#%d>%d", craneID, terminalID, operationID)
}

210
spn/terminal/init.go Normal file
View File

@@ -0,0 +1,210 @@
package terminal
import (
"context"
"github.com/safing/jess"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/hub"
)
/*
Terminal Init Message Format:
- Version [varint]
- Data Block [bytes; not blocked]
- TerminalOpts as DSD
*/
const (
minSupportedTerminalVersion = 1
maxSupportedTerminalVersion = 1
)
// TerminalOpts holds configuration for the terminal.
type TerminalOpts struct { //nolint:golint,maligned // TODO: Rename.
Version uint8 `json:"-"`
Encrypt bool `json:"e,omitempty"`
Padding uint16 `json:"p,omitempty"`
FlowControl FlowControlType `json:"fc,omitempty"`
FlowControlSize uint32 `json:"qs,omitempty"` // Previously was "QueueSize".
UsePriorityDataMsgs bool `json:"pr,omitempty"`
}
// ParseTerminalOpts parses terminal options from the container and checks if
// they are valid.
func ParseTerminalOpts(c *container.Container) (*TerminalOpts, *Error) {
// Parse and check version.
version, err := c.GetNextN8()
if err != nil {
return nil, ErrMalformedData.With("failed to parse version: %w", err)
}
if version < minSupportedTerminalVersion || version > maxSupportedTerminalVersion {
return nil, ErrUnsupportedVersion.With("requested terminal version %d", version)
}
// Parse init message.
initMsg := &TerminalOpts{}
_, err = dsd.Load(c.CompileData(), initMsg)
if err != nil {
return nil, ErrMalformedData.With("failed to parse init message: %w", err)
}
initMsg.Version = version
// Check if options are valid.
tErr := initMsg.Check(false)
if tErr != nil {
return nil, tErr
}
return initMsg, nil
}
// Pack serialized the terminal options and checks if they are valid.
func (opts *TerminalOpts) Pack() (*container.Container, *Error) {
// Check if options are valid.
tErr := opts.Check(true)
if tErr != nil {
return nil, tErr
}
// Pack init message.
optsData, err := dsd.Dump(opts, dsd.CBOR)
if err != nil {
return nil, ErrInternalError.With("failed to pack init message: %w", err)
}
// Compile init message.
return container.New(
varint.Pack8(opts.Version),
optsData,
), nil
}
// Check checks if terminal options are valid.
func (opts *TerminalOpts) Check(useDefaultsForRequired bool) *Error {
// Version is required - use default when permitted.
if opts.Version == 0 && useDefaultsForRequired {
opts.Version = 1
}
if opts.Version < minSupportedTerminalVersion || opts.Version > maxSupportedTerminalVersion {
return ErrInvalidOptions.With("unsupported terminal version %d", opts.Version)
}
// FlowControl is optional.
switch opts.FlowControl {
case FlowControlDefault:
// Set to default flow control.
opts.FlowControl = defaultFlowControl
case FlowControlNone, FlowControlDFQ:
// Ok.
default:
return ErrInvalidOptions.With("unknown flow control type: %d", opts.FlowControl)
}
// FlowControlSize is required as it needs to be same on both sides.
// Use default when permitted.
if opts.FlowControlSize == 0 && useDefaultsForRequired {
opts.FlowControlSize = opts.FlowControl.DefaultSize()
}
if opts.FlowControlSize <= 0 || opts.FlowControlSize > MaxQueueSize {
return ErrInvalidOptions.With("invalid flow control size of %d", opts.FlowControlSize)
}
return nil
}
// NewLocalBaseTerminal creates a new local terminal base for use with inheriting terminals.
func NewLocalBaseTerminal(
ctx context.Context,
id uint32,
parentID string,
remoteHub *hub.Hub,
initMsg *TerminalOpts,
upstream Upstream,
) (
t *TerminalBase,
initData *container.Container,
err *Error,
) {
// Pack, check and add defaults to init message.
initData, err = initMsg.Pack()
if err != nil {
return nil, nil, err
}
// Create baseline.
t, err = createTerminalBase(ctx, id, parentID, false, initMsg, upstream)
if err != nil {
return nil, nil, err
}
// Setup encryption if enabled.
if remoteHub != nil {
initMsg.Encrypt = true
// Select signet (public key) of remote Hub to use.
s := remoteHub.SelectSignet()
if s == nil {
return nil, nil, ErrHubNotReady.With("failed to select signet of remote hub")
}
// Create new session.
env := jess.NewUnconfiguredEnvelope()
env.SuiteID = jess.SuiteWireV1
env.Recipients = []*jess.Signet{s}
jession, err := env.WireCorrespondence(nil)
if err != nil {
return nil, nil, ErrIntegrity.With("failed to initialize encryption: %w", err)
}
t.jession = jession
// Encryption is ready for sending.
close(t.encryptionReady)
}
return t, initData, nil
}
// NewRemoteBaseTerminal creates a new remote terminal base for use with inheriting terminals.
func NewRemoteBaseTerminal(
ctx context.Context,
id uint32,
parentID string,
identity *cabin.Identity,
initData *container.Container,
upstream Upstream,
) (
t *TerminalBase,
initMsg *TerminalOpts,
err *Error,
) {
// Parse init message.
initMsg, err = ParseTerminalOpts(initData)
if err != nil {
return nil, nil, err
}
// Create baseline.
t, err = createTerminalBase(ctx, id, parentID, true, initMsg, upstream)
if err != nil {
return nil, nil, err
}
// Setup encryption if enabled.
if initMsg.Encrypt {
if identity == nil {
return nil, nil, ErrInternalError.With("missing identity for setting up incoming encryption")
}
t.identity = identity
}
return t, initMsg, nil
}

117
spn/terminal/metrics.go Normal file
View File

@@ -0,0 +1,117 @@
package terminal
import (
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/api"
"github.com/safing/portbase/metrics"
)
var metricsRegistered = abool.New()
func registerMetrics() (err error) {
// Only register metrics once.
if !metricsRegistered.SetToIf(false, true) {
return nil
}
// Get scheduler config and calculat scaling.
schedulerConfig := getSchedulerConfig()
scaleSlotToSecondsFactor := float64(time.Second / schedulerConfig.SlotDuration)
// Register metrics from scheduler stats.
_, err = metrics.NewGauge(
"spn/scheduling/unit/slotpace/max",
nil,
metricFromInt(scheduler.GetMaxSlotPace, scaleSlotToSecondsFactor),
&metrics.Options{
Name: "SPN Scheduling Max Slot Pace (scaled to per second)",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/scheduling/unit/slotpace/leveled/max",
nil,
metricFromInt(scheduler.GetMaxLeveledSlotPace, scaleSlotToSecondsFactor),
&metrics.Options{
Name: "SPN Scheduling Max Leveled Slot Pace (scaled to per second)",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/scheduling/unit/slotpace/avg",
nil,
metricFromInt(scheduler.GetAvgSlotPace, scaleSlotToSecondsFactor),
&metrics.Options{
Name: "SPN Scheduling Avg Slot Pace (scaled to per second)",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/scheduling/unit/life/avg/seconds",
nil,
metricFromNanoseconds(scheduler.GetAvgUnitLife),
&metrics.Options{
Name: "SPN Scheduling Avg Unit Life",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/scheduling/unit/workslot/avg/seconds",
nil,
metricFromNanoseconds(scheduler.GetAvgWorkSlotDuration),
&metrics.Options{
Name: "SPN Scheduling Avg Work Slot Duration",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/scheduling/unit/catchupslot/avg/seconds",
nil,
metricFromNanoseconds(scheduler.GetAvgCatchUpSlotDuration),
&metrics.Options{
Name: "SPN Scheduling Avg Catch-Up Slot Duration",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
return nil
}
func metricFromInt(fn func() int64, scaleFactor float64) func() float64 {
return func() float64 {
return float64(fn()) * scaleFactor
}
}
func metricFromNanoseconds(fn func() int64) func() float64 {
return func() float64 {
return float64(fn()) / float64(time.Second)
}
}

80
spn/terminal/module.go Normal file
View File

@@ -0,0 +1,80 @@
package terminal
import (
"flag"
"time"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/rng"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/unit"
)
var (
module *modules.Module
rngFeeder *rng.Feeder = rng.NewFeeder()
scheduler *unit.Scheduler
debugUnitScheduling bool
)
func init() {
flag.BoolVar(&debugUnitScheduling, "debug-unit-scheduling", false, "enable debug logs of the SPN unit scheduler")
module = modules.Register("terminal", nil, start, nil, "base")
}
func start() error {
rngFeeder = rng.NewFeeder()
scheduler = unit.NewScheduler(getSchedulerConfig())
if debugUnitScheduling {
// Debug unit leaks.
scheduler.StartDebugLog()
}
module.StartServiceWorker("msg unit scheduler", 0, scheduler.SlotScheduler)
lockOpRegistry()
return registerMetrics()
}
var waitForever chan time.Time
// TimedOut returns a channel that triggers when the timeout is reached.
func TimedOut(timeout time.Duration) <-chan time.Time {
if timeout == 0 {
return waitForever
}
return time.After(timeout)
}
// StopScheduler stops the unit scheduler.
func StopScheduler() {
if scheduler != nil {
scheduler.Stop()
}
}
func getSchedulerConfig() *unit.SchedulerConfig {
// Client Scheduler Config.
if conf.Client() {
return &unit.SchedulerConfig{
SlotDuration: 10 * time.Millisecond, // 100 slots per second
MinSlotPace: 10, // 1000pps - Small starting pace for low end devices.
WorkSlotPercentage: 0.9, // 90%
SlotChangeRatePerStreak: 0.1, // 10% - Increase/Decrease quickly.
StatCycleDuration: 1 * time.Minute, // Match metrics report cycle.
}
}
// Server Scheduler Config.
return &unit.SchedulerConfig{
SlotDuration: 10 * time.Millisecond, // 100 slots per second
MinSlotPace: 100, // 10000pps - Every server should be able to handle this.
WorkSlotPercentage: 0.7, // 70%
SlotChangeRatePerStreak: 0.05, // 5%
StatCycleDuration: 1 * time.Minute, // Match metrics report cycle.
}
}

View File

@@ -0,0 +1,13 @@
package terminal
import (
"testing"
"github.com/safing/portmaster/service/core/pmtesting"
"github.com/safing/portmaster/spn/conf"
)
func TestMain(m *testing.M) {
conf.EnablePublicHub(true)
pmtesting.TestMain(m, module)
}

106
spn/terminal/msg.go Normal file
View File

@@ -0,0 +1,106 @@
package terminal
import (
"fmt"
"runtime"
"github.com/safing/portbase/container"
"github.com/safing/portmaster/spn/unit"
)
// Msg is a message within the SPN network stack.
// It includes metadata and unit scheduling.
type Msg struct {
FlowID uint32
Type MsgType
Data *container.Container
// Unit scheduling.
// Note: With just 100B per packet, a uint64 (the Unit ID) is enough for
// over 1800 Exabyte. No need for overflow support.
Unit *unit.Unit
}
// NewMsg returns a new msg.
// The FlowID is unset.
// The Type is Data.
func NewMsg(data []byte) *Msg {
msg := &Msg{
Type: MsgTypeData,
Data: container.New(data),
Unit: scheduler.NewUnit(),
}
// Debug unit leaks.
msg.debugWithCaller(2)
return msg
}
// NewEmptyMsg returns a new empty msg with an initialized Unit.
// The FlowID is unset.
// The Type is Data.
// The Data is unset.
func NewEmptyMsg() *Msg {
msg := &Msg{
Type: MsgTypeData,
Unit: scheduler.NewUnit(),
}
// Debug unit leaks.
msg.debugWithCaller(2)
return msg
}
// Pack prepends the message header (Length and ID+Type) to the data.
func (msg *Msg) Pack() {
MakeMsg(msg.Data, msg.FlowID, msg.Type)
}
// Consume adds another Message to itself.
// The given Msg is packed before adding it to the data.
// The data is moved - not copied!
// High priority mark is inherited.
func (msg *Msg) Consume(other *Msg) {
// Pack message to be added.
other.Pack()
// Move data.
msg.Data.AppendContainer(other.Data)
// Inherit high priority.
if other.Unit.IsHighPriority() {
msg.Unit.MakeHighPriority()
}
// Finish other unit.
other.Finish()
}
// Finish signals the unit scheduler that this unit has finished processing.
// Will no-op if called on a nil Msg.
func (msg *Msg) Finish() {
// Proxying is necessary, as a nil msg still panics.
if msg == nil {
return
}
msg.Unit.Finish()
}
// Debug registers the unit for debug output with the given source.
// Additional calls on the same unit update the unit source.
// StartDebugLog() must be called before calling DebugUnit().
func (msg *Msg) Debug() {
msg.debugWithCaller(2)
}
func (msg *Msg) debugWithCaller(skip int) { //nolint:unparam
if !debugUnitScheduling || msg == nil {
return
}
_, file, line, ok := runtime.Caller(skip)
if ok {
scheduler.DebugUnit(msg.Unit, fmt.Sprintf("%s:%d", file, line))
}
}

66
spn/terminal/msgtypes.go Normal file
View File

@@ -0,0 +1,66 @@
package terminal
import (
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/varint"
)
/*
Terminal and Operation Message Format:
- Length [varint]
- If Length is 0, the remainder of given data is padding.
- IDType [varint]
- Type [uses least two significant bits]
- One of Init, Data, Stop
- ID [uses all other bits]
- The ID is currently not adapted in order to make reading raw message
easier. This means that IDs are currently always a multiple of 4.
- Data [bytes; format depends on msg type]
- MsgTypeInit:
- Data [bytes]
- MsgTypeData:
- AddAvailableSpace [varint, if Flow Queue is used]
- (Encrypted) Data [bytes]
- MsgTypeStop:
- Error Code [varint]
*/
// MsgType is the message type for both terminals and operations.
type MsgType uint8
const (
// MsgTypeInit is used to establish a new terminal or run a new operation.
MsgTypeInit MsgType = 1
// MsgTypeData is used to send data to a terminal or operation.
MsgTypeData MsgType = 2
// MsgTypePriorityData is used to send prioritized data to a terminal or operation.
MsgTypePriorityData MsgType = 0
// MsgTypeStop is used to abandon a terminal or end an operation, with an optional error.
MsgTypeStop MsgType = 3
)
// AddIDType prepends the ID and Type header to the message.
func AddIDType(c *container.Container, id uint32, msgType MsgType) {
c.Prepend(varint.Pack32(id | uint32(msgType)))
}
// MakeMsg prepends the message header (Length and ID+Type) to the data.
func MakeMsg(c *container.Container, id uint32, msgType MsgType) {
AddIDType(c, id, msgType)
c.PrependLength()
}
// ParseIDType parses the combined message ID and type.
func ParseIDType(c *container.Container) (id uint32, msgType MsgType, err error) {
idType, err := c.GetNextN32()
if err != nil {
return 0, 0, err
}
msgType = MsgType(idType % 4)
return idType - uint32(msgType), msgType, nil
}

332
spn/terminal/operation.go Normal file
View File

@@ -0,0 +1,332 @@
package terminal
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/container"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
)
// Operation is an interface for all operations.
type Operation interface {
// InitOperationBase initialize the operation with the ID and attached terminal.
// Should not be overridden by implementations.
InitOperationBase(t Terminal, opID uint32)
// ID returns the ID of the operation.
// Should not be overridden by implementations.
ID() uint32
// Type returns the operation's type ID.
// Should be overridden by implementations to return correct type ID.
Type() string
// Deliver delivers a message to the operation.
// Meant to be overridden by implementations.
Deliver(msg *Msg) *Error
// NewMsg creates a new message from this operation.
// Should not be overridden by implementations.
NewMsg(data []byte) *Msg
// Send sends a message to the other side.
// Should not be overridden by implementations.
Send(msg *Msg, timeout time.Duration) *Error
// Flush sends all messages waiting in the terminal.
// Should not be overridden by implementations.
Flush(timeout time.Duration)
// Stopped returns whether the operation has stopped.
// Should not be overridden by implementations.
Stopped() bool
// markStopped marks the operation as stopped.
// It returns whether the stop flag was set.
markStopped() bool
// Stop stops the operation by unregistering it from the terminal and calling HandleStop().
// Should not be overridden by implementations.
Stop(self Operation, err *Error)
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
// Meant to be overridden by implementations.
HandleStop(err *Error) (errorToSend *Error)
// Terminal returns the terminal the operation is linked to.
// Should not be overridden by implementations.
Terminal() Terminal
}
// OperationFactory defines an operation factory.
type OperationFactory struct {
// Type is the type id of an operation.
Type string
// Requires defines the required permissions to run an operation.
Requires Permission
// Start is the function that starts a new operation.
Start OperationStarter
}
// OperationStarter is used to initialize operations remotely.
type OperationStarter func(attachedTerminal Terminal, opID uint32, initData *container.Container) (Operation, *Error)
var (
opRegistry = make(map[string]*OperationFactory)
opRegistryLock sync.Mutex
opRegistryLocked = abool.New()
)
// RegisterOpType registers a new operation type and may only be called during
// Go's init and a module's prep phase.
func RegisterOpType(factory OperationFactory) {
// Check if we can still register an operation type.
if opRegistryLocked.IsSet() {
log.Errorf("spn/terminal: failed to register operation %s: operation registry is already locked", factory.Type)
return
}
opRegistryLock.Lock()
defer opRegistryLock.Unlock()
// Check if the operation type was already registered.
if _, ok := opRegistry[factory.Type]; ok {
log.Errorf("spn/terminal: failed to register operation type %s: type already registered", factory.Type)
return
}
// Save to registry.
opRegistry[factory.Type] = &factory
}
func lockOpRegistry() {
opRegistryLocked.Set()
}
func (t *TerminalBase) handleOperationStart(opID uint32, initData *container.Container) {
// Check if the terminal is being abandoned.
if t.Abandoning.IsSet() {
t.StopOperation(newUnknownOp(opID, ""), ErrAbandonedTerminal)
return
}
// Extract the requested operation name.
opType, err := initData.GetNextBlock()
if err != nil {
t.StopOperation(newUnknownOp(opID, ""), ErrMalformedData.With("failed to get init data: %w", err))
return
}
// Get the operation factory from the registry.
factory, ok := opRegistry[string(opType)]
if !ok {
t.StopOperation(newUnknownOp(opID, ""), ErrUnknownOperationType.With(utils.SafeFirst16Bytes(opType)))
return
}
// Check if the Terminal has the required permission to run the operation.
if !t.HasPermission(factory.Requires) {
t.StopOperation(newUnknownOp(opID, factory.Type), ErrPermissionDenied)
return
}
// Get terminal to attach to.
attachToTerminal := t.ext
if attachToTerminal == nil {
attachToTerminal = t
}
// Run the operation.
op, opErr := factory.Start(attachToTerminal, opID, initData)
switch {
case opErr != nil:
// Something went wrong.
t.StopOperation(newUnknownOp(opID, factory.Type), opErr)
case op == nil:
// The Operation was successful and is done already.
log.Debugf("spn/terminal: operation %s %s executed", factory.Type, fmtOperationID(t.parentID, t.id, opID))
t.StopOperation(newUnknownOp(opID, factory.Type), nil)
default:
// The operation started successfully and requires persistence.
t.SetActiveOp(opID, op)
log.Debugf("spn/terminal: operation %s %s started", factory.Type, fmtOperationID(t.parentID, t.id, opID))
}
}
// StartOperation starts the given operation by assigning it an ID and sending the given operation initialization data.
func (t *TerminalBase) StartOperation(op Operation, initData *container.Container, timeout time.Duration) *Error {
// Get terminal to attach to.
attachToTerminal := t.ext
if attachToTerminal == nil {
attachToTerminal = t
}
// Get the next operation ID and set it on the operation with the terminal.
op.InitOperationBase(attachToTerminal, atomic.AddUint32(t.nextOpID, 8))
// Always add operation to the active operations, as we need to receive a
// reply in any case.
t.SetActiveOp(op.ID(), op)
log.Debugf("spn/terminal: operation %s %s started", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()))
// Add or create the operation type block.
if initData == nil {
initData = container.New()
initData.AppendAsBlock([]byte(op.Type()))
} else {
initData.PrependAsBlock([]byte(op.Type()))
}
// Create init msg.
msg := NewEmptyMsg()
msg.FlowID = op.ID()
msg.Type = MsgTypeInit
msg.Data = initData
msg.Unit.MakeHighPriority()
// Send init msg.
err := op.Send(msg, timeout)
if err != nil {
msg.Finish()
}
return err
}
// Send sends data via this terminal.
// If a timeout is set, sending will fail after the given timeout passed.
func (t *TerminalBase) Send(msg *Msg, timeout time.Duration) *Error {
// Wait for processing slot.
msg.Unit.WaitForSlot()
// Check if the send queue has available space.
select {
case t.sendQueue <- msg:
return nil
default:
}
// Submit message to buffer, if space is available.
select {
case t.sendQueue <- msg:
return nil
case <-TimedOut(timeout):
msg.Finish()
return ErrTimeout.With("sending via terminal")
case <-t.Ctx().Done():
msg.Finish()
return ErrStopping
}
}
// StopOperation sends the end signal with an optional error and then deletes
// the operation from the Terminal state and calls HandleStop() on the Operation.
func (t *TerminalBase) StopOperation(op Operation, err *Error) {
// Check if the operation has already stopped.
if !op.markStopped() {
return
}
// Log reason the Operation is ending. Override stopping error with nil.
switch {
case err == nil:
log.Debugf("spn/terminal: operation %s %s stopped", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()))
case err.IsOK(), err.Is(ErrTryAgainLater), err.Is(ErrRateLimited):
log.Debugf("spn/terminal: operation %s %s stopped: %s", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()), err)
default:
log.Warningf("spn/terminal: operation %s %s failed: %s", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()), err)
}
module.StartWorker("stop operation", func(_ context.Context) error {
// Call operation stop handle function for proper shutdown cleaning up.
err = op.HandleStop(err)
// Send error to the connected Operation, if the error is internal.
if !err.IsExternal() {
if err == nil {
err = ErrStopping
}
msg := NewMsg(err.Pack())
msg.FlowID = op.ID()
msg.Type = MsgTypeStop
tErr := t.Send(msg, 10*time.Second)
if tErr != nil {
msg.Finish()
log.Warningf("spn/terminal: failed to send stop msg: %s", tErr)
}
}
// Remove operation from terminal.
t.DeleteActiveOp(op.ID())
return nil
})
}
// GetActiveOp returns the active operation with the given ID from the
// Terminal state.
func (t *TerminalBase) GetActiveOp(opID uint32) (op Operation, ok bool) {
t.lock.RLock()
defer t.lock.RUnlock()
op, ok = t.operations[opID]
return
}
// SetActiveOp saves an active operation to the Terminal state.
func (t *TerminalBase) SetActiveOp(opID uint32, op Operation) {
t.lock.Lock()
defer t.lock.Unlock()
t.operations[opID] = op
}
// DeleteActiveOp deletes an active operation from the Terminal state.
func (t *TerminalBase) DeleteActiveOp(opID uint32) {
t.lock.Lock()
defer t.lock.Unlock()
delete(t.operations, opID)
}
// GetActiveOpCount returns the amount of active operations.
func (t *TerminalBase) GetActiveOpCount() int {
t.lock.RLock()
defer t.lock.RUnlock()
return len(t.operations)
}
func newUnknownOp(id uint32, typeID string) *unknownOp {
op := &unknownOp{
typeID: typeID,
}
op.id = id
return op
}
type unknownOp struct {
OperationBase
typeID string
}
func (op *unknownOp) Type() string {
if op.typeID != "" {
return op.typeID
}
return "unknown"
}
func (op *unknownOp) Deliver(msg *Msg) *Error {
return ErrIncorrectUsage.With("unknown op shim cannot receive")
}

View File

@@ -0,0 +1,185 @@
package terminal
import (
"time"
"github.com/tevino/abool"
)
// OperationBase provides the basic operation functionality.
type OperationBase struct {
terminal Terminal
id uint32
stopped abool.AtomicBool
}
// InitOperationBase initialize the operation with the ID and attached terminal.
// Should not be overridden by implementations.
func (op *OperationBase) InitOperationBase(t Terminal, opID uint32) {
op.id = opID
op.terminal = t
}
// ID returns the ID of the operation.
// Should not be overridden by implementations.
func (op *OperationBase) ID() uint32 {
return op.id
}
// Type returns the operation's type ID.
// Should be overridden by implementations to return correct type ID.
func (op *OperationBase) Type() string {
return "unknown"
}
// Deliver delivers a message to the operation.
// Meant to be overridden by implementations.
func (op *OperationBase) Deliver(_ *Msg) *Error {
return ErrIncorrectUsage.With("Deliver not implemented for this operation")
}
// NewMsg creates a new message from this operation.
// Should not be overridden by implementations.
func (op *OperationBase) NewMsg(data []byte) *Msg {
msg := NewMsg(data)
msg.FlowID = op.id
msg.Type = MsgTypeData
// Debug unit leaks.
msg.debugWithCaller(2)
return msg
}
// NewEmptyMsg creates a new empty message from this operation.
// Should not be overridden by implementations.
func (op *OperationBase) NewEmptyMsg() *Msg {
msg := NewEmptyMsg()
msg.FlowID = op.id
msg.Type = MsgTypeData
// Debug unit leaks.
msg.debugWithCaller(2)
return msg
}
// Send sends a message to the other side.
// Should not be overridden by implementations.
func (op *OperationBase) Send(msg *Msg, timeout time.Duration) *Error {
// Add and update metadata.
msg.FlowID = op.id
if msg.Type == MsgTypeData && msg.Unit.IsHighPriority() && UsePriorityDataMsgs {
msg.Type = MsgTypePriorityData
}
// Wait for processing slot.
msg.Unit.WaitForSlot()
// Send message.
tErr := op.terminal.Send(msg, timeout)
if tErr != nil {
// Finish message unit on failure.
msg.Finish()
}
return tErr
}
// Flush sends all messages waiting in the terminal.
// Meant to be overridden by implementations.
func (op *OperationBase) Flush(timeout time.Duration) {
op.terminal.Flush(timeout)
}
// Stopped returns whether the operation has stopped.
// Should not be overridden by implementations.
func (op *OperationBase) Stopped() bool {
return op.stopped.IsSet()
}
// markStopped marks the operation as stopped.
// It returns whether the stop flag was set.
func (op *OperationBase) markStopped() bool {
return op.stopped.SetToIf(false, true)
}
// Stop stops the operation by unregistering it from the terminal and calling HandleStop().
// Should not be overridden by implementations.
func (op *OperationBase) Stop(self Operation, err *Error) {
// Stop operation from terminal.
op.terminal.StopOperation(self, err)
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
// Meant to be overridden by implementations.
func (op *OperationBase) HandleStop(err *Error) (errorToSend *Error) {
return err
}
// Terminal returns the terminal the operation is linked to.
// Should not be overridden by implementations.
func (op *OperationBase) Terminal() Terminal {
return op.terminal
}
// OneOffOperationBase is an operation base for operations that just have one
// message and a error return.
type OneOffOperationBase struct {
OperationBase
Result chan *Error
}
// Init initializes the single operation base.
func (op *OneOffOperationBase) Init() {
op.Result = make(chan *Error, 1)
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *OneOffOperationBase) HandleStop(err *Error) (errorToSend *Error) {
select {
case op.Result <- err:
default:
}
return err
}
// MessageStreamOperationBase is an operation base for receiving a message stream.
// Every received message must be finished by the implementing operation.
type MessageStreamOperationBase struct {
OperationBase
Delivered chan *Msg
Ended chan *Error
}
// Init initializes the operation base.
func (op *MessageStreamOperationBase) Init(deliverQueueSize int) {
op.Delivered = make(chan *Msg, deliverQueueSize)
op.Ended = make(chan *Error, 1)
}
// Deliver delivers data to the operation.
func (op *MessageStreamOperationBase) Deliver(msg *Msg) *Error {
select {
case op.Delivered <- msg:
return nil
default:
return ErrIncorrectUsage.With("request was not waiting for data")
}
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *MessageStreamOperationBase) HandleStop(err *Error) (errorToSend *Error) {
select {
case op.Ended <- err:
default:
}
return err
}

View File

@@ -0,0 +1,255 @@
package terminal
import (
"context"
"fmt"
"sync"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/log"
)
// CounterOpType is the type ID for the Counter Operation.
const CounterOpType string = "debug/count"
// CounterOp sends increasing numbers on both sides.
type CounterOp struct { //nolint:maligned
OperationBase
wg sync.WaitGroup
server bool
opts *CounterOpts
counterLock sync.Mutex
ClientCounter uint64
ServerCounter uint64
Error error
}
// CounterOpts holds the options for CounterOp.
type CounterOpts struct {
ClientCountTo uint64
ServerCountTo uint64
Wait time.Duration
Flush bool
suppressWorker bool
}
func init() {
RegisterOpType(OperationFactory{
Type: CounterOpType,
Start: startCounterOp,
})
}
// NewCounterOp returns a new CounterOp.
func NewCounterOp(t Terminal, opts CounterOpts) (*CounterOp, *Error) {
// Create operation.
op := &CounterOp{
opts: &opts,
}
op.wg.Add(1)
// Create argument container.
data, err := dsd.Dump(op.opts, dsd.JSON)
if err != nil {
return nil, ErrInternalError.With("failed to pack options: %w", err)
}
// Initialize operation.
tErr := t.StartOperation(op, container.New(data), 3*time.Second)
if tErr != nil {
return nil, tErr
}
// Start worker if needed.
if op.getRemoteCounterTarget() > 0 && !op.opts.suppressWorker {
module.StartWorker("counter sender", op.CounterWorker)
}
return op, nil
}
func startCounterOp(t Terminal, opID uint32, data *container.Container) (Operation, *Error) {
// Create operation.
op := &CounterOp{
server: true,
}
op.InitOperationBase(t, opID)
op.wg.Add(1)
// Parse arguments.
opts := &CounterOpts{}
_, err := dsd.Load(data.CompileData(), opts)
if err != nil {
return nil, ErrInternalError.With("failed to unpack options: %w", err)
}
op.opts = opts
// Start worker if needed.
if op.getRemoteCounterTarget() > 0 {
module.StartWorker("counter sender", op.CounterWorker)
}
return op, nil
}
// Type returns the operation's type ID.
func (op *CounterOp) Type() string {
return CounterOpType
}
func (op *CounterOp) getCounter(sending, increase bool) uint64 {
op.counterLock.Lock()
defer op.counterLock.Unlock()
// Use server counter, when op is server or for sending, but not when both.
if op.server != sending {
if increase {
op.ServerCounter++
}
return op.ServerCounter
}
if increase {
op.ClientCounter++
}
return op.ClientCounter
}
func (op *CounterOp) getRemoteCounterTarget() uint64 {
if op.server {
return op.opts.ClientCountTo
}
return op.opts.ServerCountTo
}
func (op *CounterOp) isDone() bool {
op.counterLock.Lock()
defer op.counterLock.Unlock()
return op.ClientCounter >= op.opts.ClientCountTo &&
op.ServerCounter >= op.opts.ServerCountTo
}
// Deliver delivers data to the operation.
func (op *CounterOp) Deliver(msg *Msg) *Error {
defer msg.Finish()
nextStep, err := msg.Data.GetNextN64()
if err != nil {
op.Stop(op, ErrMalformedData.With("failed to parse next number: %w", err))
return nil
}
// Count and compare.
counter := op.getCounter(false, true)
// Debugging:
// if counter < 100 ||
// counter < 1000 && counter%100 == 0 ||
// counter < 10000 && counter%1000 == 0 ||
// counter < 100000 && counter%10000 == 0 ||
// counter < 1000000 && counter%100000 == 0 {
// log.Errorf("spn/terminal: counter %s>%d recvd, now at %d", op.t.FmtID(), op.id, counter)
// }
if counter != nextStep {
log.Warningf(
"terminal: integrity of counter op violated: received %d, expected %d",
nextStep,
counter,
)
op.Stop(op, ErrIntegrity.With("counters mismatched"))
return nil
}
// Check if we are done.
if op.isDone() {
op.Stop(op, nil)
}
return nil
}
// HandleStop handles stopping the operation.
func (op *CounterOp) HandleStop(err *Error) (errorToSend *Error) {
// Check if counting finished.
if !op.isDone() {
err := fmt.Errorf(
"counter op %d: did not finish counting (%d<-%d %d->%d)",
op.id,
op.opts.ClientCountTo, op.ClientCounter,
op.ServerCounter, op.opts.ServerCountTo,
)
op.Error = err
}
op.wg.Done()
return err
}
// SendCounter sends the next counter.
func (op *CounterOp) SendCounter() *Error {
if op.Stopped() {
return ErrStopping
}
// Increase sending counter.
counter := op.getCounter(true, true)
// Debugging:
// if counter < 100 ||
// counter < 1000 && counter%100 == 0 ||
// counter < 10000 && counter%1000 == 0 ||
// counter < 100000 && counter%10000 == 0 ||
// counter < 1000000 && counter%100000 == 0 {
// defer log.Errorf("spn/terminal: counter %s>%d sent, now at %d", op.t.FmtID(), op.id, counter)
// }
return op.Send(op.NewMsg(varint.Pack64(counter)), 3*time.Second)
}
// Wait waits for the Counter Op to finish.
func (op *CounterOp) Wait() {
op.wg.Wait()
}
// CounterWorker is a worker that sends counters.
func (op *CounterOp) CounterWorker(ctx context.Context) error {
for {
// Send counter msg.
err := op.SendCounter()
switch err {
case nil:
// All good, continue.
case ErrStopping:
// Done!
return nil
default:
// Something went wrong.
err := fmt.Errorf("counter op %d: failed to send counter: %w", op.id, err)
op.Error = err
op.Stop(op, ErrInternalError.With(err.Error()))
return nil
}
// Maybe flush message.
if op.opts.Flush {
op.terminal.Flush(1 * time.Second)
}
// Check if we are done with sending.
if op.getCounter(true, false) >= op.getRemoteCounterTarget() {
return nil
}
// Maybe wait a little.
if op.opts.Wait > 0 {
time.Sleep(op.opts.Wait)
}
}
}

View File

@@ -0,0 +1,50 @@
package terminal
// Permission is a bit-map of granted permissions.
type Permission uint16
// Permissions.
const (
NoPermission Permission = 0x0
MayExpand Permission = 0x1
MayConnect Permission = 0x2
IsHubOwner Permission = 0x100
IsHubAdvisor Permission = 0x200
IsCraneController Permission = 0x8000
)
// AuthorizingTerminal is an interface for terminals that support authorization.
type AuthorizingTerminal interface {
GrantPermission(grant Permission)
HasPermission(required Permission) bool
}
// GrantPermission grants the specified permissions to the Terminal.
func (t *TerminalBase) GrantPermission(grant Permission) {
t.lock.Lock()
defer t.lock.Unlock()
t.permission |= grant
}
// HasPermission returns if the Terminal has the specified permission.
func (t *TerminalBase) HasPermission(required Permission) bool {
t.lock.RLock()
defer t.lock.RUnlock()
return t.permission.Has(required)
}
// Has returns if the permission includes the specified permission.
func (p Permission) Has(required Permission) bool {
return p&required == required
}
// AddPermissions combines multiple permissions.
func AddPermissions(perms ...Permission) Permission {
var all Permission
for _, p := range perms {
all |= p
}
return all
}

View File

@@ -0,0 +1,39 @@
package terminal
import "time"
// RateLimiter is a data flow rate limiter.
type RateLimiter struct {
maxBytesPerSlot uint64
slotBytes uint64
slotStarted time.Time
}
// NewRateLimiter returns a new rate limiter.
// The given MBit/s are transformed to bytes, so giving a multiple of 8 is
// advised for accurate results.
func NewRateLimiter(mbits uint64) *RateLimiter {
return &RateLimiter{
maxBytesPerSlot: (mbits / 8) * 1_000_000,
slotStarted: time.Now(),
}
}
// Limit is given the current transferred bytes and blocks until they may be sent.
func (rl *RateLimiter) Limit(xferBytes uint64) {
// Check if we need to limit transfer if we go over to max bytes per slot.
if rl.slotBytes > rl.maxBytesPerSlot {
// Wait if we are still within the slot.
sinceSlotStart := time.Since(rl.slotStarted)
if sinceSlotStart < time.Second {
time.Sleep(time.Second - sinceSlotStart)
}
// Reset state for next slot.
rl.slotBytes = 0
rl.slotStarted = time.Now()
}
// Add new bytes after checking, as first step over the limit is fully using the limit.
rl.slotBytes += xferBytes
}

166
spn/terminal/session.go Normal file
View File

@@ -0,0 +1,166 @@
package terminal
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
)
const (
rateLimitMinOps = 250
rateLimitMaxOpsPerSecond = 5
rateLimitMinSuspicion = 25
rateLimitMinPermaSuspicion = rateLimitMinSuspicion * 100
rateLimitMaxSuspicionPerSecond = 1
// Make this big enough to trigger suspicion limit in first blast.
concurrencyPoolSize = 30
)
// Session holds terminal metadata for operations.
type Session struct {
sync.RWMutex
// Rate Limiting.
// started holds the unix timestamp in seconds when the session was started.
// It is set when the Session is created and may be treated as a constant.
started int64
// opCount is the amount of operations started (and not rate limited by suspicion).
opCount atomic.Int64
// suspicionScore holds a score of suspicious activity.
// Every suspicious operations is counted as at least 1.
// Rate limited operations because of suspicion are also counted as 1.
suspicionScore atomic.Int64
concurrencyPool chan struct{}
}
// SessionTerminal is an interface for terminals that support authorization.
type SessionTerminal interface {
GetSession() *Session
}
// SessionAddOn can be inherited by terminals to add support for sessions.
type SessionAddOn struct {
lock sync.Mutex
// session holds the terminal session.
session *Session
}
// GetSession returns the terminal's session.
func (t *SessionAddOn) GetSession() *Session {
t.lock.Lock()
defer t.lock.Unlock()
// Create session if it does not exist.
if t.session == nil {
t.session = NewSession()
}
return t.session
}
// NewSession returns a new session.
func NewSession() *Session {
return &Session{
started: time.Now().Unix() - 1, // Ensure a 1 second difference to current time.
concurrencyPool: make(chan struct{}, concurrencyPoolSize),
}
}
// RateLimitInfo returns some basic information about the status of the rate limiter.
func (s *Session) RateLimitInfo() string {
secondsActive := time.Now().Unix() - s.started
return fmt.Sprintf(
"%do/s %ds/s %ds",
s.opCount.Load()/secondsActive,
s.suspicionScore.Load()/secondsActive,
secondsActive,
)
}
// RateLimit enforces a rate and suspicion limit.
func (s *Session) RateLimit() *Error {
secondsActive := time.Now().Unix() - s.started
// Check the suspicion limit.
score := s.suspicionScore.Load()
if score > rateLimitMinSuspicion {
scorePerSecond := score / secondsActive
if scorePerSecond >= rateLimitMaxSuspicionPerSecond {
// Add current try to suspicion score.
s.suspicionScore.Add(1)
return ErrRateLimited
}
// Permanently rate limit if suspicion goes over the perma min limit and
// the suspicion score is greater than 80% of the operation count.
if score > rateLimitMinPermaSuspicion &&
score*5 > s.opCount.Load()*4 { // Think: 80*5 == 100*4
return ErrRateLimited
}
}
// Check the rate limit.
count := s.opCount.Add(1)
if count > rateLimitMinOps {
opsPerSecond := count / secondsActive
if opsPerSecond >= rateLimitMaxOpsPerSecond {
return ErrRateLimited
}
}
return nil
}
// Suspicion Factors.
const (
SusFactorCommon = 1
SusFactorWeirdButOK = 5
SusFactorQuiteUnusual = 10
SusFactorMustBeMalicious = 100
)
// ReportSuspiciousActivity reports suspicious activity of the terminal.
func (s *Session) ReportSuspiciousActivity(factor int64) {
s.suspicionScore.Add(factor)
}
// LimitConcurrency limits concurrent executions.
// If over the limit, waiting goroutines are selected randomly.
// It returns the context error if it was canceled.
func (s *Session) LimitConcurrency(ctx context.Context, f func()) error {
// Wait for place in pool.
select {
case <-ctx.Done():
return ctx.Err()
case s.concurrencyPool <- struct{}{}:
// We added our entry to the pool, continue with execution.
}
// Drain own spot if pool after execution.
defer func() {
select {
case <-s.concurrencyPool:
// Own entry drained.
default:
// This should never happen, but let's play safe and not deadlock when pool is empty.
log.Warningf("spn/session: failed to drain own entry from concurrency pool")
}
}()
// Execute and return.
f()
return nil
}

View File

@@ -0,0 +1,94 @@
package terminal
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestRateLimit(t *testing.T) {
t.Parallel()
var tErr *Error
s := NewSession()
// Everything should be okay within the min limit.
for i := 0; i < rateLimitMinOps; i++ {
tErr = s.RateLimit()
if tErr != nil {
t.Error("should not rate limit within min limit")
}
}
// Somewhere here we should rate limiting.
for i := 0; i < rateLimitMaxOpsPerSecond; i++ {
tErr = s.RateLimit()
}
assert.ErrorIs(t, tErr, ErrRateLimited, "should rate limit")
}
func TestSuspicionLimit(t *testing.T) {
t.Parallel()
var tErr *Error
s := NewSession()
// Everything should be okay within the min limit.
for i := 0; i < rateLimitMinSuspicion; i++ {
tErr = s.RateLimit()
if tErr != nil {
t.Error("should not rate limit within min limit")
}
s.ReportSuspiciousActivity(SusFactorCommon)
}
// Somewhere here we should rate limiting.
for i := 0; i < rateLimitMaxSuspicionPerSecond; i++ {
s.ReportSuspiciousActivity(SusFactorCommon)
tErr = s.RateLimit()
}
if tErr == nil {
t.Error("should rate limit")
}
}
func TestConcurrencyLimit(t *testing.T) {
t.Parallel()
s := NewSession()
started := time.Now()
wg := sync.WaitGroup{}
workTime := 1 * time.Millisecond
workers := concurrencyPoolSize * 10
// Start many workers to test concurrency.
wg.Add(workers)
for i := 0; i < workers; i++ {
workerNum := i
go func() {
defer func() {
_ = recover()
}()
_ = s.LimitConcurrency(context.Background(), func() {
time.Sleep(workTime)
wg.Done()
// Panic sometimes.
if workerNum%concurrencyPoolSize == 0 {
panic("test")
}
})
}()
}
// Wait and check time needed.
wg.Wait()
if time.Since(started) < (time.Duration(workers) * workTime / concurrencyPoolSize) {
t.Errorf("workers were too quick - only took %s", time.Since(started))
} else {
t.Logf("workers were correctly limited - took %s", time.Since(started))
}
}

909
spn/terminal/terminal.go Normal file
View File

@@ -0,0 +1,909 @@
package terminal
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/tevino/abool"
"github.com/safing/jess"
"github.com/safing/portbase/container"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/rng"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/conf"
)
const (
timeoutTicks = 5
clientTerminalAbandonTimeout = 15 * time.Second
serverTerminalAbandonTimeout = 5 * time.Minute
)
// Terminal represents a terminal.
type Terminal interface { //nolint:golint // Being explicit is helpful here.
// ID returns the terminal ID.
ID() uint32
// Ctx returns the terminal context.
Ctx() context.Context
// Deliver delivers a message to the terminal.
// Should not be overridden by implementations.
Deliver(msg *Msg) *Error
// Send is used by others to send a message through the terminal.
// Should not be overridden by implementations.
Send(msg *Msg, timeout time.Duration) *Error
// Flush sends all messages waiting in the terminal.
// Should not be overridden by implementations.
Flush(timeout time.Duration)
// StartOperation starts the given operation by assigning it an ID and sending the given operation initialization data.
// Should not be overridden by implementations.
StartOperation(op Operation, initData *container.Container, timeout time.Duration) *Error
// StopOperation stops the given operation.
// Should not be overridden by implementations.
StopOperation(op Operation, err *Error)
// Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon().
// Should not be overridden by implementations.
Abandon(err *Error)
// HandleAbandon gives the terminal the ability to cleanly shut down.
// The terminal is still fully functional at this point.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Abandon() instead.
// Meant to be overridden by implementations.
HandleAbandon(err *Error) (errorToSend *Error)
// HandleDestruction gives the terminal the ability to clean up.
// The terminal has already fully shut down at this point.
// Should never be called directly. Call Abandon() instead.
// Meant to be overridden by implementations.
HandleDestruction(err *Error)
// FmtID formats the terminal ID (including parent IDs).
// May be overridden by implementations.
FmtID() string
}
// TerminalBase contains the basic functions of a terminal.
type TerminalBase struct { //nolint:golint,maligned // Being explicit is helpful here.
// TODO: Fix maligned.
Terminal // Interface check.
lock sync.RWMutex
// id is the underlying id of the Terminal.
id uint32
// parentID is the id of the parent component.
parentID string
// ext holds the extended terminal so that the base terminal can access custom functions.
ext Terminal
// sendQueue holds message to be sent.
sendQueue chan *Msg
// flowControl holds the flow control system.
flowControl FlowControl
// upstream represents the upstream (parent) terminal.
upstream Upstream
// deliverProxy is populated with the configured deliver function
deliverProxy func(msg *Msg) *Error
// recvProxy is populated with the configured recv function
recvProxy func() <-chan *Msg
// ctx is the context of the Terminal.
ctx context.Context
// cancelCtx cancels ctx.
cancelCtx context.CancelFunc
// waitForFlush signifies if sending should be delayed until the next call
// to Flush()
waitForFlush *abool.AtomicBool
// flush is used to send a finish function to the handler, which will write
// all pending messages and then call the received function.
flush chan func()
// idleTicker ticks for increasing and checking the idle counter.
idleTicker *time.Ticker
// idleCounter counts the ticks the terminal has been idle.
idleCounter *uint32
// jession is the jess session used for encryption.
jession *jess.Session
// jessionLock locks jession.
jessionLock sync.Mutex
// encryptionReady is set when the encryption is ready for sending messages.
encryptionReady chan struct{}
// identity is the identity used by a remote Terminal.
identity *cabin.Identity
// operations holds references to all active operations that require persistence.
operations map[uint32]Operation
// nextOpID holds the next operation ID.
nextOpID *uint32
// permission holds the permissions of the terminal.
permission Permission
// opts holds the terminal options. It must not be modified after the terminal
// has started.
opts *TerminalOpts
// lastUnknownOpID holds the operation ID of the last data message received
// for an unknown operation ID.
lastUnknownOpID uint32
// lastUnknownOpMsgs holds the amount of continuous data messages received
// for the operation ID in lastUnknownOpID.
lastUnknownOpMsgs uint32
// Abandoning indicates if the Terminal is being abandoned. The main handlers
// will keep running until the context has been canceled by the abandon
// procedure.
// No new operations should be started.
// Whoever initiates the abandoning must also start the abandon procedure.
Abandoning *abool.AtomicBool
}
func createTerminalBase(
ctx context.Context,
id uint32,
parentID string,
remote bool,
initMsg *TerminalOpts,
upstream Upstream,
) (*TerminalBase, *Error) {
t := &TerminalBase{
id: id,
parentID: parentID,
sendQueue: make(chan *Msg),
upstream: upstream,
waitForFlush: abool.New(),
flush: make(chan func()),
idleTicker: time.NewTicker(time.Minute),
idleCounter: new(uint32),
encryptionReady: make(chan struct{}),
operations: make(map[uint32]Operation),
nextOpID: new(uint32),
opts: initMsg,
Abandoning: abool.New(),
}
// Stop ticking to disable timeout.
t.idleTicker.Stop()
// Shift next operation ID if remote.
if remote {
atomic.AddUint32(t.nextOpID, 4)
}
// Create context.
t.ctx, t.cancelCtx = context.WithCancel(ctx)
// Create flow control.
switch initMsg.FlowControl {
case FlowControlDFQ:
t.flowControl = NewDuplexFlowQueue(t.Ctx(), initMsg.FlowControlSize, t.submitToUpstream)
t.deliverProxy = t.flowControl.Deliver
t.recvProxy = t.flowControl.Receive
case FlowControlNone:
deliver := make(chan *Msg, initMsg.FlowControlSize)
t.deliverProxy = MakeDirectDeliveryDeliverFunc(ctx, deliver)
t.recvProxy = MakeDirectDeliveryRecvFunc(deliver)
case FlowControlDefault:
fallthrough
default:
return nil, ErrInternalError.With("unknown flow control type %d", initMsg.FlowControl)
}
return t, nil
}
// ID returns the Terminal's ID.
func (t *TerminalBase) ID() uint32 {
return t.id
}
// Ctx returns the Terminal's context.
func (t *TerminalBase) Ctx() context.Context {
return t.ctx
}
// SetTerminalExtension sets the Terminal's extension. This function is not
// guarded and may only be used during initialization.
func (t *TerminalBase) SetTerminalExtension(ext Terminal) {
t.ext = ext
}
// SetTimeout sets the Terminal's idle timeout duration.
// It is broken down into slots internally.
func (t *TerminalBase) SetTimeout(d time.Duration) {
t.idleTicker.Reset(d / timeoutTicks)
}
// Deliver on TerminalBase only exists to conform to the interface. It must be
// overridden by an actual implementation.
func (t *TerminalBase) Deliver(msg *Msg) *Error {
// Deliver via configured proxy.
err := t.deliverProxy(msg)
if err != nil {
msg.Finish()
}
return err
}
// StartWorkers starts the necessary workers to operate the Terminal.
func (t *TerminalBase) StartWorkers(m *modules.Module, terminalName string) {
// Start terminal workers.
m.StartWorker(terminalName+" handler", t.Handler)
m.StartWorker(terminalName+" sender", t.Sender)
// Start any flow control workers.
if t.flowControl != nil {
t.flowControl.StartWorkers(m, terminalName)
}
}
const (
sendThresholdLength = 100 // bytes
sendMaxLength = 4000 // bytes
sendThresholdMaxWait = 20 * time.Millisecond
)
// Handler receives and handles messages and must be started as a worker in the
// module where the Terminal is used.
func (t *TerminalBase) Handler(_ context.Context) error {
defer t.Abandon(ErrInternalError.With("handler died"))
var msg *Msg
defer msg.Finish()
for {
select {
case <-t.ctx.Done():
// Call Abandon just in case.
// Normally, only the StopProcedure function should cancel the context.
t.Abandon(nil)
return nil // Controlled worker exit.
case <-t.idleTicker.C:
// If nothing happens for a while, end the session.
if atomic.AddUint32(t.idleCounter, 1) > timeoutTicks {
// Abandon the terminal and reset the counter.
t.Abandon(ErrNoActivity)
atomic.StoreUint32(t.idleCounter, 0)
}
case msg = <-t.recvProxy():
err := t.handleReceive(msg)
if err != nil {
t.Abandon(err.Wrap("failed to handle"))
return nil
}
// Register activity.
atomic.StoreUint32(t.idleCounter, 0)
}
}
}
// submit is used to send message from the terminal to upstream, including
// going through flow control, if configured.
// This function should be used to send message from the terminal to upstream.
func (t *TerminalBase) submit(msg *Msg, timeout time.Duration) {
// Submit directly if no flow control is configured.
if t.flowControl == nil {
t.submitToUpstream(msg, timeout)
return
}
// Hand over to flow control.
err := t.flowControl.Send(msg, timeout)
if err != nil {
msg.Finish()
t.Abandon(err.Wrap("failed to submit to flow control"))
}
}
// submitToUpstream is used to directly submit messages to upstream.
// This function should only be used by the flow control or submit function.
func (t *TerminalBase) submitToUpstream(msg *Msg, timeout time.Duration) {
// Add terminal ID as flow ID.
msg.FlowID = t.ID()
// Debug unit leaks.
msg.debugWithCaller(2)
// Submit to upstream.
err := t.upstream.Send(msg, timeout)
if err != nil {
msg.Finish()
t.Abandon(err.Wrap("failed to submit to upstream"))
}
}
// Sender handles sending messages and must be started as a worker in the
// module where the Terminal is used.
func (t *TerminalBase) Sender(_ context.Context) error {
// Don't send messages, if the encryption is net yet set up.
// The server encryption session is only initialized with the first
// operative message, not on Terminal creation.
if t.opts.Encrypt {
select {
case <-t.ctx.Done():
// Call Abandon just in case.
// Normally, the only the StopProcedure function should cancel the context.
t.Abandon(nil)
return nil // Controlled worker exit.
case <-t.encryptionReady:
}
}
// Be sure to call Stop even in case of sudden death.
defer t.Abandon(ErrInternalError.With("sender died"))
var msgBufferMsg *Msg
var msgBufferLen int
var msgBufferLimitReached bool
var sendMsgs bool
var sendMaxWait *time.Timer
var flushFinished func()
// Finish any current unit when returning.
defer msgBufferMsg.Finish()
// Only receive message when not sending the current msg buffer.
sendQueueOpMsgs := func() <-chan *Msg {
// Don't handle more messages, if the buffer is full.
if msgBufferLimitReached {
return nil
}
return t.sendQueue
}
// Only wait for sending slot when the current msg buffer is ready to be sent.
readyToSend := func() <-chan struct{} {
switch {
case !sendMsgs:
// Wait until there is something to send.
return nil
case t.flowControl != nil:
// Let flow control decide when we are ready.
return t.flowControl.ReadyToSend()
default:
// Always ready.
return ready
}
}
// Calculate current max wait time to send the msg buffer.
getSendMaxWait := func() <-chan time.Time {
if sendMaxWait != nil {
return sendMaxWait.C
}
return nil
}
handling:
for {
select {
case <-t.ctx.Done():
// Call Stop just in case.
// Normally, the only the StopProcedure function should cancel the context.
t.Abandon(nil)
return nil // Controlled worker exit.
case <-t.idleTicker.C:
// If nothing happens for a while, end the session.
if atomic.AddUint32(t.idleCounter, 1) > timeoutTicks {
// Abandon the terminal and reset the counter.
t.Abandon(ErrNoActivity)
atomic.StoreUint32(t.idleCounter, 0)
}
case msg := <-sendQueueOpMsgs():
if msg == nil {
continue handling
}
// Add unit to buffer unit, or use it as new buffer.
if msgBufferMsg != nil {
// Pack, append and finish additional message.
msgBufferMsg.Consume(msg)
} else {
// Pack operation message.
msg.Pack()
// Convert to message of terminal.
msgBufferMsg = msg
msgBufferMsg.FlowID = t.ID()
msgBufferMsg.Type = MsgTypeData
}
msgBufferLen += msg.Data.Length()
// Check if there is enough data to hit the sending threshold.
if msgBufferLen >= sendThresholdLength {
sendMsgs = true
} else if sendMaxWait == nil && t.waitForFlush.IsNotSet() {
sendMaxWait = time.NewTimer(sendThresholdMaxWait)
}
// Check if we have reached the maximum buffer size.
if msgBufferLen >= sendMaxLength {
msgBufferLimitReached = true
}
// Register activity.
atomic.StoreUint32(t.idleCounter, 0)
case <-getSendMaxWait():
// The timer for waiting for more data has ended.
// Send all available data if not forced to wait for a flush.
if t.waitForFlush.IsNotSet() {
sendMsgs = true
}
case newFlushFinishedFn := <-t.flush:
// We are flushing - stop waiting.
t.waitForFlush.UnSet()
// Signal immediately if msg buffer is empty.
if msgBufferLen == 0 {
newFlushFinishedFn()
} else {
// If there already is a flush finished function, stack them.
if flushFinished != nil {
stackedFlushFinishFn := flushFinished
flushFinished = func() {
stackedFlushFinishFn()
newFlushFinishedFn()
}
} else {
flushFinished = newFlushFinishedFn
}
}
// Force sending data now.
sendMsgs = true
case <-readyToSend():
// Reset sending flags.
sendMsgs = false
msgBufferLimitReached = false
// Send if there is anything to send.
var err *Error
if msgBufferLen > 0 {
// Update message type to include priority.
if msgBufferMsg.Type == MsgTypeData &&
msgBufferMsg.Unit.IsHighPriority() &&
t.opts.UsePriorityDataMsgs {
msgBufferMsg.Type = MsgTypePriorityData
}
// Wait for clearance on initial msg only.
msgBufferMsg.Unit.WaitForSlot()
err = t.sendOpMsgs(msgBufferMsg)
}
// Reset buffer.
msgBufferMsg = nil
msgBufferLen = 0
// Reset send wait timer.
if sendMaxWait != nil {
sendMaxWait.Stop()
sendMaxWait = nil
}
// Check if we are flushing and need to notify.
if flushFinished != nil {
flushFinished()
flushFinished = nil
}
// Handle error after state updates.
if err != nil {
t.Abandon(err.With("failed to send"))
continue handling
}
}
}
}
// WaitForFlush makes the terminal pause all sending until the next call to
// Flush().
func (t *TerminalBase) WaitForFlush() {
t.waitForFlush.Set()
}
// Flush sends all data waiting to be sent.
func (t *TerminalBase) Flush(timeout time.Duration) {
// Create channel and function for notifying.
wait := make(chan struct{})
finished := func() {
close(wait)
}
// Request flush and return when stopping.
select {
case t.flush <- finished:
case <-t.Ctx().Done():
return
case <-TimedOut(timeout):
return
}
// Wait for flush to finish and return when stopping.
select {
case <-wait:
case <-t.Ctx().Done():
return
case <-TimedOut(timeout):
return
}
// Flush flow control, if configured.
if t.flowControl != nil {
t.flowControl.Flush(timeout)
}
}
func (t *TerminalBase) encrypt(c *container.Container) (*container.Container, *Error) {
if !t.opts.Encrypt {
return c, nil
}
t.jessionLock.Lock()
defer t.jessionLock.Unlock()
letter, err := t.jession.Close(c.CompileData())
if err != nil {
return nil, ErrIntegrity.With("failed to encrypt: %w", err)
}
encryptedData, err := letter.ToWire()
if err != nil {
return nil, ErrInternalError.With("failed to pack letter: %w", err)
}
return encryptedData, nil
}
func (t *TerminalBase) decrypt(c *container.Container) (*container.Container, *Error) {
if !t.opts.Encrypt {
return c, nil
}
t.jessionLock.Lock()
defer t.jessionLock.Unlock()
letter, err := jess.LetterFromWire(c)
if err != nil {
return nil, ErrMalformedData.With("failed to parse letter: %w", err)
}
// Setup encryption if not yet done.
if t.jession == nil {
if t.identity == nil {
return nil, ErrInternalError.With("missing identity for setting up incoming encryption")
}
// Create jess session.
t.jession, err = letter.WireCorrespondence(t.identity)
if err != nil {
return nil, ErrIntegrity.With("failed to initialize incoming encryption: %w", err)
}
// Don't need that anymore.
t.identity = nil
// Encryption is ready for sending.
close(t.encryptionReady)
}
decryptedData, err := t.jession.Open(letter)
if err != nil {
return nil, ErrIntegrity.With("failed to decrypt: %w", err)
}
return container.New(decryptedData), nil
}
func (t *TerminalBase) handleReceive(msg *Msg) *Error {
msg.Unit.WaitForSlot()
defer msg.Finish()
// Debugging:
// log.Errorf("spn/terminal %s handling tmsg: %s", t.FmtID(), spew.Sdump(c.CompileData()))
// Check if message is empty. This will be the case if a message was only
// for updated the available space of the flow queue.
if !msg.Data.HoldsData() {
return nil
}
// Decrypt if enabled.
var tErr *Error
msg.Data, tErr = t.decrypt(msg.Data)
if tErr != nil {
return tErr
}
// Handle operation messages.
for msg.Data.HoldsData() {
// Get next message length.
msgLength, err := msg.Data.GetNextN32()
if err != nil {
return ErrMalformedData.With("failed to get operation msg length: %w", err)
}
if msgLength == 0 {
// Remainder is padding.
// Padding can only be at the end of the segment.
t.handlePaddingMsg(msg.Data)
return nil
}
// Get op msg data.
msgData, err := msg.Data.GetAsContainer(int(msgLength))
if err != nil {
return ErrMalformedData.With("failed to get operation msg data (%d/%d bytes): %w", msg.Data.Length(), msgLength, err)
}
// Handle op msg.
if handleErr := t.handleOpMsg(msgData); handleErr != nil {
return handleErr
}
}
return nil
}
func (t *TerminalBase) handleOpMsg(data *container.Container) *Error {
// Debugging:
// log.Errorf("spn/terminal %s handling opmsg: %s", t.FmtID(), spew.Sdump(data.CompileData()))
// Parse message operation id, type.
opID, msgType, err := ParseIDType(data)
if err != nil {
return ErrMalformedData.With("failed to parse operation msg id/type: %w", err)
}
switch msgType {
case MsgTypeInit:
t.handleOperationStart(opID, data)
case MsgTypeData, MsgTypePriorityData:
op, ok := t.GetActiveOp(opID)
if ok && !op.Stopped() {
// Create message from data.
msg := NewEmptyMsg()
msg.FlowID = opID
msg.Type = msgType
msg.Data = data
if msg.Type == MsgTypePriorityData {
msg.Unit.MakeHighPriority()
}
// Deliver message to operation.
tErr := op.Deliver(msg)
if tErr != nil {
// Also stop on "success" errors!
msg.Finish()
t.StopOperation(op, tErr)
}
return nil
}
// If an active op is not found, this is likely just left-overs from a
// stopped or failed operation.
// log.Tracef("spn/terminal: %s received data msg for unknown op %d", fmtTerminalID(t.parentID, t.id), opID)
// Send a stop error if this happens too often.
if opID == t.lastUnknownOpID {
// OpID is the same as last time.
t.lastUnknownOpMsgs++
// Log an warning (via StopOperation) and send a stop message every thousand.
if t.lastUnknownOpMsgs%1000 == 0 {
t.StopOperation(newUnknownOp(opID, ""), ErrUnknownOperationID.With("received %d unsolicited data msgs", t.lastUnknownOpMsgs))
}
// TODO: Abandon terminal at over 10000?
} else {
// OpID changed, set new ID and reset counter.
t.lastUnknownOpID = opID
t.lastUnknownOpMsgs = 1
}
case MsgTypeStop:
// Parse received error.
opErr, parseErr := ParseExternalError(data.CompileData())
if parseErr != nil {
log.Warningf("spn/terminal: %s failed to parse stop error: %s", fmtTerminalID(t.parentID, t.id), parseErr)
opErr = ErrUnknownError.AsExternal()
}
// End operation.
op, ok := t.GetActiveOp(opID)
if ok {
t.StopOperation(op, opErr)
} else {
log.Tracef("spn/terminal: %s received stop msg for unknown op %d", fmtTerminalID(t.parentID, t.id), opID)
}
default:
log.Warningf("spn/terminal: %s received unexpected message type: %d", t.FmtID(), msgType)
return ErrUnexpectedMsgType
}
return nil
}
func (t *TerminalBase) handlePaddingMsg(c *container.Container) {
padding := c.GetAll()
if len(padding) > 0 {
rngFeeder.SupplyEntropyIfNeeded(padding, len(padding))
}
}
func (t *TerminalBase) sendOpMsgs(msg *Msg) *Error {
msg.Unit.WaitForSlot()
// Add Padding if needed.
if t.opts.Padding > 0 {
paddingNeeded := (int(t.opts.Padding) - msg.Data.Length()) % int(t.opts.Padding)
if paddingNeeded > 0 {
// Add padding message header.
msg.Data.Append([]byte{0})
paddingNeeded--
// Add needed padding data.
if paddingNeeded > 0 {
padding, err := rng.Bytes(paddingNeeded)
if err != nil {
log.Debugf("spn/terminal: %s failed to get random data, using zeros instead", t.FmtID())
padding = make([]byte, paddingNeeded)
}
msg.Data.Append(padding)
}
}
}
// Encrypt operative data.
var tErr *Error
msg.Data, tErr = t.encrypt(msg.Data)
if tErr != nil {
return tErr
}
// Send data.
t.submit(msg, 0)
return nil
}
// Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon().
// Should not be overridden by implementations.
func (t *TerminalBase) Abandon(err *Error) {
if t.Abandoning.SetToIf(false, true) {
module.StartWorker("terminal abandon procedure", func(_ context.Context) error {
t.handleAbandonProcedure(err)
return nil
})
}
}
// HandleAbandon gives the terminal the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Abandon() instead.
// Meant to be overridden by implementations.
func (t *TerminalBase) HandleAbandon(err *Error) (errorToSend *Error) {
return err
}
// HandleDestruction gives the terminal the ability to clean up.
// The terminal has already fully shut down at this point.
// Should never be called directly. Call Abandon() instead.
// Meant to be overridden by implementations.
func (t *TerminalBase) HandleDestruction(err *Error) {}
func (t *TerminalBase) handleAbandonProcedure(err *Error) {
// End all operations.
for _, op := range t.allOps() {
t.StopOperation(op, nil)
}
// Prepare timeouts for waiting for ops.
timeout := clientTerminalAbandonTimeout
if conf.PublicHub() {
timeout = serverTerminalAbandonTimeout
}
checkTicker := time.NewTicker(50 * time.Millisecond)
defer checkTicker.Stop()
abortWaiting := time.After(timeout)
// Wait for all operations to end.
waitForOps:
for {
select {
case <-checkTicker.C:
if t.GetActiveOpCount() <= 0 {
break waitForOps
}
case <-abortWaiting:
log.Warningf(
"spn/terminal: terminal %s is continuing shutdown with %d active operations",
t.FmtID(),
t.GetActiveOpCount(),
)
break waitForOps
}
}
// Call operation stop handle function for proper shutdown cleaning up.
if t.ext != nil {
err = t.ext.HandleAbandon(err)
}
// Send error to the connected Operation, if the error is internal.
if !err.IsExternal() {
if err == nil {
err = ErrStopping
}
msg := NewMsg(err.Pack())
msg.FlowID = t.ID()
msg.Type = MsgTypeStop
t.submit(msg, 1*time.Second)
}
// If terminal was ended locally, send all data before abandoning.
// If terminal was ended remotely, don't bother sending remaining data.
if !err.IsExternal() {
// Flushing could mean sending a full buffer of 50000 packets.
t.Flush(5 * time.Minute)
}
// Stop all other connected workers.
t.cancelCtx()
t.idleTicker.Stop()
// Call operation destruction handle function for proper shutdown cleaning up.
if t.ext != nil {
t.ext.HandleDestruction(err)
}
}
func (t *TerminalBase) allOps() []Operation {
t.lock.Lock()
defer t.lock.Unlock()
ops := make([]Operation, 0, len(t.operations))
for _, op := range t.operations {
ops = append(ops, op)
}
return ops
}
// MakeDirectDeliveryDeliverFunc creates a submit upstream function with the
// given delivery channel.
func MakeDirectDeliveryDeliverFunc(
ctx context.Context,
deliver chan *Msg,
) func(c *Msg) *Error {
return func(c *Msg) *Error {
select {
case deliver <- c:
return nil
case <-ctx.Done():
return ErrStopping
}
}
}
// MakeDirectDeliveryRecvFunc makes a delivery receive function with the given
// delivery channel.
func MakeDirectDeliveryRecvFunc(
deliver chan *Msg,
) func() <-chan *Msg {
return func() <-chan *Msg {
return deliver
}
}

View File

@@ -0,0 +1,311 @@
package terminal
import (
"fmt"
"os"
"runtime/pprof"
"sync/atomic"
"testing"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/hub"
)
func TestTerminals(t *testing.T) {
t.Parallel()
identity, erro := cabin.CreateIdentity(module.Ctx, "test")
if erro != nil {
t.Fatalf("failed to create identity: %s", erro)
}
// Test without and with encryption.
for _, encrypt := range []bool{false, true} {
// Test with different flow controls.
for _, fc := range []struct {
flowControl FlowControlType
flowControlSize uint32
}{
{
flowControl: FlowControlNone,
flowControlSize: 5,
},
{
flowControl: FlowControlDFQ,
flowControlSize: defaultTestQueueSize,
},
} {
// Run tests with combined options.
testTerminals(t, identity, &TerminalOpts{
Encrypt: encrypt,
Padding: defaultTestPadding,
FlowControl: fc.flowControl,
FlowControlSize: fc.flowControlSize,
})
}
}
}
func testTerminals(t *testing.T, identity *cabin.Identity, terminalOpts *TerminalOpts) {
t.Helper()
// Prepare encryption.
var dstHub *hub.Hub
if terminalOpts.Encrypt {
dstHub = identity.Hub
} else {
identity = nil
}
// Create test terminals.
var term1 *TestTerminal
var term2 *TestTerminal
var initData *container.Container
var err *Error
term1, initData, err = NewLocalTestTerminal(
module.Ctx, 127, "c1", dstHub, terminalOpts, createForwardingUpstream(
t, "c1", "c2", func(msg *Msg) *Error {
return term2.Deliver(msg)
},
),
)
if err != nil {
t.Fatalf("failed to create local terminal: %s", err)
}
term2, _, err = NewRemoteTestTerminal(
module.Ctx, 127, "c2", identity, initData, createForwardingUpstream(
t, "c2", "c1", func(msg *Msg) *Error {
return term1.Deliver(msg)
},
),
)
if err != nil {
t.Fatalf("failed to create remote terminal: %s", err)
}
// Start testing with counters.
countToQueueSize := uint64(terminalOpts.FlowControlSize)
optionsSuffix := fmt.Sprintf(
"encrypt=%v,flowType=%d",
terminalOpts.Encrypt,
terminalOpts.FlowControl,
)
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "onlyup-flushing-waiting:" + optionsSuffix,
flush: true,
serverCountTo: countToQueueSize * 2,
waitBetweenMsgs: sendThresholdMaxWait * 2,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "onlyup-waiting:" + optionsSuffix,
serverCountTo: 10,
waitBetweenMsgs: sendThresholdMaxWait * 2,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "onlyup-flushing:" + optionsSuffix,
flush: true,
serverCountTo: countToQueueSize * 2,
waitBetweenMsgs: time.Millisecond,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "onlyup:" + optionsSuffix,
serverCountTo: countToQueueSize * 2,
waitBetweenMsgs: time.Millisecond,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "onlydown-flushing-waiting:" + optionsSuffix,
flush: true,
clientCountTo: countToQueueSize * 2,
waitBetweenMsgs: sendThresholdMaxWait * 2,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "onlydown-waiting:" + optionsSuffix,
clientCountTo: 10,
waitBetweenMsgs: sendThresholdMaxWait * 2,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "onlydown-flushing:" + optionsSuffix,
flush: true,
clientCountTo: countToQueueSize * 2,
waitBetweenMsgs: time.Millisecond,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "onlydown:" + optionsSuffix,
clientCountTo: countToQueueSize * 2,
waitBetweenMsgs: time.Millisecond,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "twoway-flushing-waiting:" + optionsSuffix,
flush: true,
clientCountTo: countToQueueSize * 2,
serverCountTo: countToQueueSize * 2,
waitBetweenMsgs: sendThresholdMaxWait * 2,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "twoway-waiting:" + optionsSuffix,
flush: true,
clientCountTo: 10,
serverCountTo: 10,
waitBetweenMsgs: sendThresholdMaxWait * 2,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "twoway-flushing:" + optionsSuffix,
flush: true,
clientCountTo: countToQueueSize * 2,
serverCountTo: countToQueueSize * 2,
waitBetweenMsgs: time.Millisecond,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "twoway:" + optionsSuffix,
clientCountTo: countToQueueSize * 2,
serverCountTo: countToQueueSize * 2,
waitBetweenMsgs: time.Millisecond,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "stresstest-down:" + optionsSuffix,
clientCountTo: countToQueueSize * 1000,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "stresstest-up:" + optionsSuffix,
serverCountTo: countToQueueSize * 1000,
})
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
testName: "stresstest-duplex:" + optionsSuffix,
clientCountTo: countToQueueSize * 1000,
serverCountTo: countToQueueSize * 1000,
})
// Clean up.
term1.Abandon(nil)
term2.Abandon(nil)
// Give some time for the last log messages and clean up.
time.Sleep(100 * time.Millisecond)
}
func createForwardingUpstream(t *testing.T, srcName, dstName string, deliverFunc func(*Msg) *Error) Upstream {
t.Helper()
return UpstreamSendFunc(func(msg *Msg, _ time.Duration) *Error {
// Fast track nil containers.
if msg == nil {
dErr := deliverFunc(msg)
if dErr != nil {
t.Errorf("%s>%s: failed to deliver nil msg to terminal: %s", srcName, dstName, dErr)
return dErr.With("failed to deliver nil msg to terminal")
}
return nil
}
// Log messages.
if logTestCraneMsgs {
t.Logf("%s>%s: %v\n", srcName, dstName, msg.Data.CompileData())
}
// Deliver to other terminal.
dErr := deliverFunc(msg)
if dErr != nil {
t.Errorf("%s>%s: failed to deliver to terminal: %s", srcName, dstName, dErr)
return dErr.With("failed to deliver to terminal")
}
return nil
})
}
type testWithCounterOpts struct {
testName string
flush bool
clientCountTo uint64
serverCountTo uint64
waitBetweenMsgs time.Duration
}
func testTerminalWithCounters(t *testing.T, term1, term2 *TestTerminal, opts *testWithCounterOpts) {
t.Helper()
// Wait async for test to complete, print stack after timeout.
finished := make(chan struct{})
maxTestDuration := 60 * time.Second
go func() {
select {
case <-finished:
case <-time.After(maxTestDuration):
fmt.Printf("terminal test %s is taking more than %s, printing stack:\n", opts.testName, maxTestDuration)
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
os.Exit(1)
}
}()
t.Logf("starting terminal counter test %s", opts.testName)
defer t.Logf("stopping terminal counter test %s", opts.testName)
// Start counters.
counter, tErr := NewCounterOp(term1, CounterOpts{
ClientCountTo: opts.clientCountTo,
ServerCountTo: opts.serverCountTo,
Flush: opts.flush,
Wait: opts.waitBetweenMsgs,
})
if tErr != nil {
t.Fatalf("terminal test %s failed to start counter: %s", opts.testName, tErr)
}
// Wait until counters are done.
counter.Wait()
close(finished)
// Check for error.
if counter.Error != nil {
t.Fatalf("terminal test %s failed to count: %s", opts.testName, counter.Error)
}
// Log stats.
printCTStats(t, opts.testName, "term1", term1)
printCTStats(t, opts.testName, "term2", term2)
// Check if stats match, if DFQ is used on both sides.
dfq1, ok1 := term1.flowControl.(*DuplexFlowQueue)
dfq2, ok2 := term2.flowControl.(*DuplexFlowQueue)
if ok1 && ok2 &&
(atomic.LoadInt32(dfq1.sendSpace) != atomic.LoadInt32(dfq2.reportedSpace) ||
atomic.LoadInt32(dfq2.sendSpace) != atomic.LoadInt32(dfq1.reportedSpace)) {
t.Fatalf("terminal test %s has non-matching space counters", opts.testName)
}
}
func printCTStats(t *testing.T, testName, name string, term *TestTerminal) {
t.Helper()
dfq, ok := term.flowControl.(*DuplexFlowQueue)
if !ok {
return
}
t.Logf(
"%s: %s: sq=%d rq=%d sends=%d reps=%d",
testName,
name,
len(dfq.sendQueue),
len(dfq.recvQueue),
atomic.LoadInt32(dfq.sendSpace),
atomic.LoadInt32(dfq.reportedSpace),
)
}

243
spn/terminal/testing.go Normal file
View File

@@ -0,0 +1,243 @@
package terminal
import (
"context"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/hub"
)
const (
defaultTestQueueSize = 16
defaultTestPadding = 8
logTestCraneMsgs = false
)
// TestTerminal is a terminal for running tests.
type TestTerminal struct {
*TerminalBase
}
// NewLocalTestTerminal returns a new local test terminal.
func NewLocalTestTerminal(
ctx context.Context,
id uint32,
parentID string,
remoteHub *hub.Hub,
initMsg *TerminalOpts,
upstream Upstream,
) (*TestTerminal, *container.Container, *Error) {
// Create Terminal Base.
t, initData, err := NewLocalBaseTerminal(ctx, id, parentID, remoteHub, initMsg, upstream)
if err != nil {
return nil, nil, err
}
t.StartWorkers(module, "test terminal")
return &TestTerminal{t}, initData, nil
}
// NewRemoteTestTerminal returns a new remote test terminal.
func NewRemoteTestTerminal(
ctx context.Context,
id uint32,
parentID string,
identity *cabin.Identity,
initData *container.Container,
upstream Upstream,
) (*TestTerminal, *TerminalOpts, *Error) {
// Create Terminal Base.
t, initMsg, err := NewRemoteBaseTerminal(ctx, id, parentID, identity, initData, upstream)
if err != nil {
return nil, nil, err
}
t.StartWorkers(module, "test terminal")
return &TestTerminal{t}, initMsg, nil
}
type delayedMsg struct {
msg *Msg
timeout time.Duration
delayUntil time.Time
}
func createDelayingTestForwardingFunc(
srcName,
dstName string,
delay time.Duration,
delayQueueSize int,
deliverFunc func(msg *Msg, timeout time.Duration) *Error,
) func(msg *Msg, timeout time.Duration) *Error {
// Return simple forward func if no delay is given.
if delay == 0 {
return func(msg *Msg, timeout time.Duration) *Error {
// Deliver to other terminal.
dErr := deliverFunc(msg, timeout)
if dErr != nil {
log.Errorf("spn/testing: %s>%s: failed to deliver to terminal: %s", srcName, dstName, dErr)
return dErr
}
return nil
}
}
// If there is delay, create a delaying channel and handler.
delayedMsgs := make(chan *delayedMsg, delayQueueSize)
go func() {
for {
// Read from chan
msg := <-delayedMsgs
if msg == nil {
return
}
// Check if we need to wait.
waitFor := time.Until(msg.delayUntil)
if waitFor > 0 {
time.Sleep(waitFor)
}
// Deliver to other terminal.
dErr := deliverFunc(msg.msg, msg.timeout)
if dErr != nil {
log.Errorf("spn/testing: %s>%s: failed to deliver to terminal: %s", srcName, dstName, dErr)
}
}
}()
return func(msg *Msg, timeout time.Duration) *Error {
// Add msg to delaying msg channel.
delayedMsgs <- &delayedMsg{
msg: msg,
timeout: timeout,
delayUntil: time.Now().Add(delay),
}
return nil
}
}
// HandleAbandon gives the terminal the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Abandon() instead.
func (t *TestTerminal) HandleAbandon(err *Error) (errorToSend *Error) {
switch err {
case nil:
// nil means that the Terminal is being shutdown by the owner.
log.Tracef("spn/terminal: %s is closing", fmtTerminalID(t.parentID, t.id))
default:
// All other errors are faults.
log.Warningf("spn/terminal: %s: %s", fmtTerminalID(t.parentID, t.id), err)
}
return
}
// NewSimpleTestTerminalPair provides a simple conntected terminal pair for tests.
func NewSimpleTestTerminalPair(delay time.Duration, delayQueueSize int, opts *TerminalOpts) (a, b *TestTerminal, err error) {
if opts == nil {
opts = &TerminalOpts{
Padding: defaultTestPadding,
FlowControl: FlowControlDFQ,
FlowControlSize: defaultTestQueueSize,
}
}
var initData *container.Container
var tErr *Error
a, initData, tErr = NewLocalTestTerminal(
module.Ctx, 127, "a", nil, opts, UpstreamSendFunc(createDelayingTestForwardingFunc(
"a", "b", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error {
return b.Deliver(msg)
},
)),
)
if tErr != nil {
return nil, nil, tErr.Wrap("failed to create local test terminal")
}
b, _, tErr = NewRemoteTestTerminal(
module.Ctx, 127, "b", nil, initData, UpstreamSendFunc(createDelayingTestForwardingFunc(
"b", "a", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error {
return a.Deliver(msg)
},
)),
)
if tErr != nil {
return nil, nil, tErr.Wrap("failed to create remote test terminal")
}
return a, b, nil
}
// BareTerminal is a bare terminal that just returns errors for testing.
type BareTerminal struct{}
var (
_ Terminal = &BareTerminal{}
errNotImplementedByBareTerminal = ErrInternalError.With("not implemented by bare terminal")
)
// ID returns the terminal ID.
func (t *BareTerminal) ID() uint32 {
return 0
}
// Ctx returns the terminal context.
func (t *BareTerminal) Ctx() context.Context {
return context.Background()
}
// Deliver delivers a message to the terminal.
// Should not be overridden by implementations.
func (t *BareTerminal) Deliver(msg *Msg) *Error {
return errNotImplementedByBareTerminal
}
// Send is used by others to send a message through the terminal.
// Should not be overridden by implementations.
func (t *BareTerminal) Send(msg *Msg, timeout time.Duration) *Error {
return errNotImplementedByBareTerminal
}
// Flush sends all messages waiting in the terminal.
// Should not be overridden by implementations.
func (t *BareTerminal) Flush(timeout time.Duration) {}
// StartOperation starts the given operation by assigning it an ID and sending the given operation initialization data.
// Should not be overridden by implementations.
func (t *BareTerminal) StartOperation(op Operation, initData *container.Container, timeout time.Duration) *Error {
return errNotImplementedByBareTerminal
}
// StopOperation stops the given operation.
// Should not be overridden by implementations.
func (t *BareTerminal) StopOperation(op Operation, err *Error) {}
// Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon().
// Should not be overridden by implementations.
func (t *BareTerminal) Abandon(err *Error) {}
// HandleAbandon gives the terminal the ability to cleanly shut down.
// The terminal is still fully functional at this point.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Abandon() instead.
// Meant to be overridden by implementations.
func (t *BareTerminal) HandleAbandon(err *Error) (errorToSend *Error) {
return err
}
// HandleDestruction gives the terminal the ability to clean up.
// The terminal has already fully shut down at this point.
// Should never be called directly. Call Abandon() instead.
// Meant to be overridden by implementations.
func (t *BareTerminal) HandleDestruction(err *Error) {}
// FmtID formats the terminal ID (including parent IDs).
// May be overridden by implementations.
func (t *BareTerminal) FmtID() string {
return "bare"
}

16
spn/terminal/upstream.go Normal file
View File

@@ -0,0 +1,16 @@
package terminal
import "time"
// Upstream defines the interface for upstream (parent) components.
type Upstream interface {
Send(msg *Msg, timeout time.Duration) *Error
}
// UpstreamSendFunc is a helper to be able to satisfy the Upstream interface.
type UpstreamSendFunc func(msg *Msg, timeout time.Duration) *Error
// Send is used to send a message through this upstream.
func (fn UpstreamSendFunc) Send(msg *Msg, timeout time.Duration) *Error {
return fn(msg, timeout)
}