wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
131
spn/ships/connection_test.go
Normal file
131
spn/ships/connection_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
var (
|
||||
testPort uint16 = 65000
|
||||
testData = []byte("The quick brown fox jumps over the lazy dog")
|
||||
localhost = net.IPv4(127, 0, 0, 1)
|
||||
)
|
||||
|
||||
func getTestPort() uint16 {
|
||||
testPort++
|
||||
return testPort
|
||||
}
|
||||
|
||||
func getTestBuf() []byte {
|
||||
return make([]byte, len(testData))
|
||||
}
|
||||
|
||||
func TestConnections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registryLock.Lock()
|
||||
t.Cleanup(func() {
|
||||
registryLock.Unlock()
|
||||
})
|
||||
|
||||
for k, v := range registry { //nolint:paralleltest // False positive.
|
||||
protocol, builder := k, v
|
||||
t.Run(protocol, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ctx, cancelCtx := context.WithCancel(context.Background())
|
||||
|
||||
// docking requests
|
||||
dockingRequests := make(chan Ship, 1)
|
||||
transport := &hub.Transport{
|
||||
Protocol: protocol,
|
||||
Port: getTestPort(),
|
||||
}
|
||||
|
||||
// create listener
|
||||
pier, err := builder.EstablishPier(transport, dockingRequests)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// connect to listener
|
||||
ship, err := builder.LaunchShip(ctx, transport, localhost)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// client send
|
||||
err = ship.Load(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// dock client
|
||||
srvShip := <-dockingRequests
|
||||
if srvShip == nil {
|
||||
t.Fatalf("%s failed to dock", pier)
|
||||
}
|
||||
|
||||
// server recv
|
||||
buf := getTestBuf()
|
||||
_, err = srvShip.UnloadTo(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// check data
|
||||
assert.Equal(t, testData, buf, "should match")
|
||||
fmt.Print(".")
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
// server send
|
||||
err = srvShip.Load(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// client recv
|
||||
buf = getTestBuf()
|
||||
_, err = ship.UnloadTo(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// check data
|
||||
assert.Equal(t, testData, buf, "should match")
|
||||
fmt.Print(".")
|
||||
|
||||
// client send
|
||||
err = ship.Load(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// server recv
|
||||
buf = getTestBuf()
|
||||
_, err = srvShip.UnloadTo(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// check data
|
||||
assert.Equal(t, testData, buf, "should match")
|
||||
fmt.Print(".")
|
||||
}
|
||||
|
||||
ship.Sink()
|
||||
srvShip.Sink()
|
||||
pier.Abolish()
|
||||
cancelCtx()
|
||||
wg.Wait() // wait for docking procedure to end
|
||||
})
|
||||
}
|
||||
}
|
||||
230
spn/ships/http.go
Normal file
230
spn/ships/http.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
// HTTPShip is a ship that uses HTTP.
|
||||
type HTTPShip struct {
|
||||
ShipBase
|
||||
}
|
||||
|
||||
// HTTPPier is a pier that uses HTTP.
|
||||
type HTTPPier struct {
|
||||
PierBase
|
||||
|
||||
newDockings chan net.Conn
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("http", &Builder{
|
||||
LaunchShip: launchHTTPShip,
|
||||
EstablishPier: establishHTTPPier,
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
HTTP Transport Variants:
|
||||
|
||||
1. Hijack connection and switch to raw SPN protocol:
|
||||
|
||||
Request:
|
||||
|
||||
GET <path> HTTP/1.1
|
||||
Connection: Upgrade
|
||||
Upgrade: SPN
|
||||
|
||||
Response:
|
||||
|
||||
HTTP/1.1 101 Switching Protocols
|
||||
Connection: Upgrade
|
||||
Upgrade: SPN
|
||||
|
||||
*/
|
||||
|
||||
func launchHTTPShip(ctx context.Context, transport *hub.Transport, ip net.IP) (Ship, error) {
|
||||
// Default to root path.
|
||||
path := transport.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
// Build request for Variant 1.
|
||||
variant := 1
|
||||
request, err := http.NewRequest(http.MethodGet, path, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build HTTP request: %w", err)
|
||||
}
|
||||
request.Header.Set("Connection", "Upgrade")
|
||||
request.Header.Set("Upgrade", "SPN")
|
||||
|
||||
// Create connection.
|
||||
var dialNet string
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
dialNet = "tcp4"
|
||||
} else {
|
||||
dialNet = "tcp6"
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
LocalAddr: conf.GetBindAddr(dialNet),
|
||||
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
|
||||
KeepAlive: -1, // Disable keep-alive.
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, dialNet, net.JoinHostPort(ip.String(), portToA(transport.Port)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
// Send HTTP request.
|
||||
err = request.Write(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send HTTP request: %w", err)
|
||||
}
|
||||
|
||||
// Receive HTTP response.
|
||||
response, err := http.ReadResponse(bufio.NewReader(conn), request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read HTTP response: %w", err)
|
||||
}
|
||||
defer response.Body.Close() //nolint:errcheck,gosec
|
||||
|
||||
// Handle response according to variant.
|
||||
switch variant {
|
||||
case 1:
|
||||
if response.StatusCode == http.StatusSwitchingProtocols &&
|
||||
response.Header.Get("Connection") == "Upgrade" &&
|
||||
response.Header.Get("Upgrade") == "SPN" {
|
||||
// Continue
|
||||
} else {
|
||||
return nil, fmt.Errorf("received unexpected response for variant 1: %s", response.Status)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("internal error: unsupported http transport variant: %d", variant)
|
||||
}
|
||||
|
||||
// Create ship.
|
||||
ship := &HTTPShip{
|
||||
ShipBase: ShipBase{
|
||||
conn: conn,
|
||||
transport: transport,
|
||||
mine: true,
|
||||
secure: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Init and return.
|
||||
ship.calculateLoadSize(ip, nil, TCPHeaderMTUSize)
|
||||
ship.initBase()
|
||||
return ship, nil
|
||||
}
|
||||
|
||||
func (pier *HTTPPier) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet &&
|
||||
r.Header.Get("Connection") == "Upgrade" &&
|
||||
r.Header.Get("Upgrade") == "SPN":
|
||||
// Request for Variant 1.
|
||||
|
||||
// Hijack connection.
|
||||
var conn net.Conn
|
||||
if hijacker, ok := w.(http.Hijacker); ok {
|
||||
// Empty body, so the hijacked connection starts with a clean buffer.
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
log.Warningf("ships: failed to empty body for hijack for %s: %s", r.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
_ = r.Body.Close()
|
||||
|
||||
// Reply with upgrade confirmation.
|
||||
w.Header().Set("Connection", "Upgrade")
|
||||
w.Header().Set("Upgrade", "SPN")
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
|
||||
// Get connection.
|
||||
conn, _, err = hijacker.Hijack()
|
||||
if err != nil {
|
||||
log.Warningf("ships: failed to hijack http connection from %s: %s", r.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
log.Warningf("ships: connection from %s cannot be hijacked", r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
// Create new ship.
|
||||
ship := &HTTPShip{
|
||||
ShipBase: ShipBase{
|
||||
transport: pier.transport,
|
||||
conn: conn,
|
||||
mine: false,
|
||||
secure: false,
|
||||
},
|
||||
}
|
||||
ship.calculateLoadSize(nil, conn.RemoteAddr(), TCPHeaderMTUSize)
|
||||
ship.initBase()
|
||||
|
||||
// Submit new docking request.
|
||||
select {
|
||||
case pier.dockingRequests <- ship:
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
// Reply with info page if no variant matches the request.
|
||||
ServeInfoPage(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func establishHTTPPier(transport *hub.Transport, dockingRequests chan Ship) (Pier, error) {
|
||||
// Default to root path.
|
||||
path := transport.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
// Create pier.
|
||||
pier := &HTTPPier{
|
||||
newDockings: make(chan net.Conn),
|
||||
PierBase: PierBase{
|
||||
transport: transport,
|
||||
dockingRequests: dockingRequests,
|
||||
},
|
||||
}
|
||||
pier.initBase()
|
||||
|
||||
// Register handler.
|
||||
err := addHTTPHandler(transport.Port, path, pier.ServeHTTP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add HTTP handler: %w", err)
|
||||
}
|
||||
|
||||
return pier, nil
|
||||
}
|
||||
|
||||
// Abolish closes the underlying listener and cleans up any related resources.
|
||||
func (pier *HTTPPier) Abolish() {
|
||||
// Only abolish once.
|
||||
if !pier.abolishing.SetToIf(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// Do not close the listener, as it is shared.
|
||||
// Instead, remove the HTTP handler and the shared server will shutdown itself when needed.
|
||||
_ = removeHTTPHandler(pier.transport.Port, pier.transport.Path)
|
||||
}
|
||||
83
spn/ships/http_info.go
Normal file
83
spn/ships/http_info.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/info"
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed http_info_page.html.tmpl
|
||||
infoPageData string
|
||||
|
||||
infoPageTemplate *template.Template
|
||||
|
||||
// DisplayHubID holds the Hub ID for displaying it on the info page.
|
||||
DisplayHubID string
|
||||
)
|
||||
|
||||
type infoPageInput struct {
|
||||
Version string
|
||||
Info *info.Info
|
||||
ID string
|
||||
Name string
|
||||
Group string
|
||||
ContactAddress string
|
||||
ContactService string
|
||||
}
|
||||
|
||||
var (
|
||||
pageInputName config.StringOption
|
||||
pageInputGroup config.StringOption
|
||||
pageInputContactAddress config.StringOption
|
||||
pageInputContactService config.StringOption
|
||||
)
|
||||
|
||||
func initPageInput() {
|
||||
infoPageTemplate = template.Must(template.New("info-page").Parse(infoPageData))
|
||||
|
||||
pageInputName = config.Concurrent.GetAsString("spn/publicHub/name", "")
|
||||
pageInputGroup = config.Concurrent.GetAsString("spn/publicHub/group", "")
|
||||
pageInputContactAddress = config.Concurrent.GetAsString("spn/publicHub/contactAddress", "")
|
||||
pageInputContactService = config.Concurrent.GetAsString("spn/publicHub/contactService", "")
|
||||
}
|
||||
|
||||
// ServeInfoPage serves the info page for the given request.
|
||||
func ServeInfoPage(w http.ResponseWriter, r *http.Request) {
|
||||
pageData, err := renderInfoPage()
|
||||
if err != nil {
|
||||
log.Warningf("ships: failed to render SPN info page: %s", err)
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = w.Write(pageData)
|
||||
if err != nil {
|
||||
log.Warningf("ships: failed to write info page: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func renderInfoPage() ([]byte, error) {
|
||||
input := &infoPageInput{
|
||||
Version: info.Version(),
|
||||
Info: info.GetInfo(),
|
||||
ID: DisplayHubID,
|
||||
Name: pageInputName(),
|
||||
Group: pageInputGroup(),
|
||||
ContactAddress: pageInputContactAddress(),
|
||||
ContactService: pageInputContactService(),
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
err := infoPageTemplate.ExecuteTemplate(buf, "info-page", input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
112
spn/ships/http_info_page.html.tmpl
Normal file
112
spn/ships/http_info_page.html.tmpl
Normal file
File diff suppressed because one or more lines are too long
26
spn/ships/http_info_test.go
Normal file
26
spn/ships/http_info_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
)
|
||||
|
||||
func TestInfoPageTemplate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
infoPageTemplate = template.Must(template.New("info-page").Parse(infoPageData))
|
||||
pageInputName = config.Concurrent.GetAsString("spn/publicHub/name", "node-name")
|
||||
pageInputGroup = config.Concurrent.GetAsString("spn/publicHub/group", "node-group")
|
||||
pageInputContactAddress = config.Concurrent.GetAsString("spn/publicHub/contactAddress", "john@doe.com")
|
||||
pageInputContactService = config.Concurrent.GetAsString("spn/publicHub/contactService", "email")
|
||||
|
||||
pageData, err := renderInfoPage()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = pageData
|
||||
// t.Log(string(pageData))
|
||||
}
|
||||
188
spn/ships/http_shared.go
Normal file
188
spn/ships/http_shared.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
type sharedServer struct {
|
||||
server *http.Server
|
||||
|
||||
handlers map[string]http.HandlerFunc
|
||||
handlersLock sync.RWMutex
|
||||
}
|
||||
|
||||
// ServeHTTP forwards requests to registered handler or uses defaults.
|
||||
func (shared *sharedServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
shared.handlersLock.Lock()
|
||||
defer shared.handlersLock.Unlock()
|
||||
|
||||
// Get and forward to registered handler.
|
||||
handler, ok := shared.handlers[r.URL.Path]
|
||||
if ok {
|
||||
handler(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is registered handler and path is "/", respond with info page.
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||
ServeInfoPage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, respond with error.
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
|
||||
var (
|
||||
sharedHTTPServers = make(map[uint16]*sharedServer)
|
||||
sharedHTTPServersLock sync.Mutex
|
||||
)
|
||||
|
||||
func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) error {
|
||||
// Check params.
|
||||
if port == 0 {
|
||||
return errors.New("cannot listen on port 0")
|
||||
}
|
||||
|
||||
// Default to root path.
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
sharedHTTPServersLock.Lock()
|
||||
defer sharedHTTPServersLock.Unlock()
|
||||
|
||||
// Get http server of the port.
|
||||
shared, ok := sharedHTTPServers[port]
|
||||
if ok {
|
||||
// Set path to handler.
|
||||
shared.handlersLock.Lock()
|
||||
defer shared.handlersLock.Unlock()
|
||||
|
||||
// Check if path is already registered.
|
||||
_, ok := shared.handlers[path]
|
||||
if ok {
|
||||
return errors.New("path already registered")
|
||||
}
|
||||
|
||||
// Else, register handler at path.
|
||||
shared.handlers[path] = handler
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shared server does not exist - create one.
|
||||
shared = &sharedServer{
|
||||
handlers: make(map[string]http.HandlerFunc),
|
||||
}
|
||||
|
||||
// Add first handler.
|
||||
shared.handlers[path] = handler
|
||||
|
||||
// Define new server.
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", port),
|
||||
Handler: shared,
|
||||
ReadTimeout: 1 * time.Minute,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
WriteTimeout: 1 * time.Minute,
|
||||
IdleTimeout: 1 * time.Minute,
|
||||
MaxHeaderBytes: 4096,
|
||||
// ErrorLog: &log.Logger{}, // FIXME
|
||||
BaseContext: func(net.Listener) context.Context { return module.Ctx },
|
||||
}
|
||||
shared.server = server
|
||||
|
||||
// Start listeners.
|
||||
bindIPs := conf.GetBindIPs()
|
||||
listeners := make([]net.Listener, 0, len(bindIPs))
|
||||
for _, bindIP := range bindIPs {
|
||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: bindIP,
|
||||
Port: int(port),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
listeners = append(listeners, listener)
|
||||
log.Infof("spn/ships: http transport pier established on %s", listener.Addr())
|
||||
}
|
||||
|
||||
// Add shared http server to list.
|
||||
sharedHTTPServers[port] = shared
|
||||
|
||||
// Start servers in service workers.
|
||||
for _, listener := range listeners {
|
||||
serviceListener := listener
|
||||
module.StartServiceWorker(
|
||||
fmt.Sprintf("shared http server listener on %s", listener.Addr()), 0,
|
||||
func(ctx context.Context) error {
|
||||
err := shared.server.Serve(serviceListener)
|
||||
if !errors.Is(http.ErrServerClosed, err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeHTTPHandler(port uint16, path string) error {
|
||||
// Check params.
|
||||
if port == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Default to root path.
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
sharedHTTPServersLock.Lock()
|
||||
defer sharedHTTPServersLock.Unlock()
|
||||
|
||||
// Get http server of the port.
|
||||
shared, ok := sharedHTTPServers[port]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set path to handler.
|
||||
shared.handlersLock.Lock()
|
||||
defer shared.handlersLock.Unlock()
|
||||
|
||||
// Check if path is registered.
|
||||
_, ok = shared.handlers[path]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove path from handler.
|
||||
delete(shared.handlers, path)
|
||||
|
||||
// Shutdown shared HTTP server if no more handlers are registered.
|
||||
if len(shared.handlers) == 0 {
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
10*time.Second,
|
||||
)
|
||||
defer cancel()
|
||||
return shared.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// Remove shared HTTP server from map.
|
||||
delete(sharedHTTPServers, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
33
spn/ships/http_shared_test.go
Normal file
33
spn/ships/http_shared_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSharedHTTP(t *testing.T) { //nolint:paralleltest // Test checks global state.
|
||||
const testPort = 65100
|
||||
|
||||
// Register multiple handlers.
|
||||
err := addHTTPHandler(testPort, "", ServeInfoPage)
|
||||
assert.NoError(t, err, "should be able to share http listener")
|
||||
err = addHTTPHandler(testPort, "/test", ServeInfoPage)
|
||||
assert.NoError(t, err, "should be able to share http listener")
|
||||
err = addHTTPHandler(testPort, "/test2", ServeInfoPage)
|
||||
assert.NoError(t, err, "should be able to share http listener")
|
||||
err = addHTTPHandler(testPort, "/", ServeInfoPage)
|
||||
assert.Error(t, err, "should fail to register path twice")
|
||||
|
||||
// Unregister
|
||||
assert.NoError(t, removeHTTPHandler(testPort, ""))
|
||||
assert.NoError(t, removeHTTPHandler(testPort, "/test"))
|
||||
assert.NoError(t, removeHTTPHandler(testPort, "/not-registered")) // removing unregistered handler does not error
|
||||
assert.NoError(t, removeHTTPHandler(testPort, "/test2"))
|
||||
assert.NoError(t, removeHTTPHandler(testPort, "/not-registered")) // removing unregistered handler does not error
|
||||
|
||||
// Check if all handlers are gone again.
|
||||
sharedHTTPServersLock.Lock()
|
||||
defer sharedHTTPServersLock.Unlock()
|
||||
assert.Equal(t, 0, len(sharedHTTPServers), "shared http handlers should be back to zero")
|
||||
}
|
||||
81
spn/ships/kcp.go
Normal file
81
spn/ships/kcp.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package ships
|
||||
|
||||
// KCPShip is a ship that uses KCP.
|
||||
type KCPShip struct {
|
||||
ShipBase
|
||||
}
|
||||
|
||||
// KCPPier is a pier that uses KCP.
|
||||
type KCPPier struct {
|
||||
PierBase
|
||||
}
|
||||
|
||||
// TODO: Find a replacement for kcp, which turned out to not fit our use case.
|
||||
/*
|
||||
func init() {
|
||||
Register("kcp", &Builder{
|
||||
LaunchShip: launchKCPShip,
|
||||
EstablishPier: establishKCPPier,
|
||||
})
|
||||
}
|
||||
|
||||
func launchKCPShip(ctx context.Context, transport *hub.Transport, ip net.IP) (Ship, error) {
|
||||
conn, err := kcp.Dial(net.JoinHostPort(ip.String(), portToA(transport.Port)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ship := &KCPShip{
|
||||
ShipBase: ShipBase{
|
||||
conn: conn,
|
||||
transport: transport,
|
||||
mine: true,
|
||||
secure: false,
|
||||
// Calculate KCP's MSS.
|
||||
loadSize: kcp.IKCP_MTU_DEF - kcp.IKCP_OVERHEAD,
|
||||
},
|
||||
}
|
||||
|
||||
ship.initBase()
|
||||
return ship, nil
|
||||
}
|
||||
|
||||
func establishKCPPier(transport *hub.Transport, dockingRequests chan *DockingRequest) (Pier, error) {
|
||||
listener, err := kcp.Listen(net.JoinHostPort("", portToA(transport.Port)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pier := &KCPPier{
|
||||
PierBase: PierBase{
|
||||
transport: transport,
|
||||
listener: listener,
|
||||
dockingRequests: dockingRequests,
|
||||
},
|
||||
}
|
||||
pier.PierBase.dockShip = pier.dockShip
|
||||
pier.initBase()
|
||||
return pier, nil
|
||||
}
|
||||
|
||||
func (pier *KCPPier) dockShip() (Ship, error) {
|
||||
conn, err := pier.listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ship := &KCPShip{
|
||||
ShipBase: ShipBase{
|
||||
conn: conn,
|
||||
transport: pier.transport,
|
||||
mine: false,
|
||||
secure: false,
|
||||
// Calculate KCP's MSS.
|
||||
loadSize: kcp.IKCP_MTU_DEF - kcp.IKCP_OVERHEAD,
|
||||
},
|
||||
}
|
||||
|
||||
ship.initBase()
|
||||
return ship, nil
|
||||
}
|
||||
*/
|
||||
114
spn/ships/launch.go
Normal file
114
spn/ships/launch.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
// Launch launches a new ship to the given Hub.
|
||||
func Launch(ctx context.Context, h *hub.Hub, transport *hub.Transport, ip net.IP) (Ship, error) {
|
||||
var transports []*hub.Transport
|
||||
var ips []net.IP
|
||||
|
||||
// choose transports
|
||||
if transport != nil {
|
||||
transports = []*hub.Transport{transport}
|
||||
} else {
|
||||
if h.Info == nil {
|
||||
return nil, hub.ErrMissingInfo
|
||||
}
|
||||
transports = h.Info.ParsedTransports()
|
||||
// If there are no transports, check if they were parsed.
|
||||
if len(transports) == 0 && len(h.Info.Transports) > 0 {
|
||||
log.Errorf("ships: %s has no parsed transports, but transports are %v", h, h.Info.Transports)
|
||||
// Attempt to parse transports now.
|
||||
transports, _ = hub.ParseTransports(h.Info.Transports)
|
||||
}
|
||||
// Fail if there are not transports.
|
||||
if len(transports) == 0 {
|
||||
return nil, hub.ErrMissingTransports
|
||||
}
|
||||
}
|
||||
|
||||
// choose IPs
|
||||
if ip != nil {
|
||||
ips = []net.IP{ip}
|
||||
} else {
|
||||
if h.Info == nil {
|
||||
return nil, hub.ErrMissingInfo
|
||||
}
|
||||
ips = make([]net.IP, 0, 3)
|
||||
// If IPs have been verified, check if we can use a virtual network address.
|
||||
var vnetForced bool
|
||||
if h.VerifiedIPs {
|
||||
vnet := GetVirtualNetworkConfig()
|
||||
if vnet != nil {
|
||||
virtIP := vnet.Mapping[h.ID]
|
||||
if virtIP != nil {
|
||||
ips = append(ips, virtIP)
|
||||
if vnet.Force {
|
||||
vnetForced = true
|
||||
log.Infof("spn/ships: forcing virtual network address %s for %s", virtIP, h)
|
||||
} else {
|
||||
log.Infof("spn/ships: using virtual network address %s for %s", virtIP, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add Hub's IPs if no virtual address was forced.
|
||||
if !vnetForced {
|
||||
// prioritize IPv4
|
||||
if h.Info.IPv4 != nil {
|
||||
ips = append(ips, h.Info.IPv4)
|
||||
}
|
||||
if h.Info.IPv6 != nil && netenv.IPv6Enabled() {
|
||||
ips = append(ips, h.Info.IPv6)
|
||||
}
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, hub.ErrMissingIPs
|
||||
}
|
||||
}
|
||||
|
||||
// connect
|
||||
var firstErr error
|
||||
for _, ip := range ips {
|
||||
for _, tr := range transports {
|
||||
ship, err := connectTo(ctx, h, tr, ip)
|
||||
if err == nil {
|
||||
return ship, nil // return on success
|
||||
}
|
||||
|
||||
// Check if context is canceled.
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Save first error.
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
func connectTo(ctx context.Context, h *hub.Hub, transport *hub.Transport, ip net.IP) (Ship, error) {
|
||||
builder := GetBuilder(transport.Protocol)
|
||||
if builder == nil {
|
||||
return nil, fmt.Errorf("protocol %s not supported", transport.Protocol)
|
||||
}
|
||||
|
||||
ship, err := builder.LaunchShip(ctx, transport, ip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to %s using %s (%s): %w", h, transport, ip, err)
|
||||
}
|
||||
|
||||
return ship, nil
|
||||
}
|
||||
63
spn/ships/masking.go
Normal file
63
spn/ships/masking.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"net"
|
||||
|
||||
"github.com/mr-tron/base58"
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
var (
|
||||
maskingEnabled = abool.New()
|
||||
maskingActive = abool.New()
|
||||
maskingBytes []byte
|
||||
)
|
||||
|
||||
// EnableMasking enables masking with the given salt.
|
||||
func EnableMasking(salt []byte) {
|
||||
if maskingEnabled.SetToIf(false, true) {
|
||||
maskingBytes = salt
|
||||
maskingActive.Set()
|
||||
}
|
||||
}
|
||||
|
||||
// MaskAddress masks the given address if masking is enabled and the ship is
|
||||
// not public.
|
||||
func (ship *ShipBase) MaskAddress(addr net.Addr) string {
|
||||
// Return in plain if masking is not enabled or if ship is public.
|
||||
if maskingActive.IsNotSet() || ship.Public() {
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
switch typedAddr := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return ship.MaskIP(typedAddr.IP)
|
||||
case *net.UDPAddr:
|
||||
return ship.MaskIP(typedAddr.IP)
|
||||
default:
|
||||
return ship.Mask([]byte(addr.String()))
|
||||
}
|
||||
}
|
||||
|
||||
// MaskIP masks the given IP if masking is enabled and the ship is not public.
|
||||
func (ship *ShipBase) MaskIP(ip net.IP) string {
|
||||
// Return in plain if masking is not enabled or if ship is public.
|
||||
if maskingActive.IsNotSet() || ship.Public() {
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
return ship.Mask(ip)
|
||||
}
|
||||
|
||||
// Mask masks the given value.
|
||||
func (ship *ShipBase) Mask(value []byte) string {
|
||||
// Hash the IP with masking bytes.
|
||||
hasher := sha1.New() //nolint:gosec // Not used for cryptography.
|
||||
hasher.Write(maskingBytes)
|
||||
hasher.Write(value)
|
||||
masked := hasher.Sum(nil)
|
||||
|
||||
// Return first 8 characters from the base58-encoded hash.
|
||||
return "masked:" + base58.Encode(masked)[:8]
|
||||
}
|
||||
20
spn/ships/module.go
Normal file
20
spn/ships/module.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
var module *modules.Module
|
||||
|
||||
func init() {
|
||||
module = modules.Register("ships", start, nil, nil, "cabin")
|
||||
}
|
||||
|
||||
func start() error {
|
||||
if conf.PublicHub() {
|
||||
initPageInput()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
47
spn/ships/mtu.go
Normal file
47
spn/ships/mtu.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package ships
|
||||
|
||||
import "net"
|
||||
|
||||
// MTU Calculation Configuration.
|
||||
const (
|
||||
BaseMTU = 1460 // 1500 with 40 bytes extra space for special cases.
|
||||
IPv4HeaderMTUSize = 20 // Without options, as not common.
|
||||
IPv6HeaderMTUSize = 40 // Without options, as not common.
|
||||
TCPHeaderMTUSize = 60 // Maximum size with options.
|
||||
UDPHeaderMTUSize = 8 // Has no options.
|
||||
)
|
||||
|
||||
func (ship *ShipBase) calculateLoadSize(ip net.IP, addr net.Addr, subtract ...int) {
|
||||
ship.loadSize = BaseMTU
|
||||
|
||||
// Convert addr to IP if needed.
|
||||
if ip == nil && addr != nil {
|
||||
switch v := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
ip = v.IP
|
||||
case *net.UDPAddr:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract IP Header, if IP is available.
|
||||
if ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
ship.loadSize -= IPv4HeaderMTUSize
|
||||
} else {
|
||||
ship.loadSize -= IPv6HeaderMTUSize
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract others.
|
||||
for sub := range subtract {
|
||||
ship.loadSize -= sub
|
||||
}
|
||||
|
||||
// Raise buf size to at least load size.
|
||||
if ship.bufSize < ship.loadSize {
|
||||
ship.bufSize = ship.loadSize
|
||||
}
|
||||
}
|
||||
82
spn/ships/pier.go
Normal file
82
spn/ships/pier.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
// Pier represents a network connection listener.
|
||||
type Pier interface {
|
||||
// String returns a human readable informational summary about the ship.
|
||||
String() string
|
||||
|
||||
// Transport returns the transport used for this ship.
|
||||
Transport() *hub.Transport
|
||||
|
||||
// Abolish closes the underlying listener and cleans up any related resources.
|
||||
Abolish()
|
||||
}
|
||||
|
||||
// DockingRequest is a uniform request that Piers emit when a new ship arrives.
|
||||
type DockingRequest struct {
|
||||
Pier Pier
|
||||
Ship Ship
|
||||
Err error
|
||||
}
|
||||
|
||||
// EstablishPier is shorthand function to get the transport's builder and establish a pier.
|
||||
func EstablishPier(transport *hub.Transport, dockingRequests chan Ship) (Pier, error) {
|
||||
builder := GetBuilder(transport.Protocol)
|
||||
if builder == nil {
|
||||
return nil, fmt.Errorf("protocol %s not supported", transport.Protocol)
|
||||
}
|
||||
|
||||
pier, err := builder.EstablishPier(transport, dockingRequests)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to establish pier on %s: %w", transport, err)
|
||||
}
|
||||
|
||||
return pier, nil
|
||||
}
|
||||
|
||||
// PierBase implements common functions to comply with the Pier interface.
|
||||
type PierBase struct {
|
||||
// transport holds the transport definition of the pier.
|
||||
transport *hub.Transport
|
||||
// listeners holds the actual underlying listeners.
|
||||
listeners []net.Listener
|
||||
|
||||
// dockingRequests is used to report new connections to the higher layer.
|
||||
dockingRequests chan Ship
|
||||
|
||||
// abolishing specifies if the pier and listener is being closed.
|
||||
abolishing *abool.AtomicBool
|
||||
}
|
||||
|
||||
func (pier *PierBase) initBase() {
|
||||
// init
|
||||
pier.abolishing = abool.New()
|
||||
}
|
||||
|
||||
// String returns a human readable informational summary about the ship.
|
||||
func (pier *PierBase) String() string {
|
||||
return fmt.Sprintf("<Pier %s>", pier.transport)
|
||||
}
|
||||
|
||||
// Transport returns the transport used for this ship.
|
||||
func (pier *PierBase) Transport() *hub.Transport {
|
||||
return pier.transport
|
||||
}
|
||||
|
||||
// Abolish closes the underlying listener and cleans up any related resources.
|
||||
func (pier *PierBase) Abolish() {
|
||||
if pier.abolishing.SetToIf(false, true) {
|
||||
for _, listener := range pier.listeners {
|
||||
_ = listener.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
55
spn/ships/registry.go
Normal file
55
spn/ships/registry.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
// Builder is a factory that can build ships and piers of it's protocol.
|
||||
type Builder struct {
|
||||
LaunchShip func(ctx context.Context, transport *hub.Transport, ip net.IP) (Ship, error)
|
||||
EstablishPier func(transport *hub.Transport, dockingRequests chan Ship) (Pier, error)
|
||||
}
|
||||
|
||||
var (
|
||||
registry = make(map[string]*Builder)
|
||||
allProtocols []string
|
||||
registryLock sync.Mutex
|
||||
)
|
||||
|
||||
// Register registers a new builder for a protocol.
|
||||
func Register(protocol string, builder *Builder) {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
registry[protocol] = builder
|
||||
}
|
||||
|
||||
// GetBuilder returns the builder for the given protocol, or nil if it does not exist.
|
||||
func GetBuilder(protocol string) *Builder {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
builder, ok := registry[protocol]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
||||
// Protocols returns a slice with all registered protocol names. The return slice must not be edited.
|
||||
func Protocols() []string {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
return allProtocols
|
||||
}
|
||||
|
||||
// portToA transforms the given port into a string.
|
||||
func portToA(port uint16) string {
|
||||
return strconv.FormatUint(uint64(port), 10)
|
||||
}
|
||||
220
spn/ships/ship.go
Normal file
220
spn/ships/ship.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLoadSize = 4096
|
||||
)
|
||||
|
||||
// ErrSunk is returned when a ship sunk, ie. the connection was lost.
|
||||
var ErrSunk = errors.New("ship sunk")
|
||||
|
||||
// Ship represents a network layer connection.
|
||||
type Ship interface {
|
||||
// String returns a human readable informational summary about the ship.
|
||||
String() string
|
||||
|
||||
// Transport returns the transport used for this ship.
|
||||
Transport() *hub.Transport
|
||||
|
||||
// IsMine returns whether the ship was launched from here.
|
||||
IsMine() bool
|
||||
|
||||
// IsSecure returns whether the ship provides transport security.
|
||||
IsSecure() bool
|
||||
|
||||
// Public returns whether the ship is marked as public.
|
||||
Public() bool
|
||||
|
||||
// MarkPublic marks the ship as public.
|
||||
MarkPublic()
|
||||
|
||||
// LoadSize returns the recommended data size that should be handed to Load().
|
||||
// This value will be most likely somehow related to the connection's MTU.
|
||||
// Alternatively, using a multiple of LoadSize is also recommended.
|
||||
LoadSize() int
|
||||
|
||||
// Load loads data into the ship - ie. sends the data via the connection.
|
||||
// Returns ErrSunk if the ship has already sunk earlier.
|
||||
Load(data []byte) error
|
||||
|
||||
// UnloadTo unloads data from the ship - ie. receives data from the
|
||||
// connection - puts it into the buf. It returns the amount of data
|
||||
// written and an optional error.
|
||||
// Returns ErrSunk if the ship has already sunk earlier.
|
||||
UnloadTo(buf []byte) (n int, err error)
|
||||
|
||||
// LocalAddr returns the underlying local net.Addr of the connection.
|
||||
LocalAddr() net.Addr
|
||||
|
||||
// RemoteAddr returns the underlying remote net.Addr of the connection.
|
||||
RemoteAddr() net.Addr
|
||||
|
||||
// Sink closes the underlying connection and cleans up any related resources.
|
||||
Sink()
|
||||
|
||||
// MaskAddress masks the address, if enabled.
|
||||
MaskAddress(addr net.Addr) string
|
||||
// MaskIP masks an IP, if enabled.
|
||||
MaskIP(ip net.IP) string
|
||||
// Mask masks a value.
|
||||
Mask(value []byte) string
|
||||
}
|
||||
|
||||
// ShipBase implements common functions to comply with the Ship interface.
|
||||
type ShipBase struct {
|
||||
// conn is the actual underlying connection.
|
||||
conn net.Conn
|
||||
// transport holds the transport definition of the ship.
|
||||
transport *hub.Transport
|
||||
|
||||
// mine specifies whether the ship was launched from here.
|
||||
mine bool
|
||||
// secure specifies whether the ship provides transport security.
|
||||
secure bool
|
||||
// public specifies whether the ship is public.
|
||||
public *abool.AtomicBool
|
||||
// bufSize specifies the size of the receive buffer.
|
||||
bufSize int
|
||||
// loadSize specifies the recommended data size that should be handed to Load().
|
||||
loadSize int
|
||||
|
||||
// initial holds initial data from setting up the ship.
|
||||
initial []byte
|
||||
// sinking specifies if the connection is being closed.
|
||||
sinking *abool.AtomicBool
|
||||
}
|
||||
|
||||
func (ship *ShipBase) initBase() {
|
||||
// init
|
||||
ship.sinking = abool.New()
|
||||
ship.public = abool.New()
|
||||
|
||||
// set default
|
||||
if ship.loadSize == 0 {
|
||||
ship.loadSize = defaultLoadSize
|
||||
}
|
||||
if ship.bufSize == 0 {
|
||||
ship.bufSize = ship.loadSize
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a human readable informational summary about the ship.
|
||||
func (ship *ShipBase) String() string {
|
||||
if ship.mine {
|
||||
return fmt.Sprintf("<Ship to %s using %s>", ship.MaskAddress(ship.RemoteAddr()), ship.transport)
|
||||
}
|
||||
return fmt.Sprintf("<Ship from %s using %s>", ship.MaskAddress(ship.RemoteAddr()), ship.transport)
|
||||
}
|
||||
|
||||
// Transport returns the transport used for this ship.
|
||||
func (ship *ShipBase) Transport() *hub.Transport {
|
||||
return ship.transport
|
||||
}
|
||||
|
||||
// IsMine returns whether the ship was launched from here.
|
||||
func (ship *ShipBase) IsMine() bool {
|
||||
return ship.mine
|
||||
}
|
||||
|
||||
// IsSecure returns whether the ship provides transport security.
|
||||
func (ship *ShipBase) IsSecure() bool {
|
||||
return ship.secure
|
||||
}
|
||||
|
||||
// Public returns whether the ship is marked as public.
|
||||
func (ship *ShipBase) Public() bool {
|
||||
return ship.public.IsSet()
|
||||
}
|
||||
|
||||
// MarkPublic marks the ship as public.
|
||||
func (ship *ShipBase) MarkPublic() {
|
||||
ship.public.Set()
|
||||
}
|
||||
|
||||
// LoadSize returns the recommended data size that should be handed to Load().
|
||||
// This value will be most likely somehow related to the connection's MTU.
|
||||
// Alternatively, using a multiple of LoadSize is also recommended.
|
||||
func (ship *ShipBase) LoadSize() int {
|
||||
return ship.loadSize
|
||||
}
|
||||
|
||||
// Load loads data into the ship - ie. sends the data via the connection.
|
||||
// Returns ErrSunk if the ship has already sunk earlier.
|
||||
func (ship *ShipBase) Load(data []byte) error {
|
||||
// Empty load is used as a signal to cease operaetion.
|
||||
if len(data) == 0 {
|
||||
if ship.sinking.SetToIf(false, true) {
|
||||
_ = ship.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send all given data.
|
||||
n, err := ship.conn.Write(data)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case n == 0:
|
||||
return errors.New("loaded 0 bytes")
|
||||
case n < len(data):
|
||||
// If not all data was sent, try again.
|
||||
log.Debugf("spn/ships: %s only loaded %d/%d bytes", ship, n, len(data))
|
||||
data = data[n:]
|
||||
return ship.Load(data)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnloadTo unloads data from the ship - ie. receives data from the
|
||||
// connection - puts it into the buf. It returns the amount of data
|
||||
// written and an optional error.
|
||||
// Returns ErrSunk if the ship has already sunk earlier.
|
||||
func (ship *ShipBase) UnloadTo(buf []byte) (n int, err error) {
|
||||
// Process initial data, if there is any.
|
||||
if ship.initial != nil {
|
||||
// Copy as much data as possible.
|
||||
copy(buf, ship.initial)
|
||||
|
||||
// If buf was too small, skip the copied section.
|
||||
if len(buf) < len(ship.initial) {
|
||||
ship.initial = ship.initial[len(buf):]
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
// If everything was copied, unset the initial data.
|
||||
n := len(ship.initial)
|
||||
ship.initial = nil
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Receive data.
|
||||
return ship.conn.Read(buf)
|
||||
}
|
||||
|
||||
// LocalAddr returns the underlying local net.Addr of the connection.
|
||||
func (ship *ShipBase) LocalAddr() net.Addr {
|
||||
return ship.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the underlying remote net.Addr of the connection.
|
||||
func (ship *ShipBase) RemoteAddr() net.Addr {
|
||||
return ship.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// Sink closes the underlying connection and cleans up any related resources.
|
||||
func (ship *ShipBase) Sink() {
|
||||
if ship.sinking.SetToIf(false, true) {
|
||||
_ = ship.conn.Close()
|
||||
}
|
||||
}
|
||||
145
spn/ships/tcp.go
Normal file
145
spn/ships/tcp.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
// TCPShip is a ship that uses TCP.
|
||||
type TCPShip struct {
|
||||
ShipBase
|
||||
}
|
||||
|
||||
// TCPPier is a pier that uses TCP.
|
||||
type TCPPier struct {
|
||||
PierBase
|
||||
|
||||
ctx context.Context
|
||||
cancelCtx context.CancelFunc
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("tcp", &Builder{
|
||||
LaunchShip: launchTCPShip,
|
||||
EstablishPier: establishTCPPier,
|
||||
})
|
||||
}
|
||||
|
||||
func launchTCPShip(ctx context.Context, transport *hub.Transport, ip net.IP) (Ship, error) {
|
||||
var dialNet string
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
dialNet = "tcp4"
|
||||
} else {
|
||||
dialNet = "tcp6"
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
LocalAddr: conf.GetBindAddr(dialNet),
|
||||
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
|
||||
KeepAlive: -1, // Disable keep-alive.
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, dialNet, net.JoinHostPort(ip.String(), portToA(transport.Port)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
ship := &TCPShip{
|
||||
ShipBase: ShipBase{
|
||||
conn: conn,
|
||||
transport: transport,
|
||||
mine: true,
|
||||
secure: false,
|
||||
},
|
||||
}
|
||||
|
||||
ship.calculateLoadSize(ip, nil, TCPHeaderMTUSize)
|
||||
ship.initBase()
|
||||
return ship, nil
|
||||
}
|
||||
|
||||
func establishTCPPier(transport *hub.Transport, dockingRequests chan Ship) (Pier, error) {
|
||||
// Start listeners.
|
||||
bindIPs := conf.GetBindIPs()
|
||||
listeners := make([]net.Listener, 0, len(bindIPs))
|
||||
for _, bindIP := range bindIPs {
|
||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: bindIP,
|
||||
Port: int(transport.Port),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
listeners = append(listeners, listener)
|
||||
log.Infof("spn/ships: tcp transport pier established on %s", listener.Addr())
|
||||
}
|
||||
|
||||
// Create new pier.
|
||||
pierCtx, cancelCtx := context.WithCancel(module.Ctx)
|
||||
pier := &TCPPier{
|
||||
PierBase: PierBase{
|
||||
transport: transport,
|
||||
listeners: listeners,
|
||||
dockingRequests: dockingRequests,
|
||||
},
|
||||
ctx: pierCtx,
|
||||
cancelCtx: cancelCtx,
|
||||
}
|
||||
pier.initBase()
|
||||
|
||||
// Start workers.
|
||||
for _, listener := range pier.listeners {
|
||||
serviceListener := listener
|
||||
module.StartServiceWorker("accept TCP docking requests", 0, func(ctx context.Context) error {
|
||||
return pier.dockingWorker(ctx, serviceListener)
|
||||
})
|
||||
}
|
||||
|
||||
return pier, nil
|
||||
}
|
||||
|
||||
func (pier *TCPPier) dockingWorker(_ context.Context, listener net.Listener) error {
|
||||
for {
|
||||
// Block until something happens.
|
||||
conn, err := listener.Accept()
|
||||
|
||||
// Check for errors.
|
||||
switch {
|
||||
case pier.ctx.Err() != nil:
|
||||
return pier.ctx.Err()
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
// Create new ship.
|
||||
ship := &TCPShip{
|
||||
ShipBase: ShipBase{
|
||||
transport: pier.transport,
|
||||
conn: conn,
|
||||
mine: false,
|
||||
secure: false,
|
||||
},
|
||||
}
|
||||
ship.calculateLoadSize(nil, conn.RemoteAddr(), TCPHeaderMTUSize)
|
||||
ship.initBase()
|
||||
|
||||
// Submit new docking request.
|
||||
select {
|
||||
case pier.dockingRequests <- ship:
|
||||
case <-pier.ctx.Done():
|
||||
return pier.ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Abolish closes the underlying listener and cleans up any related resources.
|
||||
func (pier *TCPPier) Abolish() {
|
||||
pier.cancelCtx()
|
||||
pier.PierBase.Abolish()
|
||||
}
|
||||
154
spn/ships/testship.go
Normal file
154
spn/ships/testship.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/mr-tron/base58"
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
// TestShip is a simulated ship that is used for testing higher level components.
|
||||
type TestShip struct {
|
||||
mine bool
|
||||
secure bool
|
||||
loadSize int
|
||||
forward chan []byte
|
||||
backward chan []byte
|
||||
unloadTmp []byte
|
||||
sinking *abool.AtomicBool
|
||||
}
|
||||
|
||||
// NewTestShip returns a new TestShip for simulation.
|
||||
func NewTestShip(secure bool, loadSize int) *TestShip {
|
||||
return &TestShip{
|
||||
mine: true,
|
||||
secure: secure,
|
||||
loadSize: loadSize,
|
||||
forward: make(chan []byte, 100),
|
||||
backward: make(chan []byte, 100),
|
||||
sinking: abool.NewBool(false),
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a human readable informational summary about the ship.
|
||||
func (ship *TestShip) String() string {
|
||||
if ship.mine {
|
||||
return "<TestShip outbound>"
|
||||
}
|
||||
return "<TestShip inbound>"
|
||||
}
|
||||
|
||||
// Transport returns the transport used for this ship.
|
||||
func (ship *TestShip) Transport() *hub.Transport {
|
||||
return &hub.Transport{
|
||||
Protocol: "dummy",
|
||||
}
|
||||
}
|
||||
|
||||
// IsMine returns whether the ship was launched from here.
|
||||
func (ship *TestShip) IsMine() bool {
|
||||
return ship.mine
|
||||
}
|
||||
|
||||
// IsSecure returns whether the ship provides transport security.
|
||||
func (ship *TestShip) IsSecure() bool {
|
||||
return ship.secure
|
||||
}
|
||||
|
||||
// LoadSize returns the recommended data size that should be handed to Load().
|
||||
// This value will be most likely somehow related to the connection's MTU.
|
||||
// Alternatively, using a multiple of LoadSize is also recommended.
|
||||
func (ship *TestShip) LoadSize() int {
|
||||
return ship.loadSize
|
||||
}
|
||||
|
||||
// Reverse creates a connected TestShip. This is used to simulate a connection instead of using a Pier.
|
||||
func (ship *TestShip) Reverse() *TestShip {
|
||||
return &TestShip{
|
||||
mine: !ship.mine,
|
||||
secure: ship.secure,
|
||||
loadSize: ship.loadSize,
|
||||
forward: ship.backward,
|
||||
backward: ship.forward,
|
||||
sinking: abool.NewBool(false),
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads data into the ship - ie. sends the data via the connection.
|
||||
// Returns ErrSunk if the ship has already sunk earlier.
|
||||
func (ship *TestShip) Load(data []byte) error {
|
||||
// Debugging:
|
||||
// log.Debugf("spn/ship: loading %s", spew.Sdump(data))
|
||||
|
||||
// Check if ship is alive.
|
||||
if ship.sinking.IsSet() {
|
||||
return ErrSunk
|
||||
}
|
||||
|
||||
// Empty load is used as a signal to cease operaetion.
|
||||
if len(data) == 0 {
|
||||
ship.Sink()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send all given data.
|
||||
ship.forward <- data
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnloadTo unloads data from the ship - ie. receives data from the
|
||||
// connection - puts it into the buf. It returns the amount of data
|
||||
// written and an optional error.
|
||||
// Returns ErrSunk if the ship has already sunk earlier.
|
||||
func (ship *TestShip) UnloadTo(buf []byte) (n int, err error) {
|
||||
// Process unload tmp data, if there is any.
|
||||
if ship.unloadTmp != nil {
|
||||
// Copy as much data as possible.
|
||||
copy(buf, ship.unloadTmp)
|
||||
|
||||
// If buf was too small, skip the copied section.
|
||||
if len(buf) < len(ship.unloadTmp) {
|
||||
ship.unloadTmp = ship.unloadTmp[len(buf):]
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
// If everything was copied, unset the unloadTmp data.
|
||||
n := len(ship.unloadTmp)
|
||||
ship.unloadTmp = nil
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Receive data.
|
||||
data := <-ship.backward
|
||||
if len(data) == 0 {
|
||||
return 0, ErrSunk
|
||||
}
|
||||
|
||||
// Copy data, possibly save remainder for later.
|
||||
copy(buf, data)
|
||||
if len(buf) < len(data) {
|
||||
ship.unloadTmp = data[len(buf):]
|
||||
return len(buf), nil
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// Sink closes the underlying connection and cleans up any related resources.
|
||||
func (ship *TestShip) Sink() {
|
||||
if ship.sinking.SetToIf(false, true) {
|
||||
close(ship.forward)
|
||||
}
|
||||
}
|
||||
|
||||
// Dummy methods to conform to interface for testing.
|
||||
|
||||
func (ship *TestShip) LocalAddr() net.Addr { return nil } //nolint:golint
|
||||
func (ship *TestShip) RemoteAddr() net.Addr { return nil } //nolint:golint
|
||||
func (ship *TestShip) Public() bool { return true } //nolint:golint
|
||||
func (ship *TestShip) MarkPublic() {} //nolint:golint
|
||||
func (ship *TestShip) MaskAddress(addr net.Addr) string { return addr.String() } //nolint:golint
|
||||
func (ship *TestShip) MaskIP(ip net.IP) string { return ip.String() } //nolint:golint
|
||||
func (ship *TestShip) Mask(value []byte) string { return base58.Encode(value) } //nolint:golint
|
||||
58
spn/ships/testship_test.go
Normal file
58
spn/ships/testship_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTestShip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tShip := NewTestShip(true, 100)
|
||||
|
||||
// interface conformance test
|
||||
var ship Ship = tShip
|
||||
|
||||
srvShip := tShip.Reverse()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
// client send
|
||||
err := ship.Load(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// server recv
|
||||
buf := getTestBuf()
|
||||
_, err = srvShip.UnloadTo(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// check data
|
||||
assert.Equal(t, testData, buf, "should match")
|
||||
fmt.Print(".")
|
||||
|
||||
// server send
|
||||
err = srvShip.Load(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// client recv
|
||||
buf = getTestBuf()
|
||||
_, err = ship.UnloadTo(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %s", ship, err)
|
||||
}
|
||||
|
||||
// check data
|
||||
assert.Equal(t, testData, buf, "should match")
|
||||
fmt.Print(".")
|
||||
}
|
||||
|
||||
ship.Sink()
|
||||
srvShip.Sink()
|
||||
}
|
||||
43
spn/ships/virtual_network.go
Normal file
43
spn/ships/virtual_network.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package ships
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
var (
|
||||
virtNetLock sync.Mutex
|
||||
virtNetConfig *hub.VirtualNetworkConfig
|
||||
)
|
||||
|
||||
// SetVirtualNetworkConfig sets the virtual networking config.
|
||||
func SetVirtualNetworkConfig(config *hub.VirtualNetworkConfig) {
|
||||
virtNetLock.Lock()
|
||||
defer virtNetLock.Unlock()
|
||||
|
||||
virtNetConfig = config
|
||||
}
|
||||
|
||||
// GetVirtualNetworkConfig returns the virtual networking config.
|
||||
func GetVirtualNetworkConfig() *hub.VirtualNetworkConfig {
|
||||
virtNetLock.Lock()
|
||||
defer virtNetLock.Unlock()
|
||||
|
||||
return virtNetConfig
|
||||
}
|
||||
|
||||
// GetVirtualNetworkAddress returns the virtual network IP for the given Hub.
|
||||
func GetVirtualNetworkAddress(dstHubID string) net.IP {
|
||||
virtNetLock.Lock()
|
||||
defer virtNetLock.Unlock()
|
||||
|
||||
// Check if we have a virtual network config.
|
||||
if virtNetConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return mapping for given Hub ID.
|
||||
return virtNetConfig.Mapping[dstHubID]
|
||||
}
|
||||
Reference in New Issue
Block a user