wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
454
spn/terminal/control_flow.go
Normal file
454
spn/terminal/control_flow.go
Normal file
@@ -0,0 +1,454 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
// FlowControl defines the flow control interface.
|
||||
type FlowControl interface {
|
||||
Deliver(msg *Msg) *Error
|
||||
Receive() <-chan *Msg
|
||||
Send(msg *Msg, timeout time.Duration) *Error
|
||||
ReadyToSend() <-chan struct{}
|
||||
Flush(timeout time.Duration)
|
||||
StartWorkers(m *modules.Module, terminalName string)
|
||||
RecvQueueLen() int
|
||||
SendQueueLen() int
|
||||
}
|
||||
|
||||
// FlowControlType represents a flow control type.
|
||||
type FlowControlType uint8
|
||||
|
||||
// Flow Control Types.
|
||||
const (
|
||||
FlowControlDefault FlowControlType = 0
|
||||
FlowControlDFQ FlowControlType = 1
|
||||
FlowControlNone FlowControlType = 2
|
||||
|
||||
defaultFlowControl = FlowControlDFQ
|
||||
)
|
||||
|
||||
// DefaultSize returns the default flow control size.
|
||||
func (fct FlowControlType) DefaultSize() uint32 {
|
||||
if fct == FlowControlDefault {
|
||||
fct = defaultFlowControl
|
||||
}
|
||||
|
||||
switch fct {
|
||||
case FlowControlDFQ:
|
||||
return 50000
|
||||
case FlowControlNone:
|
||||
return 10000
|
||||
case FlowControlDefault:
|
||||
fallthrough
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// Flow Queue Configuration.
|
||||
const (
|
||||
DefaultQueueSize = 50000
|
||||
MaxQueueSize = 1000000
|
||||
forceReportBelowPercent = 0.75
|
||||
)
|
||||
|
||||
// DuplexFlowQueue is a duplex flow control mechanism using queues.
|
||||
type DuplexFlowQueue struct {
|
||||
// ti is the Terminal that is using the DFQ.
|
||||
ctx context.Context
|
||||
|
||||
// submitUpstream is used to submit messages to the upstream channel.
|
||||
submitUpstream func(msg *Msg, timeout time.Duration)
|
||||
|
||||
// sendQueue holds the messages that are waiting to be sent.
|
||||
sendQueue chan *Msg
|
||||
// prioMsgs holds the number of messages to send with high priority.
|
||||
prioMsgs *int32
|
||||
// sendSpace indicates the amount free slots in the recvQueue on the other end.
|
||||
sendSpace *int32
|
||||
// readyToSend is used to notify sending components that there is free space.
|
||||
readyToSend chan struct{}
|
||||
// wakeSender is used to wake a sender in case the sendSpace was zero and the
|
||||
// sender is waiting for available space.
|
||||
wakeSender chan struct{}
|
||||
|
||||
// recvQueue holds the messages that are waiting to be processed.
|
||||
recvQueue chan *Msg
|
||||
// reportedSpace indicates the amount of free slots that the other end knows
|
||||
// about.
|
||||
reportedSpace *int32
|
||||
// spaceReportLock locks the calculation of space to report.
|
||||
spaceReportLock sync.Mutex
|
||||
// forceSpaceReport forces the sender to send a space report.
|
||||
forceSpaceReport chan struct{}
|
||||
|
||||
// flush is used to send a finish function to the handler, which will write
|
||||
// all pending messages and then call the received function.
|
||||
flush chan func()
|
||||
}
|
||||
|
||||
// NewDuplexFlowQueue returns a new duplex flow queue.
|
||||
func NewDuplexFlowQueue(
|
||||
ctx context.Context,
|
||||
queueSize uint32,
|
||||
submitUpstream func(msg *Msg, timeout time.Duration),
|
||||
) *DuplexFlowQueue {
|
||||
dfq := &DuplexFlowQueue{
|
||||
ctx: ctx,
|
||||
submitUpstream: submitUpstream,
|
||||
sendQueue: make(chan *Msg, queueSize),
|
||||
prioMsgs: new(int32),
|
||||
sendSpace: new(int32),
|
||||
readyToSend: make(chan struct{}),
|
||||
wakeSender: make(chan struct{}, 1),
|
||||
recvQueue: make(chan *Msg, queueSize),
|
||||
reportedSpace: new(int32),
|
||||
forceSpaceReport: make(chan struct{}, 1),
|
||||
flush: make(chan func()),
|
||||
}
|
||||
atomic.StoreInt32(dfq.sendSpace, int32(queueSize))
|
||||
atomic.StoreInt32(dfq.reportedSpace, int32(queueSize))
|
||||
|
||||
return dfq
|
||||
}
|
||||
|
||||
// StartWorkers starts the necessary workers to operate the flow queue.
|
||||
func (dfq *DuplexFlowQueue) StartWorkers(m *modules.Module, terminalName string) {
|
||||
m.StartWorker(terminalName+" flow queue", dfq.FlowHandler)
|
||||
}
|
||||
|
||||
// shouldReportRecvSpace returns whether the receive space should be reported.
|
||||
func (dfq *DuplexFlowQueue) shouldReportRecvSpace() bool {
|
||||
return atomic.LoadInt32(dfq.reportedSpace) < int32(float32(cap(dfq.recvQueue))*forceReportBelowPercent)
|
||||
}
|
||||
|
||||
// decrementReportedRecvSpace decreases the reported recv space by 1 and
|
||||
// returns if the receive space should be reported.
|
||||
func (dfq *DuplexFlowQueue) decrementReportedRecvSpace() (shouldReportRecvSpace bool) {
|
||||
return atomic.AddInt32(dfq.reportedSpace, -1) < int32(float32(cap(dfq.recvQueue))*forceReportBelowPercent)
|
||||
}
|
||||
|
||||
// getSendSpace returns the current send space.
|
||||
func (dfq *DuplexFlowQueue) getSendSpace() int32 {
|
||||
return atomic.LoadInt32(dfq.sendSpace)
|
||||
}
|
||||
|
||||
// decrementSendSpace decreases the send space by 1 and returns it.
|
||||
func (dfq *DuplexFlowQueue) decrementSendSpace() int32 {
|
||||
return atomic.AddInt32(dfq.sendSpace, -1)
|
||||
}
|
||||
|
||||
func (dfq *DuplexFlowQueue) addToSendSpace(n int32) {
|
||||
// Add new space to send space and check if it was zero.
|
||||
atomic.AddInt32(dfq.sendSpace, n)
|
||||
// Wake the sender in case it is waiting.
|
||||
select {
|
||||
case dfq.wakeSender <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// reportableRecvSpace returns how much free space can be reported to the other
|
||||
// end. The returned number must be communicated to the other end and must not
|
||||
// be ignored.
|
||||
func (dfq *DuplexFlowQueue) reportableRecvSpace() int32 {
|
||||
// Changes to the recvQueue during calculation are no problem.
|
||||
// We don't want to report space twice though!
|
||||
dfq.spaceReportLock.Lock()
|
||||
defer dfq.spaceReportLock.Unlock()
|
||||
|
||||
// Calculate reportable receive space and add it to the reported space.
|
||||
reportedSpace := atomic.LoadInt32(dfq.reportedSpace)
|
||||
toReport := int32(cap(dfq.recvQueue)-len(dfq.recvQueue)) - reportedSpace
|
||||
|
||||
// Never report values below zero.
|
||||
// This can happen, as dfq.reportedSpace is decreased after a container is
|
||||
// submitted to dfq.recvQueue by dfq.Deliver(). This race condition can only
|
||||
// lower the space to report, not increase it. A simple check here solved
|
||||
// this problem and keeps performance high.
|
||||
// Also, don't report values of 1, as the benefit is minimal and this might
|
||||
// be commonly triggered due to the buffer of the force report channel.
|
||||
if toReport <= 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Add space to report to dfq.reportedSpace and return it.
|
||||
atomic.AddInt32(dfq.reportedSpace, toReport)
|
||||
return toReport
|
||||
}
|
||||
|
||||
// FlowHandler handles all flow queue internals and must be started as a worker
|
||||
// in the module where it is used.
|
||||
func (dfq *DuplexFlowQueue) FlowHandler(_ context.Context) error {
|
||||
// The upstreamSender is started by the terminal module, but is tied to the
|
||||
// flow owner instead. Make sure that the flow owner's module depends on the
|
||||
// terminal module so that it is shut down earlier.
|
||||
|
||||
var sendSpaceDepleted bool
|
||||
var flushFinished func()
|
||||
|
||||
// Drain all queues when shutting down.
|
||||
defer func() {
|
||||
for {
|
||||
select {
|
||||
case msg := <-dfq.sendQueue:
|
||||
msg.Finish()
|
||||
case msg := <-dfq.recvQueue:
|
||||
msg.Finish()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
sending:
|
||||
for {
|
||||
// If the send queue is depleted, wait to be woken.
|
||||
if sendSpaceDepleted {
|
||||
select {
|
||||
case <-dfq.wakeSender:
|
||||
if dfq.getSendSpace() > 0 {
|
||||
sendSpaceDepleted = false
|
||||
} else {
|
||||
continue sending
|
||||
}
|
||||
|
||||
case <-dfq.forceSpaceReport:
|
||||
// Forced reporting of space.
|
||||
// We do not need to check if there is enough sending space, as there is
|
||||
// no data included.
|
||||
spaceToReport := dfq.reportableRecvSpace()
|
||||
if spaceToReport > 0 {
|
||||
msg := NewMsg(varint.Pack64(uint64(spaceToReport)))
|
||||
dfq.submitUpstream(msg, 0)
|
||||
}
|
||||
continue sending
|
||||
|
||||
case <-dfq.ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Get message from send queue.
|
||||
|
||||
select {
|
||||
case dfq.readyToSend <- struct{}{}:
|
||||
// Notify that we are ready to send.
|
||||
|
||||
case msg := <-dfq.sendQueue:
|
||||
// Send message from queue.
|
||||
|
||||
// If nil, the queue is being shut down.
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we are handling a high priority message or waiting for one.
|
||||
// Mark any msgs as high priority, when there is one in the pipeline.
|
||||
remainingPrioMsgs := atomic.AddInt32(dfq.prioMsgs, -1)
|
||||
switch {
|
||||
case remainingPrioMsgs >= 0:
|
||||
msg.Unit.MakeHighPriority()
|
||||
case remainingPrioMsgs < -30_000:
|
||||
// Prevent wrap to positive.
|
||||
// Compatible with int16 or bigger.
|
||||
atomic.StoreInt32(dfq.prioMsgs, 0)
|
||||
}
|
||||
|
||||
// Wait for processing slot.
|
||||
msg.Unit.WaitForSlot()
|
||||
|
||||
// Prepend available receiving space.
|
||||
msg.Data.Prepend(varint.Pack64(uint64(dfq.reportableRecvSpace())))
|
||||
|
||||
// Submit for sending upstream.
|
||||
dfq.submitUpstream(msg, 0)
|
||||
// Decrease the send space and set flag if depleted.
|
||||
if dfq.decrementSendSpace() <= 0 {
|
||||
sendSpaceDepleted = true
|
||||
}
|
||||
|
||||
// Check if the send queue is empty now and signal flushers.
|
||||
if flushFinished != nil && len(dfq.sendQueue) == 0 {
|
||||
flushFinished()
|
||||
flushFinished = nil
|
||||
}
|
||||
|
||||
case <-dfq.forceSpaceReport:
|
||||
// Forced reporting of space.
|
||||
// We do not need to check if there is enough sending space, as there is
|
||||
// no data included.
|
||||
spaceToReport := dfq.reportableRecvSpace()
|
||||
if spaceToReport > 0 {
|
||||
msg := NewMsg(varint.Pack64(uint64(spaceToReport)))
|
||||
dfq.submitUpstream(msg, 0)
|
||||
}
|
||||
|
||||
case newFlushFinishedFn := <-dfq.flush:
|
||||
// Signal immediately if send queue is empty.
|
||||
if len(dfq.sendQueue) == 0 {
|
||||
newFlushFinishedFn()
|
||||
} else {
|
||||
// If there already is a flush finished function, stack them.
|
||||
if flushFinished != nil {
|
||||
stackedFlushFinishFn := flushFinished
|
||||
flushFinished = func() {
|
||||
stackedFlushFinishFn()
|
||||
newFlushFinishedFn()
|
||||
}
|
||||
} else {
|
||||
flushFinished = newFlushFinishedFn
|
||||
}
|
||||
}
|
||||
|
||||
case <-dfq.ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush waits for all waiting data to be sent.
|
||||
func (dfq *DuplexFlowQueue) Flush(timeout time.Duration) {
|
||||
// Create channel and function for notifying.
|
||||
wait := make(chan struct{})
|
||||
finished := func() {
|
||||
close(wait)
|
||||
}
|
||||
// Request flush and return when stopping.
|
||||
select {
|
||||
case dfq.flush <- finished:
|
||||
case <-dfq.ctx.Done():
|
||||
return
|
||||
case <-TimedOut(timeout):
|
||||
return
|
||||
}
|
||||
// Wait for flush to finish and return when stopping.
|
||||
select {
|
||||
case <-wait:
|
||||
case <-dfq.ctx.Done():
|
||||
case <-TimedOut(timeout):
|
||||
}
|
||||
}
|
||||
|
||||
var ready = make(chan struct{})
|
||||
|
||||
func init() {
|
||||
close(ready)
|
||||
}
|
||||
|
||||
// ReadyToSend returns a channel that can be read when data can be sent.
|
||||
func (dfq *DuplexFlowQueue) ReadyToSend() <-chan struct{} {
|
||||
if atomic.LoadInt32(dfq.sendSpace) > 0 {
|
||||
return ready
|
||||
}
|
||||
return dfq.readyToSend
|
||||
}
|
||||
|
||||
// Send adds the given container to the send queue.
|
||||
func (dfq *DuplexFlowQueue) Send(msg *Msg, timeout time.Duration) *Error {
|
||||
select {
|
||||
case dfq.sendQueue <- msg:
|
||||
if msg.Unit.IsHighPriority() {
|
||||
// Reset prioMsgs to the current queue size, so that all waiting and the
|
||||
// message we just added are all handled as high priority.
|
||||
atomic.StoreInt32(dfq.prioMsgs, int32(len(dfq.sendQueue)))
|
||||
}
|
||||
return nil
|
||||
|
||||
case <-TimedOut(timeout):
|
||||
msg.Finish()
|
||||
return ErrTimeout
|
||||
|
||||
case <-dfq.ctx.Done():
|
||||
msg.Finish()
|
||||
return ErrStopping
|
||||
}
|
||||
}
|
||||
|
||||
// Receive receives a container from the recv queue.
|
||||
func (dfq *DuplexFlowQueue) Receive() <-chan *Msg {
|
||||
// If the reported recv space is nearing its end, force a report.
|
||||
if dfq.shouldReportRecvSpace() {
|
||||
select {
|
||||
case dfq.forceSpaceReport <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return dfq.recvQueue
|
||||
}
|
||||
|
||||
// Deliver submits a container for receiving from upstream.
|
||||
func (dfq *DuplexFlowQueue) Deliver(msg *Msg) *Error {
|
||||
// Ignore nil containers.
|
||||
if msg == nil || msg.Data == nil {
|
||||
msg.Finish()
|
||||
return ErrMalformedData.With("no data")
|
||||
}
|
||||
|
||||
// Get and add new reported space.
|
||||
addSpace, err := msg.Data.GetNextN16()
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
return ErrMalformedData.With("failed to parse reported space: %w", err)
|
||||
}
|
||||
if addSpace > 0 {
|
||||
dfq.addToSendSpace(int32(addSpace))
|
||||
}
|
||||
// Abort processing if the container only contained a space update.
|
||||
if !msg.Data.HoldsData() {
|
||||
msg.Finish()
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case dfq.recvQueue <- msg:
|
||||
|
||||
// If the recv queue accepted the Container, decrement the recv space.
|
||||
shouldReportRecvSpace := dfq.decrementReportedRecvSpace()
|
||||
// If the reported recv space is nearing its end, force a report, if the
|
||||
// sender worker is idle.
|
||||
if shouldReportRecvSpace {
|
||||
select {
|
||||
case dfq.forceSpaceReport <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
default:
|
||||
// If the recv queue is full, return an error.
|
||||
// The whole point of the flow queue is to guarantee that this never happens.
|
||||
msg.Finish()
|
||||
return ErrQueueOverflow
|
||||
}
|
||||
}
|
||||
|
||||
// FlowStats returns a k=v formatted string of internal stats.
|
||||
func (dfq *DuplexFlowQueue) FlowStats() string {
|
||||
return fmt.Sprintf(
|
||||
"sq=%d rq=%d sends=%d reps=%d",
|
||||
len(dfq.sendQueue),
|
||||
len(dfq.recvQueue),
|
||||
atomic.LoadInt32(dfq.sendSpace),
|
||||
atomic.LoadInt32(dfq.reportedSpace),
|
||||
)
|
||||
}
|
||||
|
||||
// RecvQueueLen returns the current length of the receive queue.
|
||||
func (dfq *DuplexFlowQueue) RecvQueueLen() int {
|
||||
return len(dfq.recvQueue)
|
||||
}
|
||||
|
||||
// SendQueueLen returns the current length of the send queue.
|
||||
func (dfq *DuplexFlowQueue) SendQueueLen() int {
|
||||
return len(dfq.sendQueue)
|
||||
}
|
||||
36
spn/terminal/defaults.go
Normal file
36
spn/terminal/defaults.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package terminal
|
||||
|
||||
const (
|
||||
// UsePriorityDataMsgs defines whether priority data messages should be used.
|
||||
UsePriorityDataMsgs = true
|
||||
)
|
||||
|
||||
// DefaultCraneControllerOpts returns the default terminal options for a crane
|
||||
// controller terminal.
|
||||
func DefaultCraneControllerOpts() *TerminalOpts {
|
||||
return &TerminalOpts{
|
||||
Padding: 0, // Crane already applies padding.
|
||||
FlowControl: FlowControlNone,
|
||||
UsePriorityDataMsgs: UsePriorityDataMsgs,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHomeHubTerminalOpts returns the default terminal options for a crane
|
||||
// terminal used for the home hub.
|
||||
func DefaultHomeHubTerminalOpts() *TerminalOpts {
|
||||
return &TerminalOpts{
|
||||
Padding: 0, // Crane already applies padding.
|
||||
FlowControl: FlowControlDFQ,
|
||||
UsePriorityDataMsgs: UsePriorityDataMsgs,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultExpansionTerminalOpts returns the default terminal options for an
|
||||
// expansion terminal.
|
||||
func DefaultExpansionTerminalOpts() *TerminalOpts {
|
||||
return &TerminalOpts{
|
||||
Padding: 8,
|
||||
FlowControl: FlowControlDFQ,
|
||||
UsePriorityDataMsgs: UsePriorityDataMsgs,
|
||||
}
|
||||
}
|
||||
221
spn/terminal/errors.go
Normal file
221
spn/terminal/errors.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
)
|
||||
|
||||
// Error is a terminal error.
|
||||
type Error struct {
|
||||
// id holds the internal error ID.
|
||||
id uint8
|
||||
// external signifies if the error was received from the outside.
|
||||
external bool
|
||||
// err holds the wrapped error or the default error message.
|
||||
err error
|
||||
}
|
||||
|
||||
// ID returns the internal ID of the error.
|
||||
func (e *Error) ID() uint8 {
|
||||
return e.id
|
||||
}
|
||||
|
||||
// Error returns the human readable format of the error.
|
||||
func (e *Error) Error() string {
|
||||
if e.external {
|
||||
return "[ext] " + e.err.Error()
|
||||
}
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
// IsExternal returns whether the error occurred externally.
|
||||
func (e *Error) IsExternal() bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return e.external
|
||||
}
|
||||
|
||||
// Is returns whether the given error is of the same type.
|
||||
func (e *Error) Is(target error) bool {
|
||||
if e == nil || target == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
t, ok := target.(*Error) //nolint:errorlint // Error implementation, not usage.
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return e.id == t.id
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapped error.
|
||||
func (e *Error) Unwrap() error {
|
||||
if e == nil || e.err == nil {
|
||||
return nil
|
||||
}
|
||||
return e.err
|
||||
}
|
||||
|
||||
// With adds context and details where the error occurred. The provided
|
||||
// message is appended to the error.
|
||||
// A new error with the same ID is returned and must be compared with
|
||||
// errors.Is().
|
||||
func (e *Error) With(format string, a ...interface{}) *Error {
|
||||
// Return nil if error is nil.
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Error{
|
||||
id: e.id,
|
||||
err: fmt.Errorf(e.Error()+": "+format, a...),
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap adds context higher up in the call chain. The provided message is
|
||||
// prepended to the error.
|
||||
// A new error with the same ID is returned and must be compared with
|
||||
// errors.Is().
|
||||
func (e *Error) Wrap(format string, a ...interface{}) *Error {
|
||||
// Return nil if error is nil.
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Error{
|
||||
id: e.id,
|
||||
err: fmt.Errorf(format+": "+e.Error(), a...),
|
||||
}
|
||||
}
|
||||
|
||||
// AsExternal creates and returns an external version of the error.
|
||||
func (e *Error) AsExternal() *Error {
|
||||
// Return nil if error is nil.
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Error{
|
||||
id: e.id,
|
||||
err: e.err,
|
||||
external: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Pack returns the serialized internal error ID. The additional message is
|
||||
// lost and is replaced with the default message upon parsing.
|
||||
func (e *Error) Pack() []byte {
|
||||
// Return nil slice if error is nil.
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return varint.Pack8(e.id)
|
||||
}
|
||||
|
||||
// ParseExternalError parses an external error.
|
||||
func ParseExternalError(id []byte) (*Error, error) {
|
||||
// Return nil for an empty error.
|
||||
if len(id) == 0 {
|
||||
return ErrStopping.AsExternal(), nil
|
||||
}
|
||||
|
||||
parsedID, _, err := varint.Unpack8(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unpack error ID: %w", err)
|
||||
}
|
||||
|
||||
return NewExternalError(parsedID), nil
|
||||
}
|
||||
|
||||
// NewExternalError creates an external error based on the given ID.
|
||||
func NewExternalError(id uint8) *Error {
|
||||
err, ok := errorRegistry[id]
|
||||
if ok {
|
||||
return err.AsExternal()
|
||||
}
|
||||
|
||||
return ErrUnknownError.AsExternal()
|
||||
}
|
||||
|
||||
var errorRegistry = make(map[uint8]*Error)
|
||||
|
||||
func registerError(id uint8, err error) *Error {
|
||||
// Check for duplicate.
|
||||
_, ok := errorRegistry[id]
|
||||
if ok {
|
||||
panic(fmt.Sprintf("error with id %d already registered", id))
|
||||
}
|
||||
|
||||
newErr := &Error{
|
||||
id: id,
|
||||
err: err,
|
||||
}
|
||||
|
||||
errorRegistry[id] = newErr
|
||||
return newErr
|
||||
}
|
||||
|
||||
// func (e *Error) IsSpecial() bool {
|
||||
// if e == nil {
|
||||
// return false
|
||||
// }
|
||||
// return e.id > 0 && e.id < 8
|
||||
// }
|
||||
|
||||
// IsOK returns if the error represents a "OK" or success status.
|
||||
func (e *Error) IsOK() bool {
|
||||
return !e.IsError()
|
||||
}
|
||||
|
||||
// IsError returns if the error represents an erronous condition.
|
||||
func (e *Error) IsError() bool {
|
||||
if e == nil || e.err == nil {
|
||||
return false
|
||||
}
|
||||
if e.id == 0 || e.id >= 8 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Terminal Errors.
|
||||
var (
|
||||
// ErrUnknownError is the default error.
|
||||
ErrUnknownError = registerError(0, errors.New("unknown error"))
|
||||
|
||||
// Error IDs 1-7 are reserved for special "OK" values.
|
||||
|
||||
ErrStopping = registerError(2, errors.New("stopping"))
|
||||
ErrExplicitAck = registerError(3, errors.New("explicit ack"))
|
||||
ErrNoActivity = registerError(4, errors.New("no activity"))
|
||||
|
||||
// Errors IDs 8 and up are for regular errors.
|
||||
|
||||
ErrInternalError = registerError(8, errors.New("internal error"))
|
||||
ErrMalformedData = registerError(9, errors.New("malformed data"))
|
||||
ErrUnexpectedMsgType = registerError(10, errors.New("unexpected message type"))
|
||||
ErrUnknownOperationType = registerError(11, errors.New("unknown operation type"))
|
||||
ErrUnknownOperationID = registerError(12, errors.New("unknown operation id"))
|
||||
ErrPermissionDenied = registerError(13, errors.New("permission denied"))
|
||||
ErrIntegrity = registerError(14, errors.New("integrity violated"))
|
||||
ErrInvalidOptions = registerError(15, errors.New("invalid options"))
|
||||
ErrHubNotReady = registerError(16, errors.New("hub not ready"))
|
||||
ErrRateLimited = registerError(24, errors.New("rate limited"))
|
||||
ErrIncorrectUsage = registerError(22, errors.New("incorrect usage"))
|
||||
ErrTimeout = registerError(62, errors.New("timed out"))
|
||||
ErrUnsupportedVersion = registerError(93, errors.New("unsupported version"))
|
||||
ErrHubUnavailable = registerError(101, errors.New("hub unavailable"))
|
||||
ErrAbandonedTerminal = registerError(102, errors.New("terminal is being abandoned"))
|
||||
ErrShipSunk = registerError(108, errors.New("ship sunk"))
|
||||
ErrDestinationUnavailable = registerError(113, errors.New("destination unavailable"))
|
||||
ErrTryAgainLater = registerError(114, errors.New("try again later"))
|
||||
ErrConnectionError = registerError(121, errors.New("connection error"))
|
||||
ErrQueueOverflow = registerError(122, errors.New("queue overflowed"))
|
||||
ErrCanceled = registerError(125, context.Canceled)
|
||||
)
|
||||
27
spn/terminal/fmt.go
Normal file
27
spn/terminal/fmt.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package terminal
|
||||
|
||||
import "fmt"
|
||||
|
||||
// CustomTerminalIDFormatting defines an interface for terminal to define their custom ID format.
|
||||
type CustomTerminalIDFormatting interface {
|
||||
CustomIDFormat() string
|
||||
}
|
||||
|
||||
// FmtID formats the terminal ID together with the parent's ID.
|
||||
func (t *TerminalBase) FmtID() string {
|
||||
if t.ext != nil {
|
||||
if customFormatting, ok := t.ext.(CustomTerminalIDFormatting); ok {
|
||||
return customFormatting.CustomIDFormat()
|
||||
}
|
||||
}
|
||||
|
||||
return fmtTerminalID(t.parentID, t.id)
|
||||
}
|
||||
|
||||
func fmtTerminalID(craneID string, terminalID uint32) string {
|
||||
return fmt.Sprintf("%s#%d", craneID, terminalID)
|
||||
}
|
||||
|
||||
func fmtOperationID(craneID string, terminalID, operationID uint32) string {
|
||||
return fmt.Sprintf("%s#%d>%d", craneID, terminalID, operationID)
|
||||
}
|
||||
210
spn/terminal/init.go
Normal file
210
spn/terminal/init.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/safing/jess"
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
Terminal Init Message Format:
|
||||
|
||||
- Version [varint]
|
||||
- Data Block [bytes; not blocked]
|
||||
- TerminalOpts as DSD
|
||||
|
||||
*/
|
||||
|
||||
const (
|
||||
minSupportedTerminalVersion = 1
|
||||
maxSupportedTerminalVersion = 1
|
||||
)
|
||||
|
||||
// TerminalOpts holds configuration for the terminal.
|
||||
type TerminalOpts struct { //nolint:golint,maligned // TODO: Rename.
|
||||
Version uint8 `json:"-"`
|
||||
Encrypt bool `json:"e,omitempty"`
|
||||
Padding uint16 `json:"p,omitempty"`
|
||||
|
||||
FlowControl FlowControlType `json:"fc,omitempty"`
|
||||
FlowControlSize uint32 `json:"qs,omitempty"` // Previously was "QueueSize".
|
||||
|
||||
UsePriorityDataMsgs bool `json:"pr,omitempty"`
|
||||
}
|
||||
|
||||
// ParseTerminalOpts parses terminal options from the container and checks if
|
||||
// they are valid.
|
||||
func ParseTerminalOpts(c *container.Container) (*TerminalOpts, *Error) {
|
||||
// Parse and check version.
|
||||
version, err := c.GetNextN8()
|
||||
if err != nil {
|
||||
return nil, ErrMalformedData.With("failed to parse version: %w", err)
|
||||
}
|
||||
if version < minSupportedTerminalVersion || version > maxSupportedTerminalVersion {
|
||||
return nil, ErrUnsupportedVersion.With("requested terminal version %d", version)
|
||||
}
|
||||
|
||||
// Parse init message.
|
||||
initMsg := &TerminalOpts{}
|
||||
_, err = dsd.Load(c.CompileData(), initMsg)
|
||||
if err != nil {
|
||||
return nil, ErrMalformedData.With("failed to parse init message: %w", err)
|
||||
}
|
||||
initMsg.Version = version
|
||||
|
||||
// Check if options are valid.
|
||||
tErr := initMsg.Check(false)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
return initMsg, nil
|
||||
}
|
||||
|
||||
// Pack serialized the terminal options and checks if they are valid.
|
||||
func (opts *TerminalOpts) Pack() (*container.Container, *Error) {
|
||||
// Check if options are valid.
|
||||
tErr := opts.Check(true)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
// Pack init message.
|
||||
optsData, err := dsd.Dump(opts, dsd.CBOR)
|
||||
if err != nil {
|
||||
return nil, ErrInternalError.With("failed to pack init message: %w", err)
|
||||
}
|
||||
|
||||
// Compile init message.
|
||||
return container.New(
|
||||
varint.Pack8(opts.Version),
|
||||
optsData,
|
||||
), nil
|
||||
}
|
||||
|
||||
// Check checks if terminal options are valid.
|
||||
func (opts *TerminalOpts) Check(useDefaultsForRequired bool) *Error {
|
||||
// Version is required - use default when permitted.
|
||||
if opts.Version == 0 && useDefaultsForRequired {
|
||||
opts.Version = 1
|
||||
}
|
||||
if opts.Version < minSupportedTerminalVersion || opts.Version > maxSupportedTerminalVersion {
|
||||
return ErrInvalidOptions.With("unsupported terminal version %d", opts.Version)
|
||||
}
|
||||
|
||||
// FlowControl is optional.
|
||||
switch opts.FlowControl {
|
||||
case FlowControlDefault:
|
||||
// Set to default flow control.
|
||||
opts.FlowControl = defaultFlowControl
|
||||
case FlowControlNone, FlowControlDFQ:
|
||||
// Ok.
|
||||
default:
|
||||
return ErrInvalidOptions.With("unknown flow control type: %d", opts.FlowControl)
|
||||
}
|
||||
|
||||
// FlowControlSize is required as it needs to be same on both sides.
|
||||
// Use default when permitted.
|
||||
if opts.FlowControlSize == 0 && useDefaultsForRequired {
|
||||
opts.FlowControlSize = opts.FlowControl.DefaultSize()
|
||||
}
|
||||
if opts.FlowControlSize <= 0 || opts.FlowControlSize > MaxQueueSize {
|
||||
return ErrInvalidOptions.With("invalid flow control size of %d", opts.FlowControlSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewLocalBaseTerminal creates a new local terminal base for use with inheriting terminals.
|
||||
func NewLocalBaseTerminal(
|
||||
ctx context.Context,
|
||||
id uint32,
|
||||
parentID string,
|
||||
remoteHub *hub.Hub,
|
||||
initMsg *TerminalOpts,
|
||||
upstream Upstream,
|
||||
) (
|
||||
t *TerminalBase,
|
||||
initData *container.Container,
|
||||
err *Error,
|
||||
) {
|
||||
// Pack, check and add defaults to init message.
|
||||
initData, err = initMsg.Pack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Create baseline.
|
||||
t, err = createTerminalBase(ctx, id, parentID, false, initMsg, upstream)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Setup encryption if enabled.
|
||||
if remoteHub != nil {
|
||||
initMsg.Encrypt = true
|
||||
|
||||
// Select signet (public key) of remote Hub to use.
|
||||
s := remoteHub.SelectSignet()
|
||||
if s == nil {
|
||||
return nil, nil, ErrHubNotReady.With("failed to select signet of remote hub")
|
||||
}
|
||||
|
||||
// Create new session.
|
||||
env := jess.NewUnconfiguredEnvelope()
|
||||
env.SuiteID = jess.SuiteWireV1
|
||||
env.Recipients = []*jess.Signet{s}
|
||||
jession, err := env.WireCorrespondence(nil)
|
||||
if err != nil {
|
||||
return nil, nil, ErrIntegrity.With("failed to initialize encryption: %w", err)
|
||||
}
|
||||
t.jession = jession
|
||||
|
||||
// Encryption is ready for sending.
|
||||
close(t.encryptionReady)
|
||||
}
|
||||
|
||||
return t, initData, nil
|
||||
}
|
||||
|
||||
// NewRemoteBaseTerminal creates a new remote terminal base for use with inheriting terminals.
|
||||
func NewRemoteBaseTerminal(
|
||||
ctx context.Context,
|
||||
id uint32,
|
||||
parentID string,
|
||||
identity *cabin.Identity,
|
||||
initData *container.Container,
|
||||
upstream Upstream,
|
||||
) (
|
||||
t *TerminalBase,
|
||||
initMsg *TerminalOpts,
|
||||
err *Error,
|
||||
) {
|
||||
// Parse init message.
|
||||
initMsg, err = ParseTerminalOpts(initData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Create baseline.
|
||||
t, err = createTerminalBase(ctx, id, parentID, true, initMsg, upstream)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Setup encryption if enabled.
|
||||
if initMsg.Encrypt {
|
||||
if identity == nil {
|
||||
return nil, nil, ErrInternalError.With("missing identity for setting up incoming encryption")
|
||||
}
|
||||
t.identity = identity
|
||||
}
|
||||
|
||||
return t, initMsg, nil
|
||||
}
|
||||
117
spn/terminal/metrics.go
Normal file
117
spn/terminal/metrics.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/metrics"
|
||||
)
|
||||
|
||||
var metricsRegistered = abool.New()
|
||||
|
||||
func registerMetrics() (err error) {
|
||||
// Only register metrics once.
|
||||
if !metricsRegistered.SetToIf(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get scheduler config and calculat scaling.
|
||||
schedulerConfig := getSchedulerConfig()
|
||||
scaleSlotToSecondsFactor := float64(time.Second / schedulerConfig.SlotDuration)
|
||||
|
||||
// Register metrics from scheduler stats.
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/scheduling/unit/slotpace/max",
|
||||
nil,
|
||||
metricFromInt(scheduler.GetMaxSlotPace, scaleSlotToSecondsFactor),
|
||||
&metrics.Options{
|
||||
Name: "SPN Scheduling Max Slot Pace (scaled to per second)",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/scheduling/unit/slotpace/leveled/max",
|
||||
nil,
|
||||
metricFromInt(scheduler.GetMaxLeveledSlotPace, scaleSlotToSecondsFactor),
|
||||
&metrics.Options{
|
||||
Name: "SPN Scheduling Max Leveled Slot Pace (scaled to per second)",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/scheduling/unit/slotpace/avg",
|
||||
nil,
|
||||
metricFromInt(scheduler.GetAvgSlotPace, scaleSlotToSecondsFactor),
|
||||
&metrics.Options{
|
||||
Name: "SPN Scheduling Avg Slot Pace (scaled to per second)",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/scheduling/unit/life/avg/seconds",
|
||||
nil,
|
||||
metricFromNanoseconds(scheduler.GetAvgUnitLife),
|
||||
&metrics.Options{
|
||||
Name: "SPN Scheduling Avg Unit Life",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/scheduling/unit/workslot/avg/seconds",
|
||||
nil,
|
||||
metricFromNanoseconds(scheduler.GetAvgWorkSlotDuration),
|
||||
&metrics.Options{
|
||||
Name: "SPN Scheduling Avg Work Slot Duration",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/scheduling/unit/catchupslot/avg/seconds",
|
||||
nil,
|
||||
metricFromNanoseconds(scheduler.GetAvgCatchUpSlotDuration),
|
||||
&metrics.Options{
|
||||
Name: "SPN Scheduling Avg Catch-Up Slot Duration",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func metricFromInt(fn func() int64, scaleFactor float64) func() float64 {
|
||||
return func() float64 {
|
||||
return float64(fn()) * scaleFactor
|
||||
}
|
||||
}
|
||||
|
||||
func metricFromNanoseconds(fn func() int64) func() float64 {
|
||||
return func() float64 {
|
||||
return float64(fn()) / float64(time.Second)
|
||||
}
|
||||
}
|
||||
80
spn/terminal/module.go
Normal file
80
spn/terminal/module.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/unit"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
rngFeeder *rng.Feeder = rng.NewFeeder()
|
||||
|
||||
scheduler *unit.Scheduler
|
||||
|
||||
debugUnitScheduling bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&debugUnitScheduling, "debug-unit-scheduling", false, "enable debug logs of the SPN unit scheduler")
|
||||
|
||||
module = modules.Register("terminal", nil, start, nil, "base")
|
||||
}
|
||||
|
||||
func start() error {
|
||||
rngFeeder = rng.NewFeeder()
|
||||
|
||||
scheduler = unit.NewScheduler(getSchedulerConfig())
|
||||
if debugUnitScheduling {
|
||||
// Debug unit leaks.
|
||||
scheduler.StartDebugLog()
|
||||
}
|
||||
module.StartServiceWorker("msg unit scheduler", 0, scheduler.SlotScheduler)
|
||||
|
||||
lockOpRegistry()
|
||||
|
||||
return registerMetrics()
|
||||
}
|
||||
|
||||
var waitForever chan time.Time
|
||||
|
||||
// TimedOut returns a channel that triggers when the timeout is reached.
|
||||
func TimedOut(timeout time.Duration) <-chan time.Time {
|
||||
if timeout == 0 {
|
||||
return waitForever
|
||||
}
|
||||
return time.After(timeout)
|
||||
}
|
||||
|
||||
// StopScheduler stops the unit scheduler.
|
||||
func StopScheduler() {
|
||||
if scheduler != nil {
|
||||
scheduler.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func getSchedulerConfig() *unit.SchedulerConfig {
|
||||
// Client Scheduler Config.
|
||||
if conf.Client() {
|
||||
return &unit.SchedulerConfig{
|
||||
SlotDuration: 10 * time.Millisecond, // 100 slots per second
|
||||
MinSlotPace: 10, // 1000pps - Small starting pace for low end devices.
|
||||
WorkSlotPercentage: 0.9, // 90%
|
||||
SlotChangeRatePerStreak: 0.1, // 10% - Increase/Decrease quickly.
|
||||
StatCycleDuration: 1 * time.Minute, // Match metrics report cycle.
|
||||
}
|
||||
}
|
||||
|
||||
// Server Scheduler Config.
|
||||
return &unit.SchedulerConfig{
|
||||
SlotDuration: 10 * time.Millisecond, // 100 slots per second
|
||||
MinSlotPace: 100, // 10000pps - Every server should be able to handle this.
|
||||
WorkSlotPercentage: 0.7, // 70%
|
||||
SlotChangeRatePerStreak: 0.05, // 5%
|
||||
StatCycleDuration: 1 * time.Minute, // Match metrics report cycle.
|
||||
}
|
||||
}
|
||||
13
spn/terminal/module_test.go
Normal file
13
spn/terminal/module_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/service/core/pmtesting"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
conf.EnablePublicHub(true)
|
||||
pmtesting.TestMain(m, module)
|
||||
}
|
||||
106
spn/terminal/msg.go
Normal file
106
spn/terminal/msg.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portmaster/spn/unit"
|
||||
)
|
||||
|
||||
// Msg is a message within the SPN network stack.
|
||||
// It includes metadata and unit scheduling.
|
||||
type Msg struct {
|
||||
FlowID uint32
|
||||
Type MsgType
|
||||
Data *container.Container
|
||||
|
||||
// Unit scheduling.
|
||||
// Note: With just 100B per packet, a uint64 (the Unit ID) is enough for
|
||||
// over 1800 Exabyte. No need for overflow support.
|
||||
Unit *unit.Unit
|
||||
}
|
||||
|
||||
// NewMsg returns a new msg.
|
||||
// The FlowID is unset.
|
||||
// The Type is Data.
|
||||
func NewMsg(data []byte) *Msg {
|
||||
msg := &Msg{
|
||||
Type: MsgTypeData,
|
||||
Data: container.New(data),
|
||||
Unit: scheduler.NewUnit(),
|
||||
}
|
||||
|
||||
// Debug unit leaks.
|
||||
msg.debugWithCaller(2)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// NewEmptyMsg returns a new empty msg with an initialized Unit.
|
||||
// The FlowID is unset.
|
||||
// The Type is Data.
|
||||
// The Data is unset.
|
||||
func NewEmptyMsg() *Msg {
|
||||
msg := &Msg{
|
||||
Type: MsgTypeData,
|
||||
Unit: scheduler.NewUnit(),
|
||||
}
|
||||
|
||||
// Debug unit leaks.
|
||||
msg.debugWithCaller(2)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// Pack prepends the message header (Length and ID+Type) to the data.
|
||||
func (msg *Msg) Pack() {
|
||||
MakeMsg(msg.Data, msg.FlowID, msg.Type)
|
||||
}
|
||||
|
||||
// Consume adds another Message to itself.
|
||||
// The given Msg is packed before adding it to the data.
|
||||
// The data is moved - not copied!
|
||||
// High priority mark is inherited.
|
||||
func (msg *Msg) Consume(other *Msg) {
|
||||
// Pack message to be added.
|
||||
other.Pack()
|
||||
|
||||
// Move data.
|
||||
msg.Data.AppendContainer(other.Data)
|
||||
|
||||
// Inherit high priority.
|
||||
if other.Unit.IsHighPriority() {
|
||||
msg.Unit.MakeHighPriority()
|
||||
}
|
||||
|
||||
// Finish other unit.
|
||||
other.Finish()
|
||||
}
|
||||
|
||||
// Finish signals the unit scheduler that this unit has finished processing.
|
||||
// Will no-op if called on a nil Msg.
|
||||
func (msg *Msg) Finish() {
|
||||
// Proxying is necessary, as a nil msg still panics.
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
msg.Unit.Finish()
|
||||
}
|
||||
|
||||
// Debug registers the unit for debug output with the given source.
|
||||
// Additional calls on the same unit update the unit source.
|
||||
// StartDebugLog() must be called before calling DebugUnit().
|
||||
func (msg *Msg) Debug() {
|
||||
msg.debugWithCaller(2)
|
||||
}
|
||||
|
||||
func (msg *Msg) debugWithCaller(skip int) { //nolint:unparam
|
||||
if !debugUnitScheduling || msg == nil {
|
||||
return
|
||||
}
|
||||
_, file, line, ok := runtime.Caller(skip)
|
||||
if ok {
|
||||
scheduler.DebugUnit(msg.Unit, fmt.Sprintf("%s:%d", file, line))
|
||||
}
|
||||
}
|
||||
66
spn/terminal/msgtypes.go
Normal file
66
spn/terminal/msgtypes.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
)
|
||||
|
||||
/*
|
||||
Terminal and Operation Message Format:
|
||||
|
||||
- Length [varint]
|
||||
- If Length is 0, the remainder of given data is padding.
|
||||
- IDType [varint]
|
||||
- Type [uses least two significant bits]
|
||||
- One of Init, Data, Stop
|
||||
- ID [uses all other bits]
|
||||
- The ID is currently not adapted in order to make reading raw message
|
||||
easier. This means that IDs are currently always a multiple of 4.
|
||||
- Data [bytes; format depends on msg type]
|
||||
- MsgTypeInit:
|
||||
- Data [bytes]
|
||||
- MsgTypeData:
|
||||
- AddAvailableSpace [varint, if Flow Queue is used]
|
||||
- (Encrypted) Data [bytes]
|
||||
- MsgTypeStop:
|
||||
- Error Code [varint]
|
||||
*/
|
||||
|
||||
// MsgType is the message type for both terminals and operations.
|
||||
type MsgType uint8
|
||||
|
||||
const (
|
||||
// MsgTypeInit is used to establish a new terminal or run a new operation.
|
||||
MsgTypeInit MsgType = 1
|
||||
|
||||
// MsgTypeData is used to send data to a terminal or operation.
|
||||
MsgTypeData MsgType = 2
|
||||
|
||||
// MsgTypePriorityData is used to send prioritized data to a terminal or operation.
|
||||
MsgTypePriorityData MsgType = 0
|
||||
|
||||
// MsgTypeStop is used to abandon a terminal or end an operation, with an optional error.
|
||||
MsgTypeStop MsgType = 3
|
||||
)
|
||||
|
||||
// AddIDType prepends the ID and Type header to the message.
|
||||
func AddIDType(c *container.Container, id uint32, msgType MsgType) {
|
||||
c.Prepend(varint.Pack32(id | uint32(msgType)))
|
||||
}
|
||||
|
||||
// MakeMsg prepends the message header (Length and ID+Type) to the data.
|
||||
func MakeMsg(c *container.Container, id uint32, msgType MsgType) {
|
||||
AddIDType(c, id, msgType)
|
||||
c.PrependLength()
|
||||
}
|
||||
|
||||
// ParseIDType parses the combined message ID and type.
|
||||
func ParseIDType(c *container.Container) (id uint32, msgType MsgType, err error) {
|
||||
idType, err := c.GetNextN32()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
msgType = MsgType(idType % 4)
|
||||
return idType - uint32(msgType), msgType, nil
|
||||
}
|
||||
332
spn/terminal/operation.go
Normal file
332
spn/terminal/operation.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/utils"
|
||||
)
|
||||
|
||||
// Operation is an interface for all operations.
|
||||
type Operation interface {
|
||||
// InitOperationBase initialize the operation with the ID and attached terminal.
|
||||
// Should not be overridden by implementations.
|
||||
InitOperationBase(t Terminal, opID uint32)
|
||||
|
||||
// ID returns the ID of the operation.
|
||||
// Should not be overridden by implementations.
|
||||
ID() uint32
|
||||
|
||||
// Type returns the operation's type ID.
|
||||
// Should be overridden by implementations to return correct type ID.
|
||||
Type() string
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
// Meant to be overridden by implementations.
|
||||
Deliver(msg *Msg) *Error
|
||||
|
||||
// NewMsg creates a new message from this operation.
|
||||
// Should not be overridden by implementations.
|
||||
NewMsg(data []byte) *Msg
|
||||
|
||||
// Send sends a message to the other side.
|
||||
// Should not be overridden by implementations.
|
||||
Send(msg *Msg, timeout time.Duration) *Error
|
||||
|
||||
// Flush sends all messages waiting in the terminal.
|
||||
// Should not be overridden by implementations.
|
||||
Flush(timeout time.Duration)
|
||||
|
||||
// Stopped returns whether the operation has stopped.
|
||||
// Should not be overridden by implementations.
|
||||
Stopped() bool
|
||||
|
||||
// markStopped marks the operation as stopped.
|
||||
// It returns whether the stop flag was set.
|
||||
markStopped() bool
|
||||
|
||||
// Stop stops the operation by unregistering it from the terminal and calling HandleStop().
|
||||
// Should not be overridden by implementations.
|
||||
Stop(self Operation, err *Error)
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
// Meant to be overridden by implementations.
|
||||
HandleStop(err *Error) (errorToSend *Error)
|
||||
|
||||
// Terminal returns the terminal the operation is linked to.
|
||||
// Should not be overridden by implementations.
|
||||
Terminal() Terminal
|
||||
}
|
||||
|
||||
// OperationFactory defines an operation factory.
|
||||
type OperationFactory struct {
|
||||
// Type is the type id of an operation.
|
||||
Type string
|
||||
// Requires defines the required permissions to run an operation.
|
||||
Requires Permission
|
||||
// Start is the function that starts a new operation.
|
||||
Start OperationStarter
|
||||
}
|
||||
|
||||
// OperationStarter is used to initialize operations remotely.
|
||||
type OperationStarter func(attachedTerminal Terminal, opID uint32, initData *container.Container) (Operation, *Error)
|
||||
|
||||
var (
|
||||
opRegistry = make(map[string]*OperationFactory)
|
||||
opRegistryLock sync.Mutex
|
||||
opRegistryLocked = abool.New()
|
||||
)
|
||||
|
||||
// RegisterOpType registers a new operation type and may only be called during
|
||||
// Go's init and a module's prep phase.
|
||||
func RegisterOpType(factory OperationFactory) {
|
||||
// Check if we can still register an operation type.
|
||||
if opRegistryLocked.IsSet() {
|
||||
log.Errorf("spn/terminal: failed to register operation %s: operation registry is already locked", factory.Type)
|
||||
return
|
||||
}
|
||||
|
||||
opRegistryLock.Lock()
|
||||
defer opRegistryLock.Unlock()
|
||||
|
||||
// Check if the operation type was already registered.
|
||||
if _, ok := opRegistry[factory.Type]; ok {
|
||||
log.Errorf("spn/terminal: failed to register operation type %s: type already registered", factory.Type)
|
||||
return
|
||||
}
|
||||
|
||||
// Save to registry.
|
||||
opRegistry[factory.Type] = &factory
|
||||
}
|
||||
|
||||
func lockOpRegistry() {
|
||||
opRegistryLocked.Set()
|
||||
}
|
||||
|
||||
func (t *TerminalBase) handleOperationStart(opID uint32, initData *container.Container) {
|
||||
// Check if the terminal is being abandoned.
|
||||
if t.Abandoning.IsSet() {
|
||||
t.StopOperation(newUnknownOp(opID, ""), ErrAbandonedTerminal)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract the requested operation name.
|
||||
opType, err := initData.GetNextBlock()
|
||||
if err != nil {
|
||||
t.StopOperation(newUnknownOp(opID, ""), ErrMalformedData.With("failed to get init data: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Get the operation factory from the registry.
|
||||
factory, ok := opRegistry[string(opType)]
|
||||
if !ok {
|
||||
t.StopOperation(newUnknownOp(opID, ""), ErrUnknownOperationType.With(utils.SafeFirst16Bytes(opType)))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the Terminal has the required permission to run the operation.
|
||||
if !t.HasPermission(factory.Requires) {
|
||||
t.StopOperation(newUnknownOp(opID, factory.Type), ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
|
||||
// Get terminal to attach to.
|
||||
attachToTerminal := t.ext
|
||||
if attachToTerminal == nil {
|
||||
attachToTerminal = t
|
||||
}
|
||||
|
||||
// Run the operation.
|
||||
op, opErr := factory.Start(attachToTerminal, opID, initData)
|
||||
switch {
|
||||
case opErr != nil:
|
||||
// Something went wrong.
|
||||
t.StopOperation(newUnknownOp(opID, factory.Type), opErr)
|
||||
case op == nil:
|
||||
// The Operation was successful and is done already.
|
||||
log.Debugf("spn/terminal: operation %s %s executed", factory.Type, fmtOperationID(t.parentID, t.id, opID))
|
||||
t.StopOperation(newUnknownOp(opID, factory.Type), nil)
|
||||
default:
|
||||
// The operation started successfully and requires persistence.
|
||||
t.SetActiveOp(opID, op)
|
||||
log.Debugf("spn/terminal: operation %s %s started", factory.Type, fmtOperationID(t.parentID, t.id, opID))
|
||||
}
|
||||
}
|
||||
|
||||
// StartOperation starts the given operation by assigning it an ID and sending the given operation initialization data.
|
||||
func (t *TerminalBase) StartOperation(op Operation, initData *container.Container, timeout time.Duration) *Error {
|
||||
// Get terminal to attach to.
|
||||
attachToTerminal := t.ext
|
||||
if attachToTerminal == nil {
|
||||
attachToTerminal = t
|
||||
}
|
||||
|
||||
// Get the next operation ID and set it on the operation with the terminal.
|
||||
op.InitOperationBase(attachToTerminal, atomic.AddUint32(t.nextOpID, 8))
|
||||
|
||||
// Always add operation to the active operations, as we need to receive a
|
||||
// reply in any case.
|
||||
t.SetActiveOp(op.ID(), op)
|
||||
|
||||
log.Debugf("spn/terminal: operation %s %s started", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()))
|
||||
|
||||
// Add or create the operation type block.
|
||||
if initData == nil {
|
||||
initData = container.New()
|
||||
initData.AppendAsBlock([]byte(op.Type()))
|
||||
} else {
|
||||
initData.PrependAsBlock([]byte(op.Type()))
|
||||
}
|
||||
|
||||
// Create init msg.
|
||||
msg := NewEmptyMsg()
|
||||
msg.FlowID = op.ID()
|
||||
msg.Type = MsgTypeInit
|
||||
msg.Data = initData
|
||||
msg.Unit.MakeHighPriority()
|
||||
|
||||
// Send init msg.
|
||||
err := op.Send(msg, timeout)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Send sends data via this terminal.
|
||||
// If a timeout is set, sending will fail after the given timeout passed.
|
||||
func (t *TerminalBase) Send(msg *Msg, timeout time.Duration) *Error {
|
||||
// Wait for processing slot.
|
||||
msg.Unit.WaitForSlot()
|
||||
|
||||
// Check if the send queue has available space.
|
||||
select {
|
||||
case t.sendQueue <- msg:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Submit message to buffer, if space is available.
|
||||
select {
|
||||
case t.sendQueue <- msg:
|
||||
return nil
|
||||
case <-TimedOut(timeout):
|
||||
msg.Finish()
|
||||
return ErrTimeout.With("sending via terminal")
|
||||
case <-t.Ctx().Done():
|
||||
msg.Finish()
|
||||
return ErrStopping
|
||||
}
|
||||
}
|
||||
|
||||
// StopOperation sends the end signal with an optional error and then deletes
|
||||
// the operation from the Terminal state and calls HandleStop() on the Operation.
|
||||
func (t *TerminalBase) StopOperation(op Operation, err *Error) {
|
||||
// Check if the operation has already stopped.
|
||||
if !op.markStopped() {
|
||||
return
|
||||
}
|
||||
|
||||
// Log reason the Operation is ending. Override stopping error with nil.
|
||||
switch {
|
||||
case err == nil:
|
||||
log.Debugf("spn/terminal: operation %s %s stopped", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()))
|
||||
case err.IsOK(), err.Is(ErrTryAgainLater), err.Is(ErrRateLimited):
|
||||
log.Debugf("spn/terminal: operation %s %s stopped: %s", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()), err)
|
||||
default:
|
||||
log.Warningf("spn/terminal: operation %s %s failed: %s", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()), err)
|
||||
}
|
||||
|
||||
module.StartWorker("stop operation", func(_ context.Context) error {
|
||||
// Call operation stop handle function for proper shutdown cleaning up.
|
||||
err = op.HandleStop(err)
|
||||
|
||||
// Send error to the connected Operation, if the error is internal.
|
||||
if !err.IsExternal() {
|
||||
if err == nil {
|
||||
err = ErrStopping
|
||||
}
|
||||
|
||||
msg := NewMsg(err.Pack())
|
||||
msg.FlowID = op.ID()
|
||||
msg.Type = MsgTypeStop
|
||||
|
||||
tErr := t.Send(msg, 10*time.Second)
|
||||
if tErr != nil {
|
||||
msg.Finish()
|
||||
log.Warningf("spn/terminal: failed to send stop msg: %s", tErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove operation from terminal.
|
||||
t.DeleteActiveOp(op.ID())
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetActiveOp returns the active operation with the given ID from the
|
||||
// Terminal state.
|
||||
func (t *TerminalBase) GetActiveOp(opID uint32) (op Operation, ok bool) {
|
||||
t.lock.RLock()
|
||||
defer t.lock.RUnlock()
|
||||
|
||||
op, ok = t.operations[opID]
|
||||
return
|
||||
}
|
||||
|
||||
// SetActiveOp saves an active operation to the Terminal state.
|
||||
func (t *TerminalBase) SetActiveOp(opID uint32, op Operation) {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
|
||||
t.operations[opID] = op
|
||||
}
|
||||
|
||||
// DeleteActiveOp deletes an active operation from the Terminal state.
|
||||
func (t *TerminalBase) DeleteActiveOp(opID uint32) {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
|
||||
delete(t.operations, opID)
|
||||
}
|
||||
|
||||
// GetActiveOpCount returns the amount of active operations.
|
||||
func (t *TerminalBase) GetActiveOpCount() int {
|
||||
t.lock.RLock()
|
||||
defer t.lock.RUnlock()
|
||||
|
||||
return len(t.operations)
|
||||
}
|
||||
|
||||
func newUnknownOp(id uint32, typeID string) *unknownOp {
|
||||
op := &unknownOp{
|
||||
typeID: typeID,
|
||||
}
|
||||
op.id = id
|
||||
return op
|
||||
}
|
||||
|
||||
type unknownOp struct {
|
||||
OperationBase
|
||||
typeID string
|
||||
}
|
||||
|
||||
func (op *unknownOp) Type() string {
|
||||
if op.typeID != "" {
|
||||
return op.typeID
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (op *unknownOp) Deliver(msg *Msg) *Error {
|
||||
return ErrIncorrectUsage.With("unknown op shim cannot receive")
|
||||
}
|
||||
185
spn/terminal/operation_base.go
Normal file
185
spn/terminal/operation_base.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// OperationBase provides the basic operation functionality.
|
||||
type OperationBase struct {
|
||||
terminal Terminal
|
||||
id uint32
|
||||
stopped abool.AtomicBool
|
||||
}
|
||||
|
||||
// InitOperationBase initialize the operation with the ID and attached terminal.
|
||||
// Should not be overridden by implementations.
|
||||
func (op *OperationBase) InitOperationBase(t Terminal, opID uint32) {
|
||||
op.id = opID
|
||||
op.terminal = t
|
||||
}
|
||||
|
||||
// ID returns the ID of the operation.
|
||||
// Should not be overridden by implementations.
|
||||
func (op *OperationBase) ID() uint32 {
|
||||
return op.id
|
||||
}
|
||||
|
||||
// Type returns the operation's type ID.
|
||||
// Should be overridden by implementations to return correct type ID.
|
||||
func (op *OperationBase) Type() string {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
// Meant to be overridden by implementations.
|
||||
func (op *OperationBase) Deliver(_ *Msg) *Error {
|
||||
return ErrIncorrectUsage.With("Deliver not implemented for this operation")
|
||||
}
|
||||
|
||||
// NewMsg creates a new message from this operation.
|
||||
// Should not be overridden by implementations.
|
||||
func (op *OperationBase) NewMsg(data []byte) *Msg {
|
||||
msg := NewMsg(data)
|
||||
msg.FlowID = op.id
|
||||
msg.Type = MsgTypeData
|
||||
|
||||
// Debug unit leaks.
|
||||
msg.debugWithCaller(2)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// NewEmptyMsg creates a new empty message from this operation.
|
||||
// Should not be overridden by implementations.
|
||||
func (op *OperationBase) NewEmptyMsg() *Msg {
|
||||
msg := NewEmptyMsg()
|
||||
msg.FlowID = op.id
|
||||
msg.Type = MsgTypeData
|
||||
|
||||
// Debug unit leaks.
|
||||
msg.debugWithCaller(2)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// Send sends a message to the other side.
|
||||
// Should not be overridden by implementations.
|
||||
func (op *OperationBase) Send(msg *Msg, timeout time.Duration) *Error {
|
||||
// Add and update metadata.
|
||||
msg.FlowID = op.id
|
||||
if msg.Type == MsgTypeData && msg.Unit.IsHighPriority() && UsePriorityDataMsgs {
|
||||
msg.Type = MsgTypePriorityData
|
||||
}
|
||||
|
||||
// Wait for processing slot.
|
||||
msg.Unit.WaitForSlot()
|
||||
|
||||
// Send message.
|
||||
tErr := op.terminal.Send(msg, timeout)
|
||||
if tErr != nil {
|
||||
// Finish message unit on failure.
|
||||
msg.Finish()
|
||||
}
|
||||
return tErr
|
||||
}
|
||||
|
||||
// Flush sends all messages waiting in the terminal.
|
||||
// Meant to be overridden by implementations.
|
||||
func (op *OperationBase) Flush(timeout time.Duration) {
|
||||
op.terminal.Flush(timeout)
|
||||
}
|
||||
|
||||
// Stopped returns whether the operation has stopped.
|
||||
// Should not be overridden by implementations.
|
||||
func (op *OperationBase) Stopped() bool {
|
||||
return op.stopped.IsSet()
|
||||
}
|
||||
|
||||
// markStopped marks the operation as stopped.
|
||||
// It returns whether the stop flag was set.
|
||||
func (op *OperationBase) markStopped() bool {
|
||||
return op.stopped.SetToIf(false, true)
|
||||
}
|
||||
|
||||
// Stop stops the operation by unregistering it from the terminal and calling HandleStop().
|
||||
// Should not be overridden by implementations.
|
||||
func (op *OperationBase) Stop(self Operation, err *Error) {
|
||||
// Stop operation from terminal.
|
||||
op.terminal.StopOperation(self, err)
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
// Meant to be overridden by implementations.
|
||||
func (op *OperationBase) HandleStop(err *Error) (errorToSend *Error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Terminal returns the terminal the operation is linked to.
|
||||
// Should not be overridden by implementations.
|
||||
func (op *OperationBase) Terminal() Terminal {
|
||||
return op.terminal
|
||||
}
|
||||
|
||||
// OneOffOperationBase is an operation base for operations that just have one
|
||||
// message and a error return.
|
||||
type OneOffOperationBase struct {
|
||||
OperationBase
|
||||
|
||||
Result chan *Error
|
||||
}
|
||||
|
||||
// Init initializes the single operation base.
|
||||
func (op *OneOffOperationBase) Init() {
|
||||
op.Result = make(chan *Error, 1)
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *OneOffOperationBase) HandleStop(err *Error) (errorToSend *Error) {
|
||||
select {
|
||||
case op.Result <- err:
|
||||
default:
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// MessageStreamOperationBase is an operation base for receiving a message stream.
|
||||
// Every received message must be finished by the implementing operation.
|
||||
type MessageStreamOperationBase struct {
|
||||
OperationBase
|
||||
|
||||
Delivered chan *Msg
|
||||
Ended chan *Error
|
||||
}
|
||||
|
||||
// Init initializes the operation base.
|
||||
func (op *MessageStreamOperationBase) Init(deliverQueueSize int) {
|
||||
op.Delivered = make(chan *Msg, deliverQueueSize)
|
||||
op.Ended = make(chan *Error, 1)
|
||||
}
|
||||
|
||||
// Deliver delivers data to the operation.
|
||||
func (op *MessageStreamOperationBase) Deliver(msg *Msg) *Error {
|
||||
select {
|
||||
case op.Delivered <- msg:
|
||||
return nil
|
||||
default:
|
||||
return ErrIncorrectUsage.With("request was not waiting for data")
|
||||
}
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *MessageStreamOperationBase) HandleStop(err *Error) (errorToSend *Error) {
|
||||
select {
|
||||
case op.Ended <- err:
|
||||
default:
|
||||
}
|
||||
return err
|
||||
}
|
||||
255
spn/terminal/operation_counter.go
Normal file
255
spn/terminal/operation_counter.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
// CounterOpType is the type ID for the Counter Operation.
|
||||
const CounterOpType string = "debug/count"
|
||||
|
||||
// CounterOp sends increasing numbers on both sides.
|
||||
type CounterOp struct { //nolint:maligned
|
||||
OperationBase
|
||||
|
||||
wg sync.WaitGroup
|
||||
server bool
|
||||
opts *CounterOpts
|
||||
|
||||
counterLock sync.Mutex
|
||||
ClientCounter uint64
|
||||
ServerCounter uint64
|
||||
Error error
|
||||
}
|
||||
|
||||
// CounterOpts holds the options for CounterOp.
|
||||
type CounterOpts struct {
|
||||
ClientCountTo uint64
|
||||
ServerCountTo uint64
|
||||
Wait time.Duration
|
||||
Flush bool
|
||||
|
||||
suppressWorker bool
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterOpType(OperationFactory{
|
||||
Type: CounterOpType,
|
||||
Start: startCounterOp,
|
||||
})
|
||||
}
|
||||
|
||||
// NewCounterOp returns a new CounterOp.
|
||||
func NewCounterOp(t Terminal, opts CounterOpts) (*CounterOp, *Error) {
|
||||
// Create operation.
|
||||
op := &CounterOp{
|
||||
opts: &opts,
|
||||
}
|
||||
op.wg.Add(1)
|
||||
|
||||
// Create argument container.
|
||||
data, err := dsd.Dump(op.opts, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, ErrInternalError.With("failed to pack options: %w", err)
|
||||
}
|
||||
|
||||
// Initialize operation.
|
||||
tErr := t.StartOperation(op, container.New(data), 3*time.Second)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
// Start worker if needed.
|
||||
if op.getRemoteCounterTarget() > 0 && !op.opts.suppressWorker {
|
||||
module.StartWorker("counter sender", op.CounterWorker)
|
||||
}
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func startCounterOp(t Terminal, opID uint32, data *container.Container) (Operation, *Error) {
|
||||
// Create operation.
|
||||
op := &CounterOp{
|
||||
server: true,
|
||||
}
|
||||
op.InitOperationBase(t, opID)
|
||||
op.wg.Add(1)
|
||||
|
||||
// Parse arguments.
|
||||
opts := &CounterOpts{}
|
||||
_, err := dsd.Load(data.CompileData(), opts)
|
||||
if err != nil {
|
||||
return nil, ErrInternalError.With("failed to unpack options: %w", err)
|
||||
}
|
||||
op.opts = opts
|
||||
|
||||
// Start worker if needed.
|
||||
if op.getRemoteCounterTarget() > 0 {
|
||||
module.StartWorker("counter sender", op.CounterWorker)
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
// Type returns the operation's type ID.
|
||||
func (op *CounterOp) Type() string {
|
||||
return CounterOpType
|
||||
}
|
||||
|
||||
func (op *CounterOp) getCounter(sending, increase bool) uint64 {
|
||||
op.counterLock.Lock()
|
||||
defer op.counterLock.Unlock()
|
||||
|
||||
// Use server counter, when op is server or for sending, but not when both.
|
||||
if op.server != sending {
|
||||
if increase {
|
||||
op.ServerCounter++
|
||||
}
|
||||
return op.ServerCounter
|
||||
}
|
||||
|
||||
if increase {
|
||||
op.ClientCounter++
|
||||
}
|
||||
return op.ClientCounter
|
||||
}
|
||||
|
||||
func (op *CounterOp) getRemoteCounterTarget() uint64 {
|
||||
if op.server {
|
||||
return op.opts.ClientCountTo
|
||||
}
|
||||
return op.opts.ServerCountTo
|
||||
}
|
||||
|
||||
func (op *CounterOp) isDone() bool {
|
||||
op.counterLock.Lock()
|
||||
defer op.counterLock.Unlock()
|
||||
|
||||
return op.ClientCounter >= op.opts.ClientCountTo &&
|
||||
op.ServerCounter >= op.opts.ServerCountTo
|
||||
}
|
||||
|
||||
// Deliver delivers data to the operation.
|
||||
func (op *CounterOp) Deliver(msg *Msg) *Error {
|
||||
defer msg.Finish()
|
||||
|
||||
nextStep, err := msg.Data.GetNextN64()
|
||||
if err != nil {
|
||||
op.Stop(op, ErrMalformedData.With("failed to parse next number: %w", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count and compare.
|
||||
counter := op.getCounter(false, true)
|
||||
|
||||
// Debugging:
|
||||
// if counter < 100 ||
|
||||
// counter < 1000 && counter%100 == 0 ||
|
||||
// counter < 10000 && counter%1000 == 0 ||
|
||||
// counter < 100000 && counter%10000 == 0 ||
|
||||
// counter < 1000000 && counter%100000 == 0 {
|
||||
// log.Errorf("spn/terminal: counter %s>%d recvd, now at %d", op.t.FmtID(), op.id, counter)
|
||||
// }
|
||||
|
||||
if counter != nextStep {
|
||||
log.Warningf(
|
||||
"terminal: integrity of counter op violated: received %d, expected %d",
|
||||
nextStep,
|
||||
counter,
|
||||
)
|
||||
op.Stop(op, ErrIntegrity.With("counters mismatched"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we are done.
|
||||
if op.isDone() {
|
||||
op.Stop(op, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleStop handles stopping the operation.
|
||||
func (op *CounterOp) HandleStop(err *Error) (errorToSend *Error) {
|
||||
// Check if counting finished.
|
||||
if !op.isDone() {
|
||||
err := fmt.Errorf(
|
||||
"counter op %d: did not finish counting (%d<-%d %d->%d)",
|
||||
op.id,
|
||||
op.opts.ClientCountTo, op.ClientCounter,
|
||||
op.ServerCounter, op.opts.ServerCountTo,
|
||||
)
|
||||
op.Error = err
|
||||
}
|
||||
|
||||
op.wg.Done()
|
||||
return err
|
||||
}
|
||||
|
||||
// SendCounter sends the next counter.
|
||||
func (op *CounterOp) SendCounter() *Error {
|
||||
if op.Stopped() {
|
||||
return ErrStopping
|
||||
}
|
||||
|
||||
// Increase sending counter.
|
||||
counter := op.getCounter(true, true)
|
||||
|
||||
// Debugging:
|
||||
// if counter < 100 ||
|
||||
// counter < 1000 && counter%100 == 0 ||
|
||||
// counter < 10000 && counter%1000 == 0 ||
|
||||
// counter < 100000 && counter%10000 == 0 ||
|
||||
// counter < 1000000 && counter%100000 == 0 {
|
||||
// defer log.Errorf("spn/terminal: counter %s>%d sent, now at %d", op.t.FmtID(), op.id, counter)
|
||||
// }
|
||||
|
||||
return op.Send(op.NewMsg(varint.Pack64(counter)), 3*time.Second)
|
||||
}
|
||||
|
||||
// Wait waits for the Counter Op to finish.
|
||||
func (op *CounterOp) Wait() {
|
||||
op.wg.Wait()
|
||||
}
|
||||
|
||||
// CounterWorker is a worker that sends counters.
|
||||
func (op *CounterOp) CounterWorker(ctx context.Context) error {
|
||||
for {
|
||||
// Send counter msg.
|
||||
err := op.SendCounter()
|
||||
switch err {
|
||||
case nil:
|
||||
// All good, continue.
|
||||
case ErrStopping:
|
||||
// Done!
|
||||
return nil
|
||||
default:
|
||||
// Something went wrong.
|
||||
err := fmt.Errorf("counter op %d: failed to send counter: %w", op.id, err)
|
||||
op.Error = err
|
||||
op.Stop(op, ErrInternalError.With(err.Error()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Maybe flush message.
|
||||
if op.opts.Flush {
|
||||
op.terminal.Flush(1 * time.Second)
|
||||
}
|
||||
|
||||
// Check if we are done with sending.
|
||||
if op.getCounter(true, false) >= op.getRemoteCounterTarget() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Maybe wait a little.
|
||||
if op.opts.Wait > 0 {
|
||||
time.Sleep(op.opts.Wait)
|
||||
}
|
||||
}
|
||||
}
|
||||
50
spn/terminal/permission.go
Normal file
50
spn/terminal/permission.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package terminal
|
||||
|
||||
// Permission is a bit-map of granted permissions.
|
||||
type Permission uint16
|
||||
|
||||
// Permissions.
|
||||
const (
|
||||
NoPermission Permission = 0x0
|
||||
MayExpand Permission = 0x1
|
||||
MayConnect Permission = 0x2
|
||||
IsHubOwner Permission = 0x100
|
||||
IsHubAdvisor Permission = 0x200
|
||||
IsCraneController Permission = 0x8000
|
||||
)
|
||||
|
||||
// AuthorizingTerminal is an interface for terminals that support authorization.
|
||||
type AuthorizingTerminal interface {
|
||||
GrantPermission(grant Permission)
|
||||
HasPermission(required Permission) bool
|
||||
}
|
||||
|
||||
// GrantPermission grants the specified permissions to the Terminal.
|
||||
func (t *TerminalBase) GrantPermission(grant Permission) {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
|
||||
t.permission |= grant
|
||||
}
|
||||
|
||||
// HasPermission returns if the Terminal has the specified permission.
|
||||
func (t *TerminalBase) HasPermission(required Permission) bool {
|
||||
t.lock.RLock()
|
||||
defer t.lock.RUnlock()
|
||||
|
||||
return t.permission.Has(required)
|
||||
}
|
||||
|
||||
// Has returns if the permission includes the specified permission.
|
||||
func (p Permission) Has(required Permission) bool {
|
||||
return p&required == required
|
||||
}
|
||||
|
||||
// AddPermissions combines multiple permissions.
|
||||
func AddPermissions(perms ...Permission) Permission {
|
||||
var all Permission
|
||||
for _, p := range perms {
|
||||
all |= p
|
||||
}
|
||||
return all
|
||||
}
|
||||
39
spn/terminal/rate_limit.go
Normal file
39
spn/terminal/rate_limit.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package terminal
|
||||
|
||||
import "time"
|
||||
|
||||
// RateLimiter is a data flow rate limiter.
|
||||
type RateLimiter struct {
|
||||
maxBytesPerSlot uint64
|
||||
slotBytes uint64
|
||||
slotStarted time.Time
|
||||
}
|
||||
|
||||
// NewRateLimiter returns a new rate limiter.
|
||||
// The given MBit/s are transformed to bytes, so giving a multiple of 8 is
|
||||
// advised for accurate results.
|
||||
func NewRateLimiter(mbits uint64) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
maxBytesPerSlot: (mbits / 8) * 1_000_000,
|
||||
slotStarted: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Limit is given the current transferred bytes and blocks until they may be sent.
|
||||
func (rl *RateLimiter) Limit(xferBytes uint64) {
|
||||
// Check if we need to limit transfer if we go over to max bytes per slot.
|
||||
if rl.slotBytes > rl.maxBytesPerSlot {
|
||||
// Wait if we are still within the slot.
|
||||
sinceSlotStart := time.Since(rl.slotStarted)
|
||||
if sinceSlotStart < time.Second {
|
||||
time.Sleep(time.Second - sinceSlotStart)
|
||||
}
|
||||
|
||||
// Reset state for next slot.
|
||||
rl.slotBytes = 0
|
||||
rl.slotStarted = time.Now()
|
||||
}
|
||||
|
||||
// Add new bytes after checking, as first step over the limit is fully using the limit.
|
||||
rl.slotBytes += xferBytes
|
||||
}
|
||||
166
spn/terminal/session.go
Normal file
166
spn/terminal/session.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
const (
|
||||
rateLimitMinOps = 250
|
||||
rateLimitMaxOpsPerSecond = 5
|
||||
|
||||
rateLimitMinSuspicion = 25
|
||||
rateLimitMinPermaSuspicion = rateLimitMinSuspicion * 100
|
||||
rateLimitMaxSuspicionPerSecond = 1
|
||||
|
||||
// Make this big enough to trigger suspicion limit in first blast.
|
||||
concurrencyPoolSize = 30
|
||||
)
|
||||
|
||||
// Session holds terminal metadata for operations.
|
||||
type Session struct {
|
||||
sync.RWMutex
|
||||
|
||||
// Rate Limiting.
|
||||
|
||||
// started holds the unix timestamp in seconds when the session was started.
|
||||
// It is set when the Session is created and may be treated as a constant.
|
||||
started int64
|
||||
|
||||
// opCount is the amount of operations started (and not rate limited by suspicion).
|
||||
opCount atomic.Int64
|
||||
|
||||
// suspicionScore holds a score of suspicious activity.
|
||||
// Every suspicious operations is counted as at least 1.
|
||||
// Rate limited operations because of suspicion are also counted as 1.
|
||||
suspicionScore atomic.Int64
|
||||
|
||||
concurrencyPool chan struct{}
|
||||
}
|
||||
|
||||
// SessionTerminal is an interface for terminals that support authorization.
|
||||
type SessionTerminal interface {
|
||||
GetSession() *Session
|
||||
}
|
||||
|
||||
// SessionAddOn can be inherited by terminals to add support for sessions.
|
||||
type SessionAddOn struct {
|
||||
lock sync.Mutex
|
||||
|
||||
// session holds the terminal session.
|
||||
session *Session
|
||||
}
|
||||
|
||||
// GetSession returns the terminal's session.
|
||||
func (t *SessionAddOn) GetSession() *Session {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
|
||||
// Create session if it does not exist.
|
||||
if t.session == nil {
|
||||
t.session = NewSession()
|
||||
}
|
||||
|
||||
return t.session
|
||||
}
|
||||
|
||||
// NewSession returns a new session.
|
||||
func NewSession() *Session {
|
||||
return &Session{
|
||||
started: time.Now().Unix() - 1, // Ensure a 1 second difference to current time.
|
||||
concurrencyPool: make(chan struct{}, concurrencyPoolSize),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitInfo returns some basic information about the status of the rate limiter.
|
||||
func (s *Session) RateLimitInfo() string {
|
||||
secondsActive := time.Now().Unix() - s.started
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%do/s %ds/s %ds",
|
||||
s.opCount.Load()/secondsActive,
|
||||
s.suspicionScore.Load()/secondsActive,
|
||||
secondsActive,
|
||||
)
|
||||
}
|
||||
|
||||
// RateLimit enforces a rate and suspicion limit.
|
||||
func (s *Session) RateLimit() *Error {
|
||||
secondsActive := time.Now().Unix() - s.started
|
||||
|
||||
// Check the suspicion limit.
|
||||
score := s.suspicionScore.Load()
|
||||
if score > rateLimitMinSuspicion {
|
||||
scorePerSecond := score / secondsActive
|
||||
if scorePerSecond >= rateLimitMaxSuspicionPerSecond {
|
||||
// Add current try to suspicion score.
|
||||
s.suspicionScore.Add(1)
|
||||
|
||||
return ErrRateLimited
|
||||
}
|
||||
|
||||
// Permanently rate limit if suspicion goes over the perma min limit and
|
||||
// the suspicion score is greater than 80% of the operation count.
|
||||
if score > rateLimitMinPermaSuspicion &&
|
||||
score*5 > s.opCount.Load()*4 { // Think: 80*5 == 100*4
|
||||
return ErrRateLimited
|
||||
}
|
||||
}
|
||||
|
||||
// Check the rate limit.
|
||||
count := s.opCount.Add(1)
|
||||
if count > rateLimitMinOps {
|
||||
opsPerSecond := count / secondsActive
|
||||
if opsPerSecond >= rateLimitMaxOpsPerSecond {
|
||||
return ErrRateLimited
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Suspicion Factors.
|
||||
const (
|
||||
SusFactorCommon = 1
|
||||
SusFactorWeirdButOK = 5
|
||||
SusFactorQuiteUnusual = 10
|
||||
SusFactorMustBeMalicious = 100
|
||||
)
|
||||
|
||||
// ReportSuspiciousActivity reports suspicious activity of the terminal.
|
||||
func (s *Session) ReportSuspiciousActivity(factor int64) {
|
||||
s.suspicionScore.Add(factor)
|
||||
}
|
||||
|
||||
// LimitConcurrency limits concurrent executions.
|
||||
// If over the limit, waiting goroutines are selected randomly.
|
||||
// It returns the context error if it was canceled.
|
||||
func (s *Session) LimitConcurrency(ctx context.Context, f func()) error {
|
||||
// Wait for place in pool.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case s.concurrencyPool <- struct{}{}:
|
||||
// We added our entry to the pool, continue with execution.
|
||||
}
|
||||
|
||||
// Drain own spot if pool after execution.
|
||||
defer func() {
|
||||
select {
|
||||
case <-s.concurrencyPool:
|
||||
// Own entry drained.
|
||||
default:
|
||||
// This should never happen, but let's play safe and not deadlock when pool is empty.
|
||||
log.Warningf("spn/session: failed to drain own entry from concurrency pool")
|
||||
}
|
||||
}()
|
||||
|
||||
// Execute and return.
|
||||
f()
|
||||
return nil
|
||||
}
|
||||
94
spn/terminal/session_test.go
Normal file
94
spn/terminal/session_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var tErr *Error
|
||||
s := NewSession()
|
||||
|
||||
// Everything should be okay within the min limit.
|
||||
for i := 0; i < rateLimitMinOps; i++ {
|
||||
tErr = s.RateLimit()
|
||||
if tErr != nil {
|
||||
t.Error("should not rate limit within min limit")
|
||||
}
|
||||
}
|
||||
|
||||
// Somewhere here we should rate limiting.
|
||||
for i := 0; i < rateLimitMaxOpsPerSecond; i++ {
|
||||
tErr = s.RateLimit()
|
||||
}
|
||||
assert.ErrorIs(t, tErr, ErrRateLimited, "should rate limit")
|
||||
}
|
||||
|
||||
func TestSuspicionLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var tErr *Error
|
||||
s := NewSession()
|
||||
|
||||
// Everything should be okay within the min limit.
|
||||
for i := 0; i < rateLimitMinSuspicion; i++ {
|
||||
tErr = s.RateLimit()
|
||||
if tErr != nil {
|
||||
t.Error("should not rate limit within min limit")
|
||||
}
|
||||
s.ReportSuspiciousActivity(SusFactorCommon)
|
||||
}
|
||||
|
||||
// Somewhere here we should rate limiting.
|
||||
for i := 0; i < rateLimitMaxSuspicionPerSecond; i++ {
|
||||
s.ReportSuspiciousActivity(SusFactorCommon)
|
||||
tErr = s.RateLimit()
|
||||
}
|
||||
if tErr == nil {
|
||||
t.Error("should rate limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrencyLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := NewSession()
|
||||
started := time.Now()
|
||||
wg := sync.WaitGroup{}
|
||||
workTime := 1 * time.Millisecond
|
||||
workers := concurrencyPoolSize * 10
|
||||
|
||||
// Start many workers to test concurrency.
|
||||
wg.Add(workers)
|
||||
for i := 0; i < workers; i++ {
|
||||
workerNum := i
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = recover()
|
||||
}()
|
||||
_ = s.LimitConcurrency(context.Background(), func() {
|
||||
time.Sleep(workTime)
|
||||
wg.Done()
|
||||
|
||||
// Panic sometimes.
|
||||
if workerNum%concurrencyPoolSize == 0 {
|
||||
panic("test")
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait and check time needed.
|
||||
wg.Wait()
|
||||
if time.Since(started) < (time.Duration(workers) * workTime / concurrencyPoolSize) {
|
||||
t.Errorf("workers were too quick - only took %s", time.Since(started))
|
||||
} else {
|
||||
t.Logf("workers were correctly limited - took %s", time.Since(started))
|
||||
}
|
||||
}
|
||||
909
spn/terminal/terminal.go
Normal file
909
spn/terminal/terminal.go
Normal file
@@ -0,0 +1,909 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/jess"
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
const (
|
||||
timeoutTicks = 5
|
||||
|
||||
clientTerminalAbandonTimeout = 15 * time.Second
|
||||
serverTerminalAbandonTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
// Terminal represents a terminal.
|
||||
type Terminal interface { //nolint:golint // Being explicit is helpful here.
|
||||
// ID returns the terminal ID.
|
||||
ID() uint32
|
||||
// Ctx returns the terminal context.
|
||||
Ctx() context.Context
|
||||
|
||||
// Deliver delivers a message to the terminal.
|
||||
// Should not be overridden by implementations.
|
||||
Deliver(msg *Msg) *Error
|
||||
// Send is used by others to send a message through the terminal.
|
||||
// Should not be overridden by implementations.
|
||||
Send(msg *Msg, timeout time.Duration) *Error
|
||||
// Flush sends all messages waiting in the terminal.
|
||||
// Should not be overridden by implementations.
|
||||
Flush(timeout time.Duration)
|
||||
|
||||
// StartOperation starts the given operation by assigning it an ID and sending the given operation initialization data.
|
||||
// Should not be overridden by implementations.
|
||||
StartOperation(op Operation, initData *container.Container, timeout time.Duration) *Error
|
||||
// StopOperation stops the given operation.
|
||||
// Should not be overridden by implementations.
|
||||
StopOperation(op Operation, err *Error)
|
||||
|
||||
// Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon().
|
||||
// Should not be overridden by implementations.
|
||||
Abandon(err *Error)
|
||||
// HandleAbandon gives the terminal the ability to cleanly shut down.
|
||||
// The terminal is still fully functional at this point.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
// Meant to be overridden by implementations.
|
||||
HandleAbandon(err *Error) (errorToSend *Error)
|
||||
// HandleDestruction gives the terminal the ability to clean up.
|
||||
// The terminal has already fully shut down at this point.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
// Meant to be overridden by implementations.
|
||||
HandleDestruction(err *Error)
|
||||
|
||||
// FmtID formats the terminal ID (including parent IDs).
|
||||
// May be overridden by implementations.
|
||||
FmtID() string
|
||||
}
|
||||
|
||||
// TerminalBase contains the basic functions of a terminal.
|
||||
type TerminalBase struct { //nolint:golint,maligned // Being explicit is helpful here.
|
||||
// TODO: Fix maligned.
|
||||
Terminal // Interface check.
|
||||
|
||||
lock sync.RWMutex
|
||||
|
||||
// id is the underlying id of the Terminal.
|
||||
id uint32
|
||||
// parentID is the id of the parent component.
|
||||
parentID string
|
||||
|
||||
// ext holds the extended terminal so that the base terminal can access custom functions.
|
||||
ext Terminal
|
||||
// sendQueue holds message to be sent.
|
||||
sendQueue chan *Msg
|
||||
// flowControl holds the flow control system.
|
||||
flowControl FlowControl
|
||||
// upstream represents the upstream (parent) terminal.
|
||||
upstream Upstream
|
||||
|
||||
// deliverProxy is populated with the configured deliver function
|
||||
deliverProxy func(msg *Msg) *Error
|
||||
// recvProxy is populated with the configured recv function
|
||||
recvProxy func() <-chan *Msg
|
||||
|
||||
// ctx is the context of the Terminal.
|
||||
ctx context.Context
|
||||
// cancelCtx cancels ctx.
|
||||
cancelCtx context.CancelFunc
|
||||
|
||||
// waitForFlush signifies if sending should be delayed until the next call
|
||||
// to Flush()
|
||||
waitForFlush *abool.AtomicBool
|
||||
// flush is used to send a finish function to the handler, which will write
|
||||
// all pending messages and then call the received function.
|
||||
flush chan func()
|
||||
// idleTicker ticks for increasing and checking the idle counter.
|
||||
idleTicker *time.Ticker
|
||||
// idleCounter counts the ticks the terminal has been idle.
|
||||
idleCounter *uint32
|
||||
|
||||
// jession is the jess session used for encryption.
|
||||
jession *jess.Session
|
||||
// jessionLock locks jession.
|
||||
jessionLock sync.Mutex
|
||||
// encryptionReady is set when the encryption is ready for sending messages.
|
||||
encryptionReady chan struct{}
|
||||
// identity is the identity used by a remote Terminal.
|
||||
identity *cabin.Identity
|
||||
|
||||
// operations holds references to all active operations that require persistence.
|
||||
operations map[uint32]Operation
|
||||
// nextOpID holds the next operation ID.
|
||||
nextOpID *uint32
|
||||
// permission holds the permissions of the terminal.
|
||||
permission Permission
|
||||
|
||||
// opts holds the terminal options. It must not be modified after the terminal
|
||||
// has started.
|
||||
opts *TerminalOpts
|
||||
|
||||
// lastUnknownOpID holds the operation ID of the last data message received
|
||||
// for an unknown operation ID.
|
||||
lastUnknownOpID uint32
|
||||
// lastUnknownOpMsgs holds the amount of continuous data messages received
|
||||
// for the operation ID in lastUnknownOpID.
|
||||
lastUnknownOpMsgs uint32
|
||||
|
||||
// Abandoning indicates if the Terminal is being abandoned. The main handlers
|
||||
// will keep running until the context has been canceled by the abandon
|
||||
// procedure.
|
||||
// No new operations should be started.
|
||||
// Whoever initiates the abandoning must also start the abandon procedure.
|
||||
Abandoning *abool.AtomicBool
|
||||
}
|
||||
|
||||
func createTerminalBase(
|
||||
ctx context.Context,
|
||||
id uint32,
|
||||
parentID string,
|
||||
remote bool,
|
||||
initMsg *TerminalOpts,
|
||||
upstream Upstream,
|
||||
) (*TerminalBase, *Error) {
|
||||
t := &TerminalBase{
|
||||
id: id,
|
||||
parentID: parentID,
|
||||
sendQueue: make(chan *Msg),
|
||||
upstream: upstream,
|
||||
waitForFlush: abool.New(),
|
||||
flush: make(chan func()),
|
||||
idleTicker: time.NewTicker(time.Minute),
|
||||
idleCounter: new(uint32),
|
||||
encryptionReady: make(chan struct{}),
|
||||
operations: make(map[uint32]Operation),
|
||||
nextOpID: new(uint32),
|
||||
opts: initMsg,
|
||||
Abandoning: abool.New(),
|
||||
}
|
||||
// Stop ticking to disable timeout.
|
||||
t.idleTicker.Stop()
|
||||
// Shift next operation ID if remote.
|
||||
if remote {
|
||||
atomic.AddUint32(t.nextOpID, 4)
|
||||
}
|
||||
// Create context.
|
||||
t.ctx, t.cancelCtx = context.WithCancel(ctx)
|
||||
|
||||
// Create flow control.
|
||||
switch initMsg.FlowControl {
|
||||
case FlowControlDFQ:
|
||||
t.flowControl = NewDuplexFlowQueue(t.Ctx(), initMsg.FlowControlSize, t.submitToUpstream)
|
||||
t.deliverProxy = t.flowControl.Deliver
|
||||
t.recvProxy = t.flowControl.Receive
|
||||
case FlowControlNone:
|
||||
deliver := make(chan *Msg, initMsg.FlowControlSize)
|
||||
t.deliverProxy = MakeDirectDeliveryDeliverFunc(ctx, deliver)
|
||||
t.recvProxy = MakeDirectDeliveryRecvFunc(deliver)
|
||||
case FlowControlDefault:
|
||||
fallthrough
|
||||
default:
|
||||
return nil, ErrInternalError.With("unknown flow control type %d", initMsg.FlowControl)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// ID returns the Terminal's ID.
|
||||
func (t *TerminalBase) ID() uint32 {
|
||||
return t.id
|
||||
}
|
||||
|
||||
// Ctx returns the Terminal's context.
|
||||
func (t *TerminalBase) Ctx() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
// SetTerminalExtension sets the Terminal's extension. This function is not
|
||||
// guarded and may only be used during initialization.
|
||||
func (t *TerminalBase) SetTerminalExtension(ext Terminal) {
|
||||
t.ext = ext
|
||||
}
|
||||
|
||||
// SetTimeout sets the Terminal's idle timeout duration.
|
||||
// It is broken down into slots internally.
|
||||
func (t *TerminalBase) SetTimeout(d time.Duration) {
|
||||
t.idleTicker.Reset(d / timeoutTicks)
|
||||
}
|
||||
|
||||
// Deliver on TerminalBase only exists to conform to the interface. It must be
|
||||
// overridden by an actual implementation.
|
||||
func (t *TerminalBase) Deliver(msg *Msg) *Error {
|
||||
// Deliver via configured proxy.
|
||||
err := t.deliverProxy(msg)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// StartWorkers starts the necessary workers to operate the Terminal.
|
||||
func (t *TerminalBase) StartWorkers(m *modules.Module, terminalName string) {
|
||||
// Start terminal workers.
|
||||
m.StartWorker(terminalName+" handler", t.Handler)
|
||||
m.StartWorker(terminalName+" sender", t.Sender)
|
||||
|
||||
// Start any flow control workers.
|
||||
if t.flowControl != nil {
|
||||
t.flowControl.StartWorkers(m, terminalName)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
sendThresholdLength = 100 // bytes
|
||||
sendMaxLength = 4000 // bytes
|
||||
sendThresholdMaxWait = 20 * time.Millisecond
|
||||
)
|
||||
|
||||
// Handler receives and handles messages and must be started as a worker in the
|
||||
// module where the Terminal is used.
|
||||
func (t *TerminalBase) Handler(_ context.Context) error {
|
||||
defer t.Abandon(ErrInternalError.With("handler died"))
|
||||
|
||||
var msg *Msg
|
||||
defer msg.Finish()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
// Call Abandon just in case.
|
||||
// Normally, only the StopProcedure function should cancel the context.
|
||||
t.Abandon(nil)
|
||||
return nil // Controlled worker exit.
|
||||
|
||||
case <-t.idleTicker.C:
|
||||
// If nothing happens for a while, end the session.
|
||||
if atomic.AddUint32(t.idleCounter, 1) > timeoutTicks {
|
||||
// Abandon the terminal and reset the counter.
|
||||
t.Abandon(ErrNoActivity)
|
||||
atomic.StoreUint32(t.idleCounter, 0)
|
||||
}
|
||||
|
||||
case msg = <-t.recvProxy():
|
||||
err := t.handleReceive(msg)
|
||||
if err != nil {
|
||||
t.Abandon(err.Wrap("failed to handle"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register activity.
|
||||
atomic.StoreUint32(t.idleCounter, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// submit is used to send message from the terminal to upstream, including
|
||||
// going through flow control, if configured.
|
||||
// This function should be used to send message from the terminal to upstream.
|
||||
func (t *TerminalBase) submit(msg *Msg, timeout time.Duration) {
|
||||
// Submit directly if no flow control is configured.
|
||||
if t.flowControl == nil {
|
||||
t.submitToUpstream(msg, timeout)
|
||||
return
|
||||
}
|
||||
|
||||
// Hand over to flow control.
|
||||
err := t.flowControl.Send(msg, timeout)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
t.Abandon(err.Wrap("failed to submit to flow control"))
|
||||
}
|
||||
}
|
||||
|
||||
// submitToUpstream is used to directly submit messages to upstream.
|
||||
// This function should only be used by the flow control or submit function.
|
||||
func (t *TerminalBase) submitToUpstream(msg *Msg, timeout time.Duration) {
|
||||
// Add terminal ID as flow ID.
|
||||
msg.FlowID = t.ID()
|
||||
|
||||
// Debug unit leaks.
|
||||
msg.debugWithCaller(2)
|
||||
|
||||
// Submit to upstream.
|
||||
err := t.upstream.Send(msg, timeout)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
t.Abandon(err.Wrap("failed to submit to upstream"))
|
||||
}
|
||||
}
|
||||
|
||||
// Sender handles sending messages and must be started as a worker in the
|
||||
// module where the Terminal is used.
|
||||
func (t *TerminalBase) Sender(_ context.Context) error {
|
||||
// Don't send messages, if the encryption is net yet set up.
|
||||
// The server encryption session is only initialized with the first
|
||||
// operative message, not on Terminal creation.
|
||||
if t.opts.Encrypt {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
// Call Abandon just in case.
|
||||
// Normally, the only the StopProcedure function should cancel the context.
|
||||
t.Abandon(nil)
|
||||
return nil // Controlled worker exit.
|
||||
case <-t.encryptionReady:
|
||||
}
|
||||
}
|
||||
|
||||
// Be sure to call Stop even in case of sudden death.
|
||||
defer t.Abandon(ErrInternalError.With("sender died"))
|
||||
|
||||
var msgBufferMsg *Msg
|
||||
var msgBufferLen int
|
||||
var msgBufferLimitReached bool
|
||||
var sendMsgs bool
|
||||
var sendMaxWait *time.Timer
|
||||
var flushFinished func()
|
||||
|
||||
// Finish any current unit when returning.
|
||||
defer msgBufferMsg.Finish()
|
||||
|
||||
// Only receive message when not sending the current msg buffer.
|
||||
sendQueueOpMsgs := func() <-chan *Msg {
|
||||
// Don't handle more messages, if the buffer is full.
|
||||
if msgBufferLimitReached {
|
||||
return nil
|
||||
}
|
||||
return t.sendQueue
|
||||
}
|
||||
|
||||
// Only wait for sending slot when the current msg buffer is ready to be sent.
|
||||
readyToSend := func() <-chan struct{} {
|
||||
switch {
|
||||
case !sendMsgs:
|
||||
// Wait until there is something to send.
|
||||
return nil
|
||||
case t.flowControl != nil:
|
||||
// Let flow control decide when we are ready.
|
||||
return t.flowControl.ReadyToSend()
|
||||
default:
|
||||
// Always ready.
|
||||
return ready
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate current max wait time to send the msg buffer.
|
||||
getSendMaxWait := func() <-chan time.Time {
|
||||
if sendMaxWait != nil {
|
||||
return sendMaxWait.C
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
handling:
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
// Call Stop just in case.
|
||||
// Normally, the only the StopProcedure function should cancel the context.
|
||||
t.Abandon(nil)
|
||||
return nil // Controlled worker exit.
|
||||
|
||||
case <-t.idleTicker.C:
|
||||
// If nothing happens for a while, end the session.
|
||||
if atomic.AddUint32(t.idleCounter, 1) > timeoutTicks {
|
||||
// Abandon the terminal and reset the counter.
|
||||
t.Abandon(ErrNoActivity)
|
||||
atomic.StoreUint32(t.idleCounter, 0)
|
||||
}
|
||||
|
||||
case msg := <-sendQueueOpMsgs():
|
||||
if msg == nil {
|
||||
continue handling
|
||||
}
|
||||
|
||||
// Add unit to buffer unit, or use it as new buffer.
|
||||
if msgBufferMsg != nil {
|
||||
// Pack, append and finish additional message.
|
||||
msgBufferMsg.Consume(msg)
|
||||
} else {
|
||||
// Pack operation message.
|
||||
msg.Pack()
|
||||
// Convert to message of terminal.
|
||||
msgBufferMsg = msg
|
||||
msgBufferMsg.FlowID = t.ID()
|
||||
msgBufferMsg.Type = MsgTypeData
|
||||
}
|
||||
msgBufferLen += msg.Data.Length()
|
||||
|
||||
// Check if there is enough data to hit the sending threshold.
|
||||
if msgBufferLen >= sendThresholdLength {
|
||||
sendMsgs = true
|
||||
} else if sendMaxWait == nil && t.waitForFlush.IsNotSet() {
|
||||
sendMaxWait = time.NewTimer(sendThresholdMaxWait)
|
||||
}
|
||||
|
||||
// Check if we have reached the maximum buffer size.
|
||||
if msgBufferLen >= sendMaxLength {
|
||||
msgBufferLimitReached = true
|
||||
}
|
||||
|
||||
// Register activity.
|
||||
atomic.StoreUint32(t.idleCounter, 0)
|
||||
|
||||
case <-getSendMaxWait():
|
||||
// The timer for waiting for more data has ended.
|
||||
// Send all available data if not forced to wait for a flush.
|
||||
if t.waitForFlush.IsNotSet() {
|
||||
sendMsgs = true
|
||||
}
|
||||
|
||||
case newFlushFinishedFn := <-t.flush:
|
||||
// We are flushing - stop waiting.
|
||||
t.waitForFlush.UnSet()
|
||||
|
||||
// Signal immediately if msg buffer is empty.
|
||||
if msgBufferLen == 0 {
|
||||
newFlushFinishedFn()
|
||||
} else {
|
||||
// If there already is a flush finished function, stack them.
|
||||
if flushFinished != nil {
|
||||
stackedFlushFinishFn := flushFinished
|
||||
flushFinished = func() {
|
||||
stackedFlushFinishFn()
|
||||
newFlushFinishedFn()
|
||||
}
|
||||
} else {
|
||||
flushFinished = newFlushFinishedFn
|
||||
}
|
||||
}
|
||||
|
||||
// Force sending data now.
|
||||
sendMsgs = true
|
||||
|
||||
case <-readyToSend():
|
||||
// Reset sending flags.
|
||||
sendMsgs = false
|
||||
msgBufferLimitReached = false
|
||||
|
||||
// Send if there is anything to send.
|
||||
var err *Error
|
||||
if msgBufferLen > 0 {
|
||||
// Update message type to include priority.
|
||||
if msgBufferMsg.Type == MsgTypeData &&
|
||||
msgBufferMsg.Unit.IsHighPriority() &&
|
||||
t.opts.UsePriorityDataMsgs {
|
||||
msgBufferMsg.Type = MsgTypePriorityData
|
||||
}
|
||||
|
||||
// Wait for clearance on initial msg only.
|
||||
msgBufferMsg.Unit.WaitForSlot()
|
||||
|
||||
err = t.sendOpMsgs(msgBufferMsg)
|
||||
}
|
||||
|
||||
// Reset buffer.
|
||||
msgBufferMsg = nil
|
||||
msgBufferLen = 0
|
||||
|
||||
// Reset send wait timer.
|
||||
if sendMaxWait != nil {
|
||||
sendMaxWait.Stop()
|
||||
sendMaxWait = nil
|
||||
}
|
||||
|
||||
// Check if we are flushing and need to notify.
|
||||
if flushFinished != nil {
|
||||
flushFinished()
|
||||
flushFinished = nil
|
||||
}
|
||||
|
||||
// Handle error after state updates.
|
||||
if err != nil {
|
||||
t.Abandon(err.With("failed to send"))
|
||||
continue handling
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForFlush makes the terminal pause all sending until the next call to
|
||||
// Flush().
|
||||
func (t *TerminalBase) WaitForFlush() {
|
||||
t.waitForFlush.Set()
|
||||
}
|
||||
|
||||
// Flush sends all data waiting to be sent.
|
||||
func (t *TerminalBase) Flush(timeout time.Duration) {
|
||||
// Create channel and function for notifying.
|
||||
wait := make(chan struct{})
|
||||
finished := func() {
|
||||
close(wait)
|
||||
}
|
||||
// Request flush and return when stopping.
|
||||
select {
|
||||
case t.flush <- finished:
|
||||
case <-t.Ctx().Done():
|
||||
return
|
||||
case <-TimedOut(timeout):
|
||||
return
|
||||
}
|
||||
// Wait for flush to finish and return when stopping.
|
||||
select {
|
||||
case <-wait:
|
||||
case <-t.Ctx().Done():
|
||||
return
|
||||
case <-TimedOut(timeout):
|
||||
return
|
||||
}
|
||||
|
||||
// Flush flow control, if configured.
|
||||
if t.flowControl != nil {
|
||||
t.flowControl.Flush(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalBase) encrypt(c *container.Container) (*container.Container, *Error) {
|
||||
if !t.opts.Encrypt {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
t.jessionLock.Lock()
|
||||
defer t.jessionLock.Unlock()
|
||||
|
||||
letter, err := t.jession.Close(c.CompileData())
|
||||
if err != nil {
|
||||
return nil, ErrIntegrity.With("failed to encrypt: %w", err)
|
||||
}
|
||||
|
||||
encryptedData, err := letter.ToWire()
|
||||
if err != nil {
|
||||
return nil, ErrInternalError.With("failed to pack letter: %w", err)
|
||||
}
|
||||
|
||||
return encryptedData, nil
|
||||
}
|
||||
|
||||
func (t *TerminalBase) decrypt(c *container.Container) (*container.Container, *Error) {
|
||||
if !t.opts.Encrypt {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
t.jessionLock.Lock()
|
||||
defer t.jessionLock.Unlock()
|
||||
|
||||
letter, err := jess.LetterFromWire(c)
|
||||
if err != nil {
|
||||
return nil, ErrMalformedData.With("failed to parse letter: %w", err)
|
||||
}
|
||||
|
||||
// Setup encryption if not yet done.
|
||||
if t.jession == nil {
|
||||
if t.identity == nil {
|
||||
return nil, ErrInternalError.With("missing identity for setting up incoming encryption")
|
||||
}
|
||||
|
||||
// Create jess session.
|
||||
t.jession, err = letter.WireCorrespondence(t.identity)
|
||||
if err != nil {
|
||||
return nil, ErrIntegrity.With("failed to initialize incoming encryption: %w", err)
|
||||
}
|
||||
|
||||
// Don't need that anymore.
|
||||
t.identity = nil
|
||||
|
||||
// Encryption is ready for sending.
|
||||
close(t.encryptionReady)
|
||||
}
|
||||
|
||||
decryptedData, err := t.jession.Open(letter)
|
||||
if err != nil {
|
||||
return nil, ErrIntegrity.With("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return container.New(decryptedData), nil
|
||||
}
|
||||
|
||||
func (t *TerminalBase) handleReceive(msg *Msg) *Error {
|
||||
msg.Unit.WaitForSlot()
|
||||
defer msg.Finish()
|
||||
|
||||
// Debugging:
|
||||
// log.Errorf("spn/terminal %s handling tmsg: %s", t.FmtID(), spew.Sdump(c.CompileData()))
|
||||
|
||||
// Check if message is empty. This will be the case if a message was only
|
||||
// for updated the available space of the flow queue.
|
||||
if !msg.Data.HoldsData() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt if enabled.
|
||||
var tErr *Error
|
||||
msg.Data, tErr = t.decrypt(msg.Data)
|
||||
if tErr != nil {
|
||||
return tErr
|
||||
}
|
||||
|
||||
// Handle operation messages.
|
||||
for msg.Data.HoldsData() {
|
||||
// Get next message length.
|
||||
msgLength, err := msg.Data.GetNextN32()
|
||||
if err != nil {
|
||||
return ErrMalformedData.With("failed to get operation msg length: %w", err)
|
||||
}
|
||||
if msgLength == 0 {
|
||||
// Remainder is padding.
|
||||
// Padding can only be at the end of the segment.
|
||||
t.handlePaddingMsg(msg.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get op msg data.
|
||||
msgData, err := msg.Data.GetAsContainer(int(msgLength))
|
||||
if err != nil {
|
||||
return ErrMalformedData.With("failed to get operation msg data (%d/%d bytes): %w", msg.Data.Length(), msgLength, err)
|
||||
}
|
||||
|
||||
// Handle op msg.
|
||||
if handleErr := t.handleOpMsg(msgData); handleErr != nil {
|
||||
return handleErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TerminalBase) handleOpMsg(data *container.Container) *Error {
|
||||
// Debugging:
|
||||
// log.Errorf("spn/terminal %s handling opmsg: %s", t.FmtID(), spew.Sdump(data.CompileData()))
|
||||
|
||||
// Parse message operation id, type.
|
||||
opID, msgType, err := ParseIDType(data)
|
||||
if err != nil {
|
||||
return ErrMalformedData.With("failed to parse operation msg id/type: %w", err)
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case MsgTypeInit:
|
||||
t.handleOperationStart(opID, data)
|
||||
|
||||
case MsgTypeData, MsgTypePriorityData:
|
||||
op, ok := t.GetActiveOp(opID)
|
||||
if ok && !op.Stopped() {
|
||||
// Create message from data.
|
||||
msg := NewEmptyMsg()
|
||||
msg.FlowID = opID
|
||||
msg.Type = msgType
|
||||
msg.Data = data
|
||||
if msg.Type == MsgTypePriorityData {
|
||||
msg.Unit.MakeHighPriority()
|
||||
}
|
||||
|
||||
// Deliver message to operation.
|
||||
tErr := op.Deliver(msg)
|
||||
if tErr != nil {
|
||||
// Also stop on "success" errors!
|
||||
msg.Finish()
|
||||
t.StopOperation(op, tErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If an active op is not found, this is likely just left-overs from a
|
||||
// stopped or failed operation.
|
||||
// log.Tracef("spn/terminal: %s received data msg for unknown op %d", fmtTerminalID(t.parentID, t.id), opID)
|
||||
|
||||
// Send a stop error if this happens too often.
|
||||
if opID == t.lastUnknownOpID {
|
||||
// OpID is the same as last time.
|
||||
t.lastUnknownOpMsgs++
|
||||
|
||||
// Log an warning (via StopOperation) and send a stop message every thousand.
|
||||
if t.lastUnknownOpMsgs%1000 == 0 {
|
||||
t.StopOperation(newUnknownOp(opID, ""), ErrUnknownOperationID.With("received %d unsolicited data msgs", t.lastUnknownOpMsgs))
|
||||
}
|
||||
|
||||
// TODO: Abandon terminal at over 10000?
|
||||
} else {
|
||||
// OpID changed, set new ID and reset counter.
|
||||
t.lastUnknownOpID = opID
|
||||
t.lastUnknownOpMsgs = 1
|
||||
}
|
||||
|
||||
case MsgTypeStop:
|
||||
// Parse received error.
|
||||
opErr, parseErr := ParseExternalError(data.CompileData())
|
||||
if parseErr != nil {
|
||||
log.Warningf("spn/terminal: %s failed to parse stop error: %s", fmtTerminalID(t.parentID, t.id), parseErr)
|
||||
opErr = ErrUnknownError.AsExternal()
|
||||
}
|
||||
|
||||
// End operation.
|
||||
op, ok := t.GetActiveOp(opID)
|
||||
if ok {
|
||||
t.StopOperation(op, opErr)
|
||||
} else {
|
||||
log.Tracef("spn/terminal: %s received stop msg for unknown op %d", fmtTerminalID(t.parentID, t.id), opID)
|
||||
}
|
||||
|
||||
default:
|
||||
log.Warningf("spn/terminal: %s received unexpected message type: %d", t.FmtID(), msgType)
|
||||
return ErrUnexpectedMsgType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TerminalBase) handlePaddingMsg(c *container.Container) {
|
||||
padding := c.GetAll()
|
||||
if len(padding) > 0 {
|
||||
rngFeeder.SupplyEntropyIfNeeded(padding, len(padding))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalBase) sendOpMsgs(msg *Msg) *Error {
|
||||
msg.Unit.WaitForSlot()
|
||||
|
||||
// Add Padding if needed.
|
||||
if t.opts.Padding > 0 {
|
||||
paddingNeeded := (int(t.opts.Padding) - msg.Data.Length()) % int(t.opts.Padding)
|
||||
if paddingNeeded > 0 {
|
||||
// Add padding message header.
|
||||
msg.Data.Append([]byte{0})
|
||||
paddingNeeded--
|
||||
|
||||
// Add needed padding data.
|
||||
if paddingNeeded > 0 {
|
||||
padding, err := rng.Bytes(paddingNeeded)
|
||||
if err != nil {
|
||||
log.Debugf("spn/terminal: %s failed to get random data, using zeros instead", t.FmtID())
|
||||
padding = make([]byte, paddingNeeded)
|
||||
}
|
||||
msg.Data.Append(padding)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt operative data.
|
||||
var tErr *Error
|
||||
msg.Data, tErr = t.encrypt(msg.Data)
|
||||
if tErr != nil {
|
||||
return tErr
|
||||
}
|
||||
|
||||
// Send data.
|
||||
t.submit(msg, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon().
|
||||
// Should not be overridden by implementations.
|
||||
func (t *TerminalBase) Abandon(err *Error) {
|
||||
if t.Abandoning.SetToIf(false, true) {
|
||||
module.StartWorker("terminal abandon procedure", func(_ context.Context) error {
|
||||
t.handleAbandonProcedure(err)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAbandon gives the terminal the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
// Meant to be overridden by implementations.
|
||||
func (t *TerminalBase) HandleAbandon(err *Error) (errorToSend *Error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleDestruction gives the terminal the ability to clean up.
|
||||
// The terminal has already fully shut down at this point.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
// Meant to be overridden by implementations.
|
||||
func (t *TerminalBase) HandleDestruction(err *Error) {}
|
||||
|
||||
func (t *TerminalBase) handleAbandonProcedure(err *Error) {
|
||||
// End all operations.
|
||||
for _, op := range t.allOps() {
|
||||
t.StopOperation(op, nil)
|
||||
}
|
||||
|
||||
// Prepare timeouts for waiting for ops.
|
||||
timeout := clientTerminalAbandonTimeout
|
||||
if conf.PublicHub() {
|
||||
timeout = serverTerminalAbandonTimeout
|
||||
}
|
||||
checkTicker := time.NewTicker(50 * time.Millisecond)
|
||||
defer checkTicker.Stop()
|
||||
abortWaiting := time.After(timeout)
|
||||
|
||||
// Wait for all operations to end.
|
||||
waitForOps:
|
||||
for {
|
||||
select {
|
||||
case <-checkTicker.C:
|
||||
if t.GetActiveOpCount() <= 0 {
|
||||
break waitForOps
|
||||
}
|
||||
case <-abortWaiting:
|
||||
log.Warningf(
|
||||
"spn/terminal: terminal %s is continuing shutdown with %d active operations",
|
||||
t.FmtID(),
|
||||
t.GetActiveOpCount(),
|
||||
)
|
||||
break waitForOps
|
||||
}
|
||||
}
|
||||
|
||||
// Call operation stop handle function for proper shutdown cleaning up.
|
||||
if t.ext != nil {
|
||||
err = t.ext.HandleAbandon(err)
|
||||
}
|
||||
|
||||
// Send error to the connected Operation, if the error is internal.
|
||||
if !err.IsExternal() {
|
||||
if err == nil {
|
||||
err = ErrStopping
|
||||
}
|
||||
|
||||
msg := NewMsg(err.Pack())
|
||||
msg.FlowID = t.ID()
|
||||
msg.Type = MsgTypeStop
|
||||
t.submit(msg, 1*time.Second)
|
||||
}
|
||||
|
||||
// If terminal was ended locally, send all data before abandoning.
|
||||
// If terminal was ended remotely, don't bother sending remaining data.
|
||||
if !err.IsExternal() {
|
||||
// Flushing could mean sending a full buffer of 50000 packets.
|
||||
t.Flush(5 * time.Minute)
|
||||
}
|
||||
|
||||
// Stop all other connected workers.
|
||||
t.cancelCtx()
|
||||
t.idleTicker.Stop()
|
||||
|
||||
// Call operation destruction handle function for proper shutdown cleaning up.
|
||||
if t.ext != nil {
|
||||
t.ext.HandleDestruction(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalBase) allOps() []Operation {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
|
||||
ops := make([]Operation, 0, len(t.operations))
|
||||
for _, op := range t.operations {
|
||||
ops = append(ops, op)
|
||||
}
|
||||
|
||||
return ops
|
||||
}
|
||||
|
||||
// MakeDirectDeliveryDeliverFunc creates a submit upstream function with the
|
||||
// given delivery channel.
|
||||
func MakeDirectDeliveryDeliverFunc(
|
||||
ctx context.Context,
|
||||
deliver chan *Msg,
|
||||
) func(c *Msg) *Error {
|
||||
return func(c *Msg) *Error {
|
||||
select {
|
||||
case deliver <- c:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ErrStopping
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MakeDirectDeliveryRecvFunc makes a delivery receive function with the given
|
||||
// delivery channel.
|
||||
func MakeDirectDeliveryRecvFunc(
|
||||
deliver chan *Msg,
|
||||
) func() <-chan *Msg {
|
||||
return func() <-chan *Msg {
|
||||
return deliver
|
||||
}
|
||||
}
|
||||
311
spn/terminal/terminal_test.go
Normal file
311
spn/terminal/terminal_test.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
func TestTerminals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
identity, erro := cabin.CreateIdentity(module.Ctx, "test")
|
||||
if erro != nil {
|
||||
t.Fatalf("failed to create identity: %s", erro)
|
||||
}
|
||||
|
||||
// Test without and with encryption.
|
||||
for _, encrypt := range []bool{false, true} {
|
||||
// Test with different flow controls.
|
||||
for _, fc := range []struct {
|
||||
flowControl FlowControlType
|
||||
flowControlSize uint32
|
||||
}{
|
||||
{
|
||||
flowControl: FlowControlNone,
|
||||
flowControlSize: 5,
|
||||
},
|
||||
{
|
||||
flowControl: FlowControlDFQ,
|
||||
flowControlSize: defaultTestQueueSize,
|
||||
},
|
||||
} {
|
||||
// Run tests with combined options.
|
||||
testTerminals(t, identity, &TerminalOpts{
|
||||
Encrypt: encrypt,
|
||||
Padding: defaultTestPadding,
|
||||
FlowControl: fc.flowControl,
|
||||
FlowControlSize: fc.flowControlSize,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testTerminals(t *testing.T, identity *cabin.Identity, terminalOpts *TerminalOpts) {
|
||||
t.Helper()
|
||||
|
||||
// Prepare encryption.
|
||||
var dstHub *hub.Hub
|
||||
if terminalOpts.Encrypt {
|
||||
dstHub = identity.Hub
|
||||
} else {
|
||||
identity = nil
|
||||
}
|
||||
|
||||
// Create test terminals.
|
||||
var term1 *TestTerminal
|
||||
var term2 *TestTerminal
|
||||
var initData *container.Container
|
||||
var err *Error
|
||||
term1, initData, err = NewLocalTestTerminal(
|
||||
module.Ctx, 127, "c1", dstHub, terminalOpts, createForwardingUpstream(
|
||||
t, "c1", "c2", func(msg *Msg) *Error {
|
||||
return term2.Deliver(msg)
|
||||
},
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create local terminal: %s", err)
|
||||
}
|
||||
term2, _, err = NewRemoteTestTerminal(
|
||||
module.Ctx, 127, "c2", identity, initData, createForwardingUpstream(
|
||||
t, "c2", "c1", func(msg *Msg) *Error {
|
||||
return term1.Deliver(msg)
|
||||
},
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create remote terminal: %s", err)
|
||||
}
|
||||
|
||||
// Start testing with counters.
|
||||
countToQueueSize := uint64(terminalOpts.FlowControlSize)
|
||||
optionsSuffix := fmt.Sprintf(
|
||||
"encrypt=%v,flowType=%d",
|
||||
terminalOpts.Encrypt,
|
||||
terminalOpts.FlowControl,
|
||||
)
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "onlyup-flushing-waiting:" + optionsSuffix,
|
||||
flush: true,
|
||||
serverCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: sendThresholdMaxWait * 2,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "onlyup-waiting:" + optionsSuffix,
|
||||
serverCountTo: 10,
|
||||
waitBetweenMsgs: sendThresholdMaxWait * 2,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "onlyup-flushing:" + optionsSuffix,
|
||||
flush: true,
|
||||
serverCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: time.Millisecond,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "onlyup:" + optionsSuffix,
|
||||
serverCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: time.Millisecond,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "onlydown-flushing-waiting:" + optionsSuffix,
|
||||
flush: true,
|
||||
clientCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: sendThresholdMaxWait * 2,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "onlydown-waiting:" + optionsSuffix,
|
||||
clientCountTo: 10,
|
||||
waitBetweenMsgs: sendThresholdMaxWait * 2,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "onlydown-flushing:" + optionsSuffix,
|
||||
flush: true,
|
||||
clientCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: time.Millisecond,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "onlydown:" + optionsSuffix,
|
||||
clientCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: time.Millisecond,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "twoway-flushing-waiting:" + optionsSuffix,
|
||||
flush: true,
|
||||
clientCountTo: countToQueueSize * 2,
|
||||
serverCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: sendThresholdMaxWait * 2,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "twoway-waiting:" + optionsSuffix,
|
||||
flush: true,
|
||||
clientCountTo: 10,
|
||||
serverCountTo: 10,
|
||||
waitBetweenMsgs: sendThresholdMaxWait * 2,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "twoway-flushing:" + optionsSuffix,
|
||||
flush: true,
|
||||
clientCountTo: countToQueueSize * 2,
|
||||
serverCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: time.Millisecond,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "twoway:" + optionsSuffix,
|
||||
clientCountTo: countToQueueSize * 2,
|
||||
serverCountTo: countToQueueSize * 2,
|
||||
waitBetweenMsgs: time.Millisecond,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "stresstest-down:" + optionsSuffix,
|
||||
clientCountTo: countToQueueSize * 1000,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "stresstest-up:" + optionsSuffix,
|
||||
serverCountTo: countToQueueSize * 1000,
|
||||
})
|
||||
|
||||
testTerminalWithCounters(t, term1, term2, &testWithCounterOpts{
|
||||
testName: "stresstest-duplex:" + optionsSuffix,
|
||||
clientCountTo: countToQueueSize * 1000,
|
||||
serverCountTo: countToQueueSize * 1000,
|
||||
})
|
||||
|
||||
// Clean up.
|
||||
term1.Abandon(nil)
|
||||
term2.Abandon(nil)
|
||||
|
||||
// Give some time for the last log messages and clean up.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func createForwardingUpstream(t *testing.T, srcName, dstName string, deliverFunc func(*Msg) *Error) Upstream {
|
||||
t.Helper()
|
||||
|
||||
return UpstreamSendFunc(func(msg *Msg, _ time.Duration) *Error {
|
||||
// Fast track nil containers.
|
||||
if msg == nil {
|
||||
dErr := deliverFunc(msg)
|
||||
if dErr != nil {
|
||||
t.Errorf("%s>%s: failed to deliver nil msg to terminal: %s", srcName, dstName, dErr)
|
||||
return dErr.With("failed to deliver nil msg to terminal")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Log messages.
|
||||
if logTestCraneMsgs {
|
||||
t.Logf("%s>%s: %v\n", srcName, dstName, msg.Data.CompileData())
|
||||
}
|
||||
|
||||
// Deliver to other terminal.
|
||||
dErr := deliverFunc(msg)
|
||||
if dErr != nil {
|
||||
t.Errorf("%s>%s: failed to deliver to terminal: %s", srcName, dstName, dErr)
|
||||
return dErr.With("failed to deliver to terminal")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type testWithCounterOpts struct {
|
||||
testName string
|
||||
flush bool
|
||||
clientCountTo uint64
|
||||
serverCountTo uint64
|
||||
waitBetweenMsgs time.Duration
|
||||
}
|
||||
|
||||
func testTerminalWithCounters(t *testing.T, term1, term2 *TestTerminal, opts *testWithCounterOpts) {
|
||||
t.Helper()
|
||||
|
||||
// Wait async for test to complete, print stack after timeout.
|
||||
finished := make(chan struct{})
|
||||
maxTestDuration := 60 * time.Second
|
||||
go func() {
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(maxTestDuration):
|
||||
fmt.Printf("terminal test %s is taking more than %s, printing stack:\n", opts.testName, maxTestDuration)
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Logf("starting terminal counter test %s", opts.testName)
|
||||
defer t.Logf("stopping terminal counter test %s", opts.testName)
|
||||
|
||||
// Start counters.
|
||||
counter, tErr := NewCounterOp(term1, CounterOpts{
|
||||
ClientCountTo: opts.clientCountTo,
|
||||
ServerCountTo: opts.serverCountTo,
|
||||
Flush: opts.flush,
|
||||
Wait: opts.waitBetweenMsgs,
|
||||
})
|
||||
if tErr != nil {
|
||||
t.Fatalf("terminal test %s failed to start counter: %s", opts.testName, tErr)
|
||||
}
|
||||
|
||||
// Wait until counters are done.
|
||||
counter.Wait()
|
||||
close(finished)
|
||||
|
||||
// Check for error.
|
||||
if counter.Error != nil {
|
||||
t.Fatalf("terminal test %s failed to count: %s", opts.testName, counter.Error)
|
||||
}
|
||||
|
||||
// Log stats.
|
||||
printCTStats(t, opts.testName, "term1", term1)
|
||||
printCTStats(t, opts.testName, "term2", term2)
|
||||
|
||||
// Check if stats match, if DFQ is used on both sides.
|
||||
dfq1, ok1 := term1.flowControl.(*DuplexFlowQueue)
|
||||
dfq2, ok2 := term2.flowControl.(*DuplexFlowQueue)
|
||||
if ok1 && ok2 &&
|
||||
(atomic.LoadInt32(dfq1.sendSpace) != atomic.LoadInt32(dfq2.reportedSpace) ||
|
||||
atomic.LoadInt32(dfq2.sendSpace) != atomic.LoadInt32(dfq1.reportedSpace)) {
|
||||
t.Fatalf("terminal test %s has non-matching space counters", opts.testName)
|
||||
}
|
||||
}
|
||||
|
||||
func printCTStats(t *testing.T, testName, name string, term *TestTerminal) {
|
||||
t.Helper()
|
||||
|
||||
dfq, ok := term.flowControl.(*DuplexFlowQueue)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf(
|
||||
"%s: %s: sq=%d rq=%d sends=%d reps=%d",
|
||||
testName,
|
||||
name,
|
||||
len(dfq.sendQueue),
|
||||
len(dfq.recvQueue),
|
||||
atomic.LoadInt32(dfq.sendSpace),
|
||||
atomic.LoadInt32(dfq.reportedSpace),
|
||||
)
|
||||
}
|
||||
243
spn/terminal/testing.go
Normal file
243
spn/terminal/testing.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTestQueueSize = 16
|
||||
defaultTestPadding = 8
|
||||
logTestCraneMsgs = false
|
||||
)
|
||||
|
||||
// TestTerminal is a terminal for running tests.
|
||||
type TestTerminal struct {
|
||||
*TerminalBase
|
||||
}
|
||||
|
||||
// NewLocalTestTerminal returns a new local test terminal.
|
||||
func NewLocalTestTerminal(
|
||||
ctx context.Context,
|
||||
id uint32,
|
||||
parentID string,
|
||||
remoteHub *hub.Hub,
|
||||
initMsg *TerminalOpts,
|
||||
upstream Upstream,
|
||||
) (*TestTerminal, *container.Container, *Error) {
|
||||
// Create Terminal Base.
|
||||
t, initData, err := NewLocalBaseTerminal(ctx, id, parentID, remoteHub, initMsg, upstream)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
t.StartWorkers(module, "test terminal")
|
||||
|
||||
return &TestTerminal{t}, initData, nil
|
||||
}
|
||||
|
||||
// NewRemoteTestTerminal returns a new remote test terminal.
|
||||
func NewRemoteTestTerminal(
|
||||
ctx context.Context,
|
||||
id uint32,
|
||||
parentID string,
|
||||
identity *cabin.Identity,
|
||||
initData *container.Container,
|
||||
upstream Upstream,
|
||||
) (*TestTerminal, *TerminalOpts, *Error) {
|
||||
// Create Terminal Base.
|
||||
t, initMsg, err := NewRemoteBaseTerminal(ctx, id, parentID, identity, initData, upstream)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
t.StartWorkers(module, "test terminal")
|
||||
|
||||
return &TestTerminal{t}, initMsg, nil
|
||||
}
|
||||
|
||||
type delayedMsg struct {
|
||||
msg *Msg
|
||||
timeout time.Duration
|
||||
delayUntil time.Time
|
||||
}
|
||||
|
||||
func createDelayingTestForwardingFunc(
|
||||
srcName,
|
||||
dstName string,
|
||||
delay time.Duration,
|
||||
delayQueueSize int,
|
||||
deliverFunc func(msg *Msg, timeout time.Duration) *Error,
|
||||
) func(msg *Msg, timeout time.Duration) *Error {
|
||||
// Return simple forward func if no delay is given.
|
||||
if delay == 0 {
|
||||
return func(msg *Msg, timeout time.Duration) *Error {
|
||||
// Deliver to other terminal.
|
||||
dErr := deliverFunc(msg, timeout)
|
||||
if dErr != nil {
|
||||
log.Errorf("spn/testing: %s>%s: failed to deliver to terminal: %s", srcName, dstName, dErr)
|
||||
return dErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// If there is delay, create a delaying channel and handler.
|
||||
delayedMsgs := make(chan *delayedMsg, delayQueueSize)
|
||||
go func() {
|
||||
for {
|
||||
// Read from chan
|
||||
msg := <-delayedMsgs
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we need to wait.
|
||||
waitFor := time.Until(msg.delayUntil)
|
||||
if waitFor > 0 {
|
||||
time.Sleep(waitFor)
|
||||
}
|
||||
|
||||
// Deliver to other terminal.
|
||||
dErr := deliverFunc(msg.msg, msg.timeout)
|
||||
if dErr != nil {
|
||||
log.Errorf("spn/testing: %s>%s: failed to deliver to terminal: %s", srcName, dstName, dErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return func(msg *Msg, timeout time.Duration) *Error {
|
||||
// Add msg to delaying msg channel.
|
||||
delayedMsgs <- &delayedMsg{
|
||||
msg: msg,
|
||||
timeout: timeout,
|
||||
delayUntil: time.Now().Add(delay),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAbandon gives the terminal the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
func (t *TestTerminal) HandleAbandon(err *Error) (errorToSend *Error) {
|
||||
switch err {
|
||||
case nil:
|
||||
// nil means that the Terminal is being shutdown by the owner.
|
||||
log.Tracef("spn/terminal: %s is closing", fmtTerminalID(t.parentID, t.id))
|
||||
default:
|
||||
// All other errors are faults.
|
||||
log.Warningf("spn/terminal: %s: %s", fmtTerminalID(t.parentID, t.id), err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// NewSimpleTestTerminalPair provides a simple conntected terminal pair for tests.
|
||||
func NewSimpleTestTerminalPair(delay time.Duration, delayQueueSize int, opts *TerminalOpts) (a, b *TestTerminal, err error) {
|
||||
if opts == nil {
|
||||
opts = &TerminalOpts{
|
||||
Padding: defaultTestPadding,
|
||||
FlowControl: FlowControlDFQ,
|
||||
FlowControlSize: defaultTestQueueSize,
|
||||
}
|
||||
}
|
||||
|
||||
var initData *container.Container
|
||||
var tErr *Error
|
||||
a, initData, tErr = NewLocalTestTerminal(
|
||||
module.Ctx, 127, "a", nil, opts, UpstreamSendFunc(createDelayingTestForwardingFunc(
|
||||
"a", "b", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error {
|
||||
return b.Deliver(msg)
|
||||
},
|
||||
)),
|
||||
)
|
||||
if tErr != nil {
|
||||
return nil, nil, tErr.Wrap("failed to create local test terminal")
|
||||
}
|
||||
b, _, tErr = NewRemoteTestTerminal(
|
||||
module.Ctx, 127, "b", nil, initData, UpstreamSendFunc(createDelayingTestForwardingFunc(
|
||||
"b", "a", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error {
|
||||
return a.Deliver(msg)
|
||||
},
|
||||
)),
|
||||
)
|
||||
if tErr != nil {
|
||||
return nil, nil, tErr.Wrap("failed to create remote test terminal")
|
||||
}
|
||||
|
||||
return a, b, nil
|
||||
}
|
||||
|
||||
// BareTerminal is a bare terminal that just returns errors for testing.
|
||||
type BareTerminal struct{}
|
||||
|
||||
var (
|
||||
_ Terminal = &BareTerminal{}
|
||||
|
||||
errNotImplementedByBareTerminal = ErrInternalError.With("not implemented by bare terminal")
|
||||
)
|
||||
|
||||
// ID returns the terminal ID.
|
||||
func (t *BareTerminal) ID() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Ctx returns the terminal context.
|
||||
func (t *BareTerminal) Ctx() context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the terminal.
|
||||
// Should not be overridden by implementations.
|
||||
func (t *BareTerminal) Deliver(msg *Msg) *Error {
|
||||
return errNotImplementedByBareTerminal
|
||||
}
|
||||
|
||||
// Send is used by others to send a message through the terminal.
|
||||
// Should not be overridden by implementations.
|
||||
func (t *BareTerminal) Send(msg *Msg, timeout time.Duration) *Error {
|
||||
return errNotImplementedByBareTerminal
|
||||
}
|
||||
|
||||
// Flush sends all messages waiting in the terminal.
|
||||
// Should not be overridden by implementations.
|
||||
func (t *BareTerminal) Flush(timeout time.Duration) {}
|
||||
|
||||
// StartOperation starts the given operation by assigning it an ID and sending the given operation initialization data.
|
||||
// Should not be overridden by implementations.
|
||||
func (t *BareTerminal) StartOperation(op Operation, initData *container.Container, timeout time.Duration) *Error {
|
||||
return errNotImplementedByBareTerminal
|
||||
}
|
||||
|
||||
// StopOperation stops the given operation.
|
||||
// Should not be overridden by implementations.
|
||||
func (t *BareTerminal) StopOperation(op Operation, err *Error) {}
|
||||
|
||||
// Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon().
|
||||
// Should not be overridden by implementations.
|
||||
func (t *BareTerminal) Abandon(err *Error) {}
|
||||
|
||||
// HandleAbandon gives the terminal the ability to cleanly shut down.
|
||||
// The terminal is still fully functional at this point.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
// Meant to be overridden by implementations.
|
||||
func (t *BareTerminal) HandleAbandon(err *Error) (errorToSend *Error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleDestruction gives the terminal the ability to clean up.
|
||||
// The terminal has already fully shut down at this point.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
// Meant to be overridden by implementations.
|
||||
func (t *BareTerminal) HandleDestruction(err *Error) {}
|
||||
|
||||
// FmtID formats the terminal ID (including parent IDs).
|
||||
// May be overridden by implementations.
|
||||
func (t *BareTerminal) FmtID() string {
|
||||
return "bare"
|
||||
}
|
||||
16
spn/terminal/upstream.go
Normal file
16
spn/terminal/upstream.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package terminal
|
||||
|
||||
import "time"
|
||||
|
||||
// Upstream defines the interface for upstream (parent) components.
|
||||
type Upstream interface {
|
||||
Send(msg *Msg, timeout time.Duration) *Error
|
||||
}
|
||||
|
||||
// UpstreamSendFunc is a helper to be able to satisfy the Upstream interface.
|
||||
type UpstreamSendFunc func(msg *Msg, timeout time.Duration) *Error
|
||||
|
||||
// Send is used to send a message through this upstream.
|
||||
func (fn UpstreamSendFunc) Send(msg *Msg, timeout time.Duration) *Error {
|
||||
return fn(msg, timeout)
|
||||
}
|
||||
Reference in New Issue
Block a user