wip: migrate to mono-repo. SPN has already been moved to spn/

This commit is contained in:
Patrick Pacher
2024-03-15 11:55:13 +01:00
parent b30fd00ccf
commit 8579430db9
577 changed files with 35981 additions and 818 deletions

26
spn/TESTING.md Normal file
View File

@@ -0,0 +1,26 @@
# Testing SPN
This page documents ways to test if the SPN works as intended.
⚠ Work in Progress. Currently we are just collecting helpful things we find.
## Test Multi-Identity Routing
In order to test if the multi-identity routing is working, you can request multiple websites to display your public IP.
If they show different values, multi-identity routing is working.
### Websites
- <https://icanhazip.com>
- <https://ipecho.net>
- <https://ipinfo.io>
- <https://ipinfo.tw>
### Terminal
```sh
curl https://icanhazip.com
curl https://ipecho.net/plain
curl https://ipinfo.io/ip
curl https://ipinfo.tw/ip
```

5
spn/TRADEMARKS Normal file
View File

@@ -0,0 +1,5 @@
The names "Safing", "Portmaster", "SPN" and their logos are trademarks owned by Safing ICS Technologies GmbH (Austria).
Although our code is free, it is very important that we strictly enforce our trademark rights, in order to be able to protect our users against people who use the marks to commit fraud. This means that, while you have considerable freedom to redistribute and modify our software, there are tight restrictions on your ability to use our names and logos in ways which fall in the domain of trademark law, even when built into binaries that we provide.
This file is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License. Parts of it were taken from https://www.mozilla.org/en-US/foundation/licensing/.

View File

@@ -0,0 +1,65 @@
package account
import (
"errors"
"net/http"
)
// Authentication Headers.
const (
AuthHeaderDevice = "Device-17"
AuthHeaderToken = "Token-17"
AuthHeaderNextToken = "Next-Token-17"
AuthHeaderNextTokenDeprecated = "Next_token_17"
)
// Errors.
var (
ErrMissingDeviceID = errors.New("missing device ID")
ErrMissingToken = errors.New("missing token")
)
// AuthToken holds an authentication token.
type AuthToken struct {
Device string
Token string
}
// GetAuthTokenFromRequest extracts an authentication token from a request.
func GetAuthTokenFromRequest(request *http.Request) (*AuthToken, error) {
device := request.Header.Get(AuthHeaderDevice)
if device == "" {
return nil, ErrMissingDeviceID
}
token := request.Header.Get(AuthHeaderToken)
if token == "" {
return nil, ErrMissingToken
}
return &AuthToken{
Device: device,
Token: token,
}, nil
}
// ApplyTo applies the authentication token to a request.
func (at *AuthToken) ApplyTo(request *http.Request) {
request.Header.Set(AuthHeaderDevice, at.Device)
request.Header.Set(AuthHeaderToken, at.Token)
}
// GetNextTokenFromResponse extracts an authentication token from a response.
func GetNextTokenFromResponse(resp *http.Response) (token string, ok bool) {
token = resp.Header.Get(AuthHeaderNextToken)
if token == "" {
// TODO: Remove when fixed on server.
token = resp.Header.Get(AuthHeaderNextTokenDeprecated)
}
return token, token != ""
}
// ApplyNextTokenToResponse applies the next authentication token to a response.
func ApplyNextTokenToResponse(w http.ResponseWriter, token string) {
w.Header().Set(AuthHeaderNextToken, token)
}

View File

@@ -0,0 +1,14 @@
package account
// Customer Agent URLs.
const (
CAAuthenticateURL = "/authenticate"
CAProfileURL = "/user/profile"
CAGetTokensURL = "/tokens"
)
// Customer Hub URLs.
const (
CHAuthenticateURL = "/v1/authenticate"
CHUserProfileURL = "/v1/user_profile"
)

137
spn/access/account/types.go Normal file
View File

@@ -0,0 +1,137 @@
package account
import (
"time"
"golang.org/x/exp/slices"
)
// User, Subscription and Charge states.
const (
// UserStateNone is only used within Portmaster for saving information for
// logging into the same device.
UserStateNone = ""
UserStateFresh = "fresh"
UserStateQueued = "queued"
UserStateApproved = "approved"
UserStateSuspended = "suspended"
UserStateLoggedOut = "loggedout" // Portmaster only.
SubscriptionStateManual = "manual" // Manual renewal.
SubscriptionStateActive = "active" // Automatic renewal.
SubscriptionStateCancelled = "cancelled" // Automatic, but canceled.
ChargeStatePending = "pending"
ChargeStateCompleted = "completed"
ChargeStateDead = "dead"
)
// Agent and Hub return statuses.
const (
// StatusInvalidAuth [401 Unauthorized] is returned when the credentials are
// invalid or the user was logged out.
StatusInvalidAuth = 401
// StatusNoAccess [403 Forbidden] is returned when the user does not have
// an active subscription or the subscription does not include the required
// feature for the request.
StatusNoAccess = 403
// StatusInvalidDevice [410 Gone] is returned when the device trying to
// log into does not exist.
StatusInvalidDevice = 410
// StatusReachedDeviceLimit [409 Conflict] is returned when the device limit is reached.
StatusReachedDeviceLimit = 409
// StatusDeviceInactive [423 Locked] is returned when the device is locked.
StatusDeviceInactive = 423
// StatusNotLoggedIn [412 Precondition] is returned by the Portmaster, if an action required to be logged in, but the user is not logged in.
StatusNotLoggedIn = 412
// StatusUnknownError is a special status code that signifies an unknown or
// unexpected error by the API.
StatusUnknownError = -1
// StatusConnectionError is a special status code that signifies a
// connection error.
StatusConnectionError = -2
)
// User describes an SPN user account.
type User struct {
Username string `json:"username"`
State string `json:"state"`
Balance int `json:"balance"`
Device *Device `json:"device"`
Subscription *Subscription `json:"subscription"`
CurrentPlan *Plan `json:"current_plan"`
NextPlan *Plan `json:"next_plan"`
View *View `json:"view"`
}
// MayUseSPN returns whether the user may currently use the SPN.
func (u *User) MayUseSPN() bool {
return u.MayUse(FeatureSPN)
}
// MayUsePrioritySupport returns whether the user may currently use the priority support.
func (u *User) MayUsePrioritySupport() bool {
return u.MayUse(FeatureSafingSupport)
}
// MayUse returns whether the user may currently use the feature identified by
// the given feature ID.
// Leave feature ID empty to check without feature.
func (u *User) MayUse(featureID FeatureID) bool {
switch {
case u == nil:
// We need a user, obviously.
case u.State != UserStateApproved:
// Only approved users may use the SPN.
case u.Subscription == nil:
// Need a subscription.
case u.Subscription.EndsAt == nil:
case time.Now().After(*u.Subscription.EndsAt):
// Subscription needs to be active.
case u.CurrentPlan == nil:
// Need a plan / package.
case featureID != "" &&
!slices.Contains(u.CurrentPlan.FeatureIDs, featureID):
// Required feature ID must be in plan / package feature IDs.
default:
// All checks passed!
return true
}
return false
}
// Device describes a device of an SPN user.
type Device struct {
Name string `json:"name"`
ID string `json:"id"`
}
// Subscription describes an SPN subscription.
type Subscription struct {
EndsAt *time.Time `json:"ends_at"`
State string `json:"state"`
NextBillingDate *time.Time `json:"next_billing_date"`
PaymentProvider string `json:"payment_provider"`
}
// FeatureID defines a feature that requires a plan/subscription.
type FeatureID string
// A list of all supported features.
const (
FeatureSPN = FeatureID("spn")
FeatureSafingSupport = FeatureID("support")
FeatureHistory = FeatureID("history")
FeatureBWVis = FeatureID("bw-vis")
FeatureVPNCompat = FeatureID("vpn-compat")
)
// Plan describes an SPN subscription plan.
type Plan struct {
Name string `json:"name"`
Amount int `json:"amount"`
Months int `json:"months"`
Renewable bool `json:"renewable"`
FeatureIDs []FeatureID `json:"feature_ids"`
}

123
spn/access/account/view.go Normal file
View File

@@ -0,0 +1,123 @@
package account
import (
"fmt"
"strings"
"time"
)
// View holds metadata that assists in displaying account information.
type View struct {
Message string
ShowAccountData bool
ShowAccountButton bool
ShowLoginButton bool
ShowRefreshButton bool
ShowLogoutButton bool
}
// UpdateView updates the view and handles plan/package fallbacks.
func (u *User) UpdateView(requestStatusCode int) {
v := &View{}
// Clean up naming and fallbacks when finished.
defer func() {
// Display "Free" package if no plan is set or if it expired.
switch {
case u.CurrentPlan == nil,
u.Subscription == nil,
u.Subscription.EndsAt == nil:
// Reset to free plan.
u.CurrentPlan = &Plan{
Name: "Free",
}
u.Subscription = nil
case u.Subscription.NextBillingDate != nil:
// Subscription is on auto-renew.
// Wait for update from server.
case time.Since(*u.Subscription.EndsAt) > 0:
// Reset to free plan.
u.CurrentPlan = &Plan{
Name: "Free",
}
u.Subscription = nil
}
// Prepend "Portmaster " to plan name.
// TODO: Remove when Plan/Package naming has been updated.
if u.CurrentPlan != nil && !strings.HasPrefix(u.CurrentPlan.Name, "Portmaster ") {
u.CurrentPlan.Name = "Portmaster " + u.CurrentPlan.Name
}
// Apply new view to user.
u.View = v
}()
// Set view data based on return code.
switch requestStatusCode {
case StatusInvalidAuth, StatusInvalidDevice, StatusDeviceInactive:
// Account deleted or Device inactive or deleted.
// When using token based auth, there is no difference between these cases.
v.Message = "This device may have been deactivated or removed from your account. Please log in again."
v.ShowAccountData = true
v.ShowAccountButton = true
v.ShowLoginButton = true
v.ShowLogoutButton = true
return
case StatusUnknownError:
v.Message = "There is an unknown error in the communication with the account server. The shown information may not be accurate. "
case StatusConnectionError:
v.Message = "Portmaster could not connect to the account server. The shown information may not be accurate. "
}
// Set view data based on profile data.
switch {
case u.State == UserStateLoggedOut:
// User logged out.
v.ShowAccountButton = true
v.ShowLoginButton = true
return
case u.State == UserStateSuspended:
// Account is suspended.
v.Message += fmt.Sprintf("Your account (%s) was suspended. Please contact support for details.", u.Username)
v.ShowAccountButton = true
v.ShowRefreshButton = true
v.ShowLogoutButton = true
return
case u.Subscription == nil || u.Subscription.EndsAt == nil:
// Account has never had a subscription.
v.Message += "Get more features. Upgrade today."
case u.Subscription.NextBillingDate != nil:
switch {
case time.Since(*u.Subscription.NextBillingDate) > 0:
v.Message += "Your auto-renewal seems to be delayed. Please refresh and check the status of your payment. Payment information may be delayed."
case time.Until(*u.Subscription.NextBillingDate) < 24*time.Hour:
v.Message += "Your subscription will auto-renew soon. Please note that payment information may be delayed."
}
case time.Since(*u.Subscription.EndsAt) > 0:
// Subscription expired.
if u.CurrentPlan != nil {
v.Message += fmt.Sprintf("Your package %s has ended. Extend it on the Account Page.", u.CurrentPlan.Name)
} else {
v.Message += "Your package has ended. Extend it on the Account Page."
}
case time.Until(*u.Subscription.EndsAt) < 7*24*time.Hour:
// Add generic ending soon message if the package ends in less than 7 days.
v.Message += "Your package ends soon. Extend it on the Account Page."
}
// Defaults for generally good accounts.
v.ShowAccountData = true
v.ShowAccountButton = true
v.ShowRefreshButton = true
v.ShowLogoutButton = true
}

168
spn/access/api.go Normal file
View File

@@ -0,0 +1,168 @@
package access
import (
"fmt"
"net/http"
"github.com/safing/portbase/api"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/access/account"
)
func registerAPIEndpoints() error {
if err := api.RegisterEndpoint(api.Endpoint{
Path: `spn/account/login`,
Write: api.PermitAdmin,
WriteMethod: http.MethodPost,
HandlerFunc: handleLogin,
Name: "SPN Login",
Description: "Log into your SPN account.",
}); err != nil {
return err
}
if err := api.RegisterEndpoint(api.Endpoint{
Path: `spn/account/logout`,
Write: api.PermitAdmin,
WriteMethod: http.MethodDelete,
ActionFunc: handleLogout,
Name: "SPN Logout",
Description: "Logout from your SPN account.",
Parameters: []api.Parameter{
{
Method: http.MethodDelete,
Field: "purge",
Value: "",
Description: "If set, account data is purged. Otherwise, the username and device ID are kept in order to log into the same device when logging in with the same user again.",
},
},
}); err != nil {
return err
}
if err := api.RegisterEndpoint(api.Endpoint{
Path: `spn/account/user/profile`,
Read: api.PermitUser,
ReadMethod: http.MethodGet,
RecordFunc: handleGetUserProfile,
Name: "SPN User Profile",
Description: "Get the user profile of the logged in SPN account.",
Parameters: []api.Parameter{
{
Method: http.MethodGet,
Field: "refresh",
Value: "",
Description: "If set, the user profile is freshly fetched from the account server.",
},
},
}); err != nil {
return err
}
if err := api.RegisterEndpoint(api.Endpoint{
Path: `account/features`,
Read: api.PermitUser,
ReadMethod: http.MethodGet,
StructFunc: func(_ *api.Request) (i interface{}, err error) {
return struct {
Features []Feature
}{
Features: features,
}, nil
},
Name: "Get Account Features",
Description: "Returns all account features.",
}); err != nil {
return err
}
if err := api.RegisterEndpoint(api.Endpoint{
Path: `account/features/{id:[A-Za-z0-9_-]+}/icon`,
Read: api.PermitUser,
ReadMethod: http.MethodGet,
Name: "Returns the image of the featuare",
MimeType: "image/svg+xml",
DataFunc: func(ar *api.Request) (data []byte, err error) {
featureID, ok := ar.URLVars["id"]
if !ok {
return nil, fmt.Errorf("invalid feature id")
}
for _, feature := range features {
if feature.ID == featureID {
return []byte(feature.icon), nil
}
}
return nil, fmt.Errorf("feature id not found")
},
}); err != nil {
return err
}
return nil
}
func handleLogin(w http.ResponseWriter, r *http.Request) {
// Get username and password.
username, password, ok := r.BasicAuth()
// Request, if omitted.
if !ok || username == "" || password == "" {
w.Header().Set("WWW-Authenticate", "Basic realm=SPN Login")
http.Error(w, "Login with your SPN account.", http.StatusUnauthorized)
return
}
// Process login.
user, code, err := Login(username, password)
if err != nil {
log.Warningf("spn/access: failed to login: %s", err)
if code == 0 {
http.Error(w, "Internal error: "+err.Error(), http.StatusInternalServerError)
} else {
http.Error(w, err.Error(), code)
}
return
}
// Return success.
_, _ = w.Write([]byte(
fmt.Sprintf("Now logged in as %s as device %s", user.Username, user.Device.Name),
))
}
func handleLogout(ar *api.Request) (msg string, err error) {
purge := ar.URL.Query().Get("purge") != ""
err = Logout(false, purge)
switch {
case err != nil:
log.Warningf("spn/access: failed to logout: %s", err)
return "", err
case purge:
return "Logged out and user data purged.", nil
default:
return "Logged out.", nil
}
}
func handleGetUserProfile(ar *api.Request) (r record.Record, err error) {
// Check if we are already authenticated.
user, err := GetUser()
if err != nil || user.State == account.UserStateNone {
return nil, api.ErrorWithStatus(
ErrNotLoggedIn,
account.StatusInvalidAuth,
)
}
// Should we refresh the user profile?
if ar.URL.Query().Get("refresh") != "" {
user, _, err = UpdateUser()
if err != nil {
return nil, err
}
}
return user, nil
}

550
spn/access/client.go Normal file
View File

@@ -0,0 +1,550 @@
package access
import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/access/account"
"github.com/safing/portmaster/spn/access/token"
)
// Client URLs.
const (
AccountServer = "https://api.account.safing.io"
LoginPath = "/api/v1/authenticate"
UserProfilePath = "/api/v1/user/profile"
TokenRequestSetupPath = "/api/v1/token/request/setup" //nolint:gosec
TokenRequestIssuePath = "/api/v1/token/request/issue" //nolint:gosec
HealthCheckPath = "/api/v1/health"
defaultDataFormat = dsd.CBOR
defaultRequestTimeout = 30 * time.Second
)
var (
accountClient = &http.Client{}
clientRequestLock sync.Mutex
// EnableAfterLogin automatically enables the SPN subsystem/module after login.
EnableAfterLogin = true
)
type clientRequestOptions struct {
method string
url string
send interface{}
recv interface{}
requestTimeout time.Duration
dataFormat uint8
setAuthToken bool
requireNextAuthToken bool
logoutOnAuthError bool
requestSetupFunc func(*http.Request) error
}
func makeClientRequest(opts *clientRequestOptions) (resp *http.Response, err error) {
// Get request timeout.
if opts.requestTimeout == 0 {
opts.requestTimeout = defaultRequestTimeout
}
// Get context for request.
var ctx context.Context
var cancel context.CancelFunc
if module.Online() {
// Only use module context if online.
ctx, cancel = context.WithTimeout(module.Ctx, opts.requestTimeout)
defer cancel()
} else {
// Otherwise, use the background context.
ctx, cancel = context.WithTimeout(context.Background(), opts.requestTimeout)
defer cancel()
}
// Create new request.
request, err := http.NewRequestWithContext(ctx, opts.method, opts.url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request structure: %w", err)
}
// Prepare body and content type.
if opts.dataFormat == dsd.AUTO {
opts.dataFormat = defaultDataFormat
}
if opts.send != nil {
// Add data to body.
err = dsd.DumpToHTTPRequest(request, opts.send, opts.dataFormat)
if err != nil {
return nil, fmt.Errorf("failed to add request body: %w", err)
}
} else {
// Set requested HTTP response format.
_, err = dsd.RequestHTTPResponseFormat(request, opts.dataFormat)
if err != nil {
return nil, fmt.Errorf("failed to set requested response format: %w", err)
}
}
// Get auth token to apply to request.
var authToken *AuthTokenRecord
if opts.setAuthToken {
authToken, err = GetAuthToken()
if err != nil {
return nil, ErrNotLoggedIn
}
authToken.Token.ApplyTo(request)
}
// Do any additional custom request setup.
if opts.requestSetupFunc != nil {
err = opts.requestSetupFunc(request)
if err != nil {
return nil, err
}
}
// Make request.
resp, err = accountClient.Do(request)
if err != nil {
updateUserWithFailedRequest(account.StatusConnectionError, false)
tokenIssuerFailed()
return nil, fmt.Errorf("http request failed: %w", err)
}
log.Debugf("spn/access: request to %s returned %s", request.URL, resp.Status)
defer func() {
_ = resp.Body.Close()
}()
// Handle request error.
switch resp.StatusCode {
case http.StatusOK, http.StatusCreated:
// All good!
case account.StatusInvalidAuth, account.StatusInvalidDevice:
// Wrong username / password.
updateUserWithFailedRequest(resp.StatusCode, true)
return resp, ErrInvalidCredentials
case account.StatusReachedDeviceLimit:
// Device limit is reached.
updateUserWithFailedRequest(resp.StatusCode, true)
return resp, ErrDeviceLimitReached
case account.StatusDeviceInactive:
// Device is locked.
updateUserWithFailedRequest(resp.StatusCode, true)
return resp, ErrDeviceIsLocked
default:
updateUserWithFailedRequest(account.StatusUnknownError, false)
tokenIssuerFailed()
return resp, fmt.Errorf("unexpected reply: [%d] %s", resp.StatusCode, resp.Status)
}
// Save next auth token.
if authToken != nil {
err = authToken.Update(resp)
if err != nil {
if errors.Is(err, account.ErrMissingToken) {
if opts.requireNextAuthToken {
return resp, fmt.Errorf("failed to save next auth token: %w", err)
}
} else {
return resp, fmt.Errorf("failed to save next auth token: %w", err)
}
}
} else if opts.requireNextAuthToken {
return resp, fmt.Errorf("failed to save next auth token: %w", account.ErrMissingToken)
}
// Load response data.
if opts.recv != nil {
_, err = dsd.LoadFromHTTPResponse(resp, opts.recv)
if err != nil {
return resp, fmt.Errorf("failed to parse response: %w", err)
}
}
tokenIssuerIsFailing.UnSet()
return resp, nil
}
func updateUserWithFailedRequest(statusCode int, disableSubscription bool) {
// Get user from database.
user, err := GetUser()
if err != nil {
if !errors.Is(err, ErrNotLoggedIn) {
log.Warningf("spn/access: failed to get user to update with failed request: %s", err)
}
return
}
func() {
user.Lock()
defer user.Unlock()
// Ignore update if user state is undefined or logged out.
if user.State == "" || user.State == account.UserStateLoggedOut {
return
}
// Disable the subscription if desired.
if disableSubscription && user.Subscription != nil {
user.Subscription.EndsAt = nil
}
// Update view with the status code and save user.
user.UpdateView(statusCode)
}()
err = user.Save()
if err != nil {
log.Warningf("spn/access: failed to save user after update with failed request: %s", err)
}
}
// Login logs the user into the SPN account with the given username and password.
func Login(username, password string) (user *UserRecord, code int, err error) {
clientRequestLock.Lock()
defer clientRequestLock.Unlock()
// Trigger account update when done.
defer module.TriggerEvent(AccountUpdateEvent, nil)
// Get previous user.
previousUser, err := GetUser()
if err != nil {
if !errors.Is(err, ErrNotLoggedIn) {
log.Warningf("spn/access: failed to get previous for re-login: %s", err)
}
previousUser = nil
}
// Create request options.
userAccount := &account.User{}
requestOptions := &clientRequestOptions{
method: http.MethodPost,
url: AccountServer + LoginPath,
recv: userAccount,
dataFormat: dsd.JSON,
requestSetupFunc: func(request *http.Request) error {
// Add username and password.
request.SetBasicAuth(username, password)
// Try to reuse the device ID, if the username matches the previous user.
if previousUser != nil && username == previousUser.Username {
request.Header.Set(account.AuthHeaderDevice, previousUser.Device.ID)
}
return nil
},
}
// Make request.
resp, err := makeClientRequest(requestOptions) //nolint:bodyclose // Body is closed in function.
if err != nil {
if resp != nil && resp.StatusCode == account.StatusInvalidDevice {
// Try again without the previous device ID.
previousUser = nil
log.Info("spn/access: retrying log in without re-using previous device ID")
resp, err = makeClientRequest(requestOptions) //nolint:bodyclose // Body is closed in function.
}
if err != nil {
if resp != nil {
return nil, resp.StatusCode, err
}
return nil, 0, err
}
}
// Save new user.
now := time.Now()
user = &UserRecord{
User: userAccount,
LoggedInAt: &now,
}
user.UpdateView(0)
err = user.Save()
if err != nil {
return user, resp.StatusCode, fmt.Errorf("failed to save new user profile: %w", err)
}
// Save initial auth token.
err = SaveNewAuthToken(user.Device.ID, resp)
if err != nil {
return user, resp.StatusCode, fmt.Errorf("failed to save initial auth token: %w", err)
}
// Enable the SPN right after login.
if user.MayUseSPN() && EnableAfterLogin {
enableSPN()
}
log.Infof("spn/access: logged in as %q on device %q", user.Username, user.Device.Name)
return user, resp.StatusCode, nil
}
// Logout logs the user out of the SPN account.
// Specify "shallow" to keep user data in order to display data in the
// UI - preferably when logged out be the server.
// Specify "purge" in order to fully delete all user account data, even
// the device ID so that logging in again will create a new device.
func Logout(shallow, purge bool) error {
clientRequestLock.Lock()
defer clientRequestLock.Unlock()
// Trigger account update when done.
defer module.TriggerEvent(AccountUpdateEvent, nil)
// Clear caches.
clearUserCaches()
// Clear tokens.
clearTokens()
// Delete auth token.
err := db.Delete(authTokenRecordKey)
if err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to delete auth token: %w", err)
}
// Delete all user data if purging.
if purge {
err := db.Delete(userRecordKey)
if err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to delete user: %w", err)
}
// Disable SPN when the user logs out directly.
disableSPN()
log.Info("spn/access: logged out and purged data")
return nil
}
// Else, just update the user.
user, err := GetUser()
if err != nil {
if errors.Is(err, ErrNotLoggedIn) {
return nil
}
return fmt.Errorf("failed to load user for logout: %w", err)
}
func() {
user.Lock()
defer user.Unlock()
if shallow {
// Shallow logout: User stays logged in the UI to display status when
// logged out from the Portmaster or Customer Hub.
user.User.State = account.UserStateLoggedOut
} else {
// Proper logout: User is logged out from UI.
// Reset all user data, except for username and device ID in order to log
// into the same device again.
user.User = &account.User{
Username: user.Username,
Device: &account.Device{
ID: user.Device.ID,
},
}
user.LoggedInAt = &time.Time{}
}
user.UpdateView(0)
}()
err = user.Save()
if err != nil {
return fmt.Errorf("failed to save user for logout: %w", err)
}
if shallow {
log.Info("spn/access: logged out shallow")
} else {
log.Info("spn/access: logged out")
// Disable SPN when the user logs out directly.
disableSPN()
}
return nil
}
// UpdateUser fetches the current user information from the server.
func UpdateUser() (user *UserRecord, statusCode int, err error) {
clientRequestLock.Lock()
defer clientRequestLock.Unlock()
// Trigger account update when done.
defer module.TriggerEvent(AccountUpdateEvent, nil)
// Create request options.
userData := &account.User{}
requestOptions := &clientRequestOptions{
method: http.MethodGet,
url: AccountServer + UserProfilePath,
recv: userData,
dataFormat: dsd.JSON,
setAuthToken: true,
requireNextAuthToken: true,
logoutOnAuthError: true,
}
// Make request.
resp, err := makeClientRequest(requestOptions) //nolint:bodyclose // Body is closed in function.
if err != nil {
if resp != nil {
return nil, resp.StatusCode, err
}
return nil, 0, err
}
// Save to previous user, if exists.
previousUser, err := GetUser()
if err == nil {
func() {
previousUser.Lock()
defer previousUser.Unlock()
previousUser.User = userData
previousUser.UpdateView(resp.StatusCode)
}()
err := previousUser.Save()
if err != nil {
log.Warningf("spn/access: failed to save updated user profile: %s", err)
}
// Notify user of nearing end of package.
notifyOfPackageEnd(previousUser)
log.Infof("spn/access: got user profile, updated existing")
return previousUser, resp.StatusCode, nil
}
// Else, save as new user.
now := time.Now()
newUser := &UserRecord{
User: userData,
LoggedInAt: &now,
}
newUser.UpdateView(resp.StatusCode)
err = newUser.Save()
if err != nil {
log.Warningf("spn/access: failed to save new user profile: %s", err)
}
// Notify user of nearing end of package.
notifyOfPackageEnd(newUser)
log.Infof("spn/access: got user profile, saved as new")
return newUser, resp.StatusCode, nil
}
// UpdateTokens fetches more tokens for handlers that need it.
func UpdateTokens() error {
clientRequestLock.Lock()
defer clientRequestLock.Unlock()
// Check if the user may request tokens.
user, err := GetUser()
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if !user.MayUseTheSPN() {
return ErrMayNotUseSPN
}
// Create setup request, return if not required.
setupRequest, setupRequired := token.CreateSetupRequest()
var setupResponse *token.SetupResponse
if setupRequired {
// Request setup data.
setupResponse = &token.SetupResponse{}
_, err := makeClientRequest(&clientRequestOptions{ //nolint:bodyclose // Body is closed in function.
method: http.MethodPost,
url: AccountServer + TokenRequestSetupPath,
send: setupRequest,
recv: setupResponse,
dataFormat: dsd.MsgPack,
setAuthToken: true,
logoutOnAuthError: true,
})
if err != nil {
return fmt.Errorf("failed to request setup data: %w", err)
}
}
// Create request for issuing new tokens.
tokenRequest, requestRequired, err := token.CreateTokenRequest(setupResponse)
if err != nil {
return fmt.Errorf("failed to create token request: %w", err)
}
if !requestRequired {
return nil
}
// Request issuing new tokens.
issuedTokens := &token.IssuedTokens{}
_, err = makeClientRequest(&clientRequestOptions{ //nolint:bodyclose // Body is closed in function.
method: http.MethodPost,
url: AccountServer + TokenRequestIssuePath,
send: tokenRequest,
recv: issuedTokens,
dataFormat: dsd.MsgPack,
setAuthToken: true,
logoutOnAuthError: true,
})
if err != nil {
return fmt.Errorf("failed to request tokens: %w", err)
}
// Save tokens to handlers.
err = token.ProcessIssuedTokens(issuedTokens)
if err != nil {
return fmt.Errorf("failed to process issued tokens: %w", err)
}
// Log new status.
regular, fallback := GetTokenAmount(ExpandAndConnectZones)
log.Infof(
"spn/access: got new tokens, now at %d regular and %d fallback tokens for expand and connect",
regular,
fallback,
)
return nil
}
var (
lastHealthCheckExpires time.Time
lastHealthCheckLock sync.Mutex
lastHealthCheckValidityDuration = 30 * time.Second
)
func healthCheck() (ok bool) {
lastHealthCheckLock.Lock()
defer lastHealthCheckLock.Unlock()
// Return current value if recently checked.
if time.Now().Before(lastHealthCheckExpires) {
return tokenIssuerIsFailing.IsNotSet()
}
// Check health.
_, err := makeClientRequest(&clientRequestOptions{ //nolint:bodyclose // Body is closed in function.
method: http.MethodGet,
url: AccountServer + HealthCheckPath,
})
if err != nil {
log.Warningf("spn/access: token issuer health check failed: %s", err)
}
// Update health check expiry.
lastHealthCheckExpires = time.Now().Add(lastHealthCheckValidityDuration)
return tokenIssuerIsFailing.IsNotSet()
}

79
spn/access/client_test.go Normal file
View File

@@ -0,0 +1,79 @@
package access
import (
"os"
"testing"
)
var (
testUsername = os.Getenv("SPN_TEST_USERNAME")
testPassword = os.Getenv("SPN_TEST_PASSWORD")
)
func TestClient(t *testing.T) {
// Skip test in CI.
if testing.Short() {
t.Skip()
}
t.Parallel()
if testUsername == "" || testPassword == "" {
t.Fatal("test username or password not configured")
}
loginAndRefresh(t, true, 5)
clearUserCaches()
loginAndRefresh(t, false, 1)
err := Logout(false, false)
if err != nil {
t.Fatalf("failed to log out: %s", err)
}
t.Logf("logged out")
loginAndRefresh(t, true, 1)
err = Logout(false, true)
if err != nil {
t.Fatalf("failed to log out: %s", err)
}
t.Logf("logged out with purge")
loginAndRefresh(t, true, 1)
}
func loginAndRefresh(t *testing.T, doLogin bool, refreshTimes int) {
t.Helper()
if doLogin {
_, _, err := Login(testUsername, testPassword)
if err != nil {
t.Fatalf("login failed: %s", err)
}
user, err := GetUser()
if err != nil {
t.Fatalf("failed to get user: %s", err)
}
t.Logf("user (from login): %+v", user.User)
t.Logf("device (from login): %+v", user.User.Device)
authToken, err := GetAuthToken()
if err != nil {
t.Fatalf("failed to get auth token: %s", err)
}
t.Logf("auth token: %+v", authToken.Token)
}
for i := 0; i < refreshTimes; i++ {
user, _, err := UpdateUser()
if err != nil {
t.Fatalf("getting profile failed: %s", err)
}
t.Logf("user (from refresh): %+v", user.User)
authToken, err := GetAuthToken()
if err != nil {
t.Fatalf("failed to get auth token: %s", err)
}
t.Logf("auth token: %+v", authToken.Token)
}
}

258
spn/access/database.go Normal file
View File

@@ -0,0 +1,258 @@
package access
import (
"errors"
"fmt"
"net/http"
"sync"
"time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/record"
"github.com/safing/portmaster/spn/access/account"
)
const (
userRecordKey = "core:spn/account/user"
authTokenRecordKey = "core:spn/account/authtoken" //nolint:gosec // Not a credential.
tokenStorageKeyTemplate = "core:spn/account/tokens/%s" //nolint:gosec // Not a credential.
)
var db = database.NewInterface(&database.Options{
Local: true,
Internal: true,
})
// UserRecord holds a SPN user account.
type UserRecord struct {
record.Base
sync.Mutex
*account.User
LastNotifiedOfEnd *time.Time
LoggedInAt *time.Time
}
// MayUseSPN returns whether the user may currently use the SPN.
func (user *UserRecord) MayUseSPN() bool {
// Shadow this function in order to allow calls on a nil user.
if user == nil || user.User == nil {
return false
}
return user.User.MayUseSPN()
}
// MayUsePrioritySupport returns whether the user may currently use the priority support.
func (user *UserRecord) MayUsePrioritySupport() bool {
// Shadow this function in order to allow calls on a nil user.
if user == nil || user.User == nil {
return false
}
return user.User.MayUsePrioritySupport()
}
// MayUse returns whether the user may currently use the feature identified by
// the given feature ID.
// Leave feature ID empty to check without feature.
func (user *UserRecord) MayUse(featureID account.FeatureID) bool {
// Shadow this function in order to allow calls on a nil user.
if user == nil || user.User == nil {
return false
}
return user.User.MayUse(featureID)
}
// AuthTokenRecord holds an authentication token.
type AuthTokenRecord struct {
record.Base
sync.Mutex
Token *account.AuthToken
}
// GetToken returns the token from the record.
func (authToken *AuthTokenRecord) GetToken() *account.AuthToken {
authToken.Lock()
defer authToken.Unlock()
return authToken.Token
}
// SaveNewAuthToken saves a new auth token to the database.
func SaveNewAuthToken(deviceID string, resp *http.Response) error {
token, ok := account.GetNextTokenFromResponse(resp)
if !ok {
return account.ErrMissingToken
}
newAuthToken := &AuthTokenRecord{
Token: &account.AuthToken{
Device: deviceID,
Token: token,
},
}
return newAuthToken.Save()
}
// Update updates an existing auth token with the next token from a response.
func (authToken *AuthTokenRecord) Update(resp *http.Response) error {
token, ok := account.GetNextTokenFromResponse(resp)
if !ok {
return account.ErrMissingToken
}
// Update token with new account.AuthToken.
func() {
authToken.Lock()
defer authToken.Unlock()
authToken.Token = &account.AuthToken{
Device: authToken.Token.Device,
Token: token,
}
}()
return authToken.Save()
}
var (
accountCacheLock sync.Mutex
cachedUser *UserRecord
cachedUserSet bool
cachedAuthToken *AuthTokenRecord
)
func clearUserCaches() {
accountCacheLock.Lock()
defer accountCacheLock.Unlock()
cachedUser = nil
cachedUserSet = false
cachedAuthToken = nil
}
// GetUser returns the current user account.
// Returns nil when no user is logged in.
func GetUser() (*UserRecord, error) {
// Check cache.
accountCacheLock.Lock()
defer accountCacheLock.Unlock()
if cachedUserSet {
if cachedUser == nil {
return nil, ErrNotLoggedIn
}
return cachedUser, nil
}
// Load from disk.
r, err := db.Get(userRecordKey)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
cachedUser = nil
cachedUserSet = true
return nil, ErrNotLoggedIn
}
return nil, err
}
// Unwrap record.
if r.IsWrapped() {
// only allocate a new struct, if we need it
newUser := &UserRecord{}
err = record.Unwrap(r, newUser)
if err != nil {
return nil, err
}
cachedUser = newUser
cachedUserSet = true
return cachedUser, nil
}
// Or adjust type.
newUser, ok := r.(*UserRecord)
if !ok {
return nil, fmt.Errorf("record not of type *UserRecord, but %T", r)
}
cachedUser = newUser
cachedUserSet = true
return cachedUser, nil
}
// Save saves the User.
func (user *UserRecord) Save() error {
// Update cache.
accountCacheLock.Lock()
defer accountCacheLock.Unlock()
cachedUser = user
cachedUserSet = true
// Update view if unset.
if user.View == nil {
user.UpdateView(0)
}
// Set, check and update metadata.
if !user.KeyIsSet() {
user.SetKey(userRecordKey)
}
user.UpdateMeta()
return db.Put(user)
}
// GetAuthToken returns the current auth token.
func GetAuthToken() (*AuthTokenRecord, error) {
// Check cache.
accountCacheLock.Lock()
defer accountCacheLock.Unlock()
if cachedAuthToken != nil {
return cachedAuthToken, nil
}
// Load from disk.
r, err := db.Get(authTokenRecordKey)
if err != nil {
return nil, err
}
// Unwrap record.
if r.IsWrapped() {
// only allocate a new struct, if we need it
newAuthRecord := &AuthTokenRecord{}
err = record.Unwrap(r, newAuthRecord)
if err != nil {
return nil, err
}
cachedAuthToken = newAuthRecord
return newAuthRecord, nil
}
// Or adjust type.
newAuthRecord, ok := r.(*AuthTokenRecord)
if !ok {
return nil, fmt.Errorf("record not of type *AuthTokenRecord, but %T", r)
}
cachedAuthToken = newAuthRecord
return newAuthRecord, nil
}
// Save saves the auth token to the database.
func (authToken *AuthTokenRecord) Save() error {
// Update cache.
accountCacheLock.Lock()
defer accountCacheLock.Unlock()
cachedAuthToken = authToken
// Set, check and update metadata.
if !authToken.KeyIsSet() {
authToken.SetKey(authTokenRecordKey)
}
authToken.UpdateMeta()
authToken.Meta().MakeSecret()
authToken.Meta().MakeCrownJewel()
return db.Put(authToken)
}

127
spn/access/features.go Normal file
View File

@@ -0,0 +1,127 @@
package access
import "github.com/safing/portmaster/spn/access/account"
// Feature describes a notable part of the program.
type Feature struct {
Name string
ID string
RequiredFeatureID account.FeatureID
ConfigKey string
ConfigScope string
InPackage *Package
Comment string
Beta bool
ComingSoon bool
icon string
}
// Package combines a set of features.
type Package struct {
Name string
HexColor string
InfoURL string
}
var (
infoURL = "https://safing.io/pricing/"
packageFree = &Package{
Name: "Free",
HexColor: "#ffffff",
InfoURL: infoURL,
}
packagePlus = &Package{
Name: "Plus",
HexColor: "#2fcfae",
InfoURL: infoURL,
}
packagePro = &Package{
Name: "Pro",
HexColor: "#029ad0",
InfoURL: infoURL,
}
features = []Feature{
{
Name: "Secure DNS",
ID: "dns",
ConfigScope: "dns/",
InPackage: packageFree,
icon: `
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round"
d="M12 21a9.004 9.004 0 008.716-6.747M12 21a9.004 9.004 0 01-8.716-6.747M12 21c2.485 0 4.5-4.03 4.5-9S14.485 3 12 3m0 18c-2.485 0-4.5-4.03-4.5-9S9.515 3 12 3m0 0a8.997 8.997 0 017.843 4.582M12 3a8.997 8.997 0 00-7.843 4.582m15.686 0A11.953 11.953 0 0112 10.5c-2.998 0-5.74-1.1-7.843-2.918m15.686 0A8.959 8.959 0 0121 12c0 .778-.099 1.533-.284 2.253m0 0A17.919 17.919 0 0112 16.5c-3.162 0-6.133-.815-8.716-2.247m0 0A9.015 9.015 0 013 12c0-1.605.42-3.113 1.157-4.418" />
</svg>
`,
},
{
Name: "Privacy Filter",
ID: "filter",
ConfigScope: "filter/",
InPackage: packageFree,
icon: `
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" d="M3.98 8.223A10.477 10.477 0 001.934 12C3.226 16.338 7.244 19.5 12 19.5c.993 0 1.953-.138 2.863-.395M6.228 6.228A10.45 10.45 0 0112 4.5c4.756 0 8.773 3.162 10.065 7.498a10.523 10.523 0 01-4.293 5.774M6.228 6.228L3 3m3.228 3.228l3.65 3.65m7.894 7.894L21 21m-3.228-3.228l-3.65-3.65m0 0a3 3 0 10-4.243-4.243m4.242 4.242L9.88 9.88" />
</svg>
`,
},
{
Name: "Network History",
ID: string(account.FeatureHistory),
RequiredFeatureID: account.FeatureHistory,
ConfigKey: "history/enable",
ConfigScope: "history/",
InPackage: packagePlus,
icon: `
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round"
d="M12 6.042A8.967 8.967 0 006 3.75c-1.052 0-2.062.18-3 .512v14.25A8.987 8.987 0 016 18c2.305 0 4.408.867 6 2.292m0-14.25a8.966 8.966 0 016-2.292c1.052 0 2.062.18 3 .512v14.25A8.987 8.987 0 0018 18a8.967 8.967 0 00-6 2.292m0-14.25v14.25" />
</svg>
`,
},
{
Name: "Bandwidth Visibility",
ID: string(account.FeatureBWVis),
RequiredFeatureID: account.FeatureBWVis,
InPackage: packagePlus,
Beta: true,
icon: `
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round"
d="M3 13.125C3 12.504 3.504 12 4.125 12h2.25c.621 0 1.125.504 1.125 1.125v6.75C7.5 20.496 6.996 21 6.375 21h-2.25A1.125 1.125 0 013 19.875v-6.75zM9.75 8.625c0-.621.504-1.125 1.125-1.125h2.25c.621 0 1.125.504 1.125 1.125v11.25c0 .621-.504 1.125-1.125 1.125h-2.25a1.125 1.125 0 01-1.125-1.125V8.625zM16.5 4.125c0-.621.504-1.125 1.125-1.125h2.25C20.496 3 21 3.504 21 4.125v15.75c0 .621-.504 1.125-1.125 1.125h-2.25a1.125 1.125 0 01-1.125-1.125V4.125z" />
</svg>
`,
},
{
Name: "Safing Support",
ID: string(account.FeatureSafingSupport),
RequiredFeatureID: account.FeatureSafingSupport,
InPackage: packagePlus,
icon: `
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round"
d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" />
</svg>
`,
},
{
Name: "Safing Privacy Network",
ID: string(account.FeatureSPN),
RequiredFeatureID: account.FeatureSPN,
ConfigKey: "spn/enable",
ConfigScope: "spn/",
InPackage: packagePro,
icon: `
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" stroke="currentColor" class="text-green-300">
<g fill="none" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5">
<path
d="M6.488 15.581c.782.781.782 2.048 0 2.829-.782.781-2.049.781-2.83 0-.782-.781-.782-2.048 0-2.829.781-.781 2.048-.781 2.83 0M13.415 3.586c.782.781.782 2.048 0 2.829-.782.781-2.049.781-2.83 0-.782-.781-.782-2.048 0-2.829.781-.781 2.049-.781 2.83 0M20.343 15.58c.782.781.782 2.048 0 2.829-.782.781-2.049.781-2.83 0-.782-.781-.782-2.048 0-2.829.781-.781 2.048-.781 2.83 0">
</path>
<path
d="M17.721 18.581C16.269 20.071 14.246 21 12 21c-1.146 0-2.231-.246-3.215-.68M4.293 15.152c-.56-1.999-.352-4.21.769-6.151.574-.995 1.334-1.814 2.205-2.449M13.975 5.254c2.017.512 3.834 1.799 4.957 3.743.569.985.899 2.041 1.018 3.103">
</path>
</g>
</svg>
`,
},
}
)

194
spn/access/module.go Normal file
View File

@@ -0,0 +1,194 @@
package access
import (
"context"
"errors"
"fmt"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/config"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/spn/access/account"
"github.com/safing/portmaster/spn/access/token"
"github.com/safing/portmaster/spn/conf"
)
var (
module *modules.Module
accountUpdateTask *modules.Task
tokenIssuerIsFailing = abool.New()
tokenIssuerRetryDuration = 10 * time.Minute
// AccountUpdateEvent is fired when the account has changed in any way.
AccountUpdateEvent = "account update"
)
// Errors.
var (
ErrDeviceIsLocked = errors.New("device is locked")
ErrDeviceLimitReached = errors.New("device limit reached")
ErrFallbackNotAvailable = errors.New("fallback tokens not available, token issuer is online")
ErrInvalidCredentials = errors.New("invalid credentials")
ErrMayNotUseSPN = errors.New("may not use SPN")
ErrNotLoggedIn = errors.New("not logged in")
)
func init() {
module = modules.Register("access", prep, start, stop, "terminal")
}
func prep() error {
module.RegisterEvent(AccountUpdateEvent, true)
// Register API handlers.
if conf.Client() {
err := registerAPIEndpoints()
if err != nil {
return err
}
}
return nil
}
func start() error {
// Initialize zones.
if err := InitializeZones(); err != nil {
return err
}
if conf.Client() {
// Load tokens from database.
loadTokens()
// Register new task.
accountUpdateTask = module.NewTask(
"update account",
UpdateAccount,
).Repeat(24 * time.Hour).Schedule(time.Now().Add(1 * time.Minute))
}
return nil
}
func stop() error {
if conf.Client() {
// Stop account update task.
accountUpdateTask.Cancel()
accountUpdateTask = nil
// Store tokens to database.
storeTokens()
}
// Reset zones.
token.ResetRegistry()
return nil
}
// UpdateAccount updates the user account and fetches new tokens, if needed.
func UpdateAccount(_ context.Context, task *modules.Task) error {
// Retry sooner if the token issuer is failing.
defer func() {
if tokenIssuerIsFailing.IsSet() && task != nil {
task.Schedule(time.Now().Add(tokenIssuerRetryDuration))
}
}()
// Get current user.
u, err := GetUser()
if err == nil {
// Do not update if we just updated.
if time.Since(time.Unix(u.Meta().Modified, 0)) < 2*time.Minute {
return nil
}
}
u, _, err = UpdateUser()
if err != nil {
return fmt.Errorf("failed to update user profile: %w", err)
}
err = UpdateTokens()
if err != nil {
return fmt.Errorf("failed to get tokens: %w", err)
}
// Schedule next check.
switch {
case u == nil: // No user.
case u.Subscription == nil: // No subscription.
case u.Subscription.EndsAt == nil: // Subscription not active
case time.Until(*u.Subscription.EndsAt) < 24*time.Hour &&
time.Since(*u.Subscription.EndsAt) < 24*time.Hour:
// Update account every hour 24h hours before and after the subscription ends.
task.Schedule(time.Now().Add(time.Hour))
case u.Subscription.NextBillingDate == nil: // No auto-subscription.
case time.Until(*u.Subscription.NextBillingDate) < 24*time.Hour &&
time.Since(*u.Subscription.NextBillingDate) < 24*time.Hour:
// Update account every hour 24h hours before and after the next billing date.
task.Schedule(time.Now().Add(time.Hour))
}
return nil
}
func enableSPN() {
err := config.SetConfigOption("spn/enable", true)
if err != nil {
log.Warningf("spn/access: failed to enable the SPN during login: %s", err)
}
}
func disableSPN() {
err := config.SetConfigOption("spn/enable", false)
if err != nil {
log.Warningf("spn/access: failed to disable the SPN during logout: %s", err)
}
}
// TokenIssuerIsFailing returns whether token issuing is currently failing.
func TokenIssuerIsFailing() bool {
return tokenIssuerIsFailing.IsSet()
}
func tokenIssuerFailed() {
if !tokenIssuerIsFailing.SetToIf(false, true) {
return
}
if !module.Online() {
return
}
accountUpdateTask.Schedule(time.Now().Add(tokenIssuerRetryDuration))
}
// IsLoggedIn returns whether a User is currently logged in.
func (user *UserRecord) IsLoggedIn() bool {
user.Lock()
defer user.Unlock()
switch user.State {
case account.UserStateNone, account.UserStateLoggedOut:
return false
default:
return true
}
}
// MayUseTheSPN returns whether the currently logged in User may use the SPN.
func (user *UserRecord) MayUseTheSPN() bool {
user.Lock()
defer user.Unlock()
return user.User.MayUseSPN()
}

13
spn/access/module_test.go Normal file
View File

@@ -0,0 +1,13 @@
package access
import (
"testing"
"github.com/safing/portmaster/service/core/pmtesting"
"github.com/safing/portmaster/spn/conf"
)
func TestMain(m *testing.M) {
conf.EnableClient(true)
pmtesting.TestMain(m, module)
}

105
spn/access/notify.go Normal file
View File

@@ -0,0 +1,105 @@
package access
import (
"fmt"
"strings"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/notifications"
)
const (
day = 24 * time.Hour
week = 7 * day
endOfPackageNearNotifID = "access:end-of-package-near"
)
func notifyOfPackageEnd(u *UserRecord) {
// TODO: Check if subscription auto-renews.
// Skip if there is not active subscription or if it has ended already.
switch {
case u.Subscription == nil, // No subscription.
u.Subscription.EndsAt == nil, // Subscription not active.
u.Subscription.NextBillingDate != nil, // Subscription is auto-renewing.
time.Now().After(*u.Subscription.EndsAt): // Subscription has ended.
return
}
// Calculate durations.
sinceLastNotified := 52 * week // Never.
if u.LastNotifiedOfEnd != nil {
sinceLastNotified = time.Since(*u.LastNotifiedOfEnd)
}
untilEnd := time.Until(*u.Subscription.EndsAt)
// Notify every two days in the week before end.
notifType := notifications.Info
switch {
case untilEnd < week && sinceLastNotified > 2*day:
// Notify 7, 5, 3 and 1 days before end.
if untilEnd < 4*day {
notifType = notifications.Warning
}
fallthrough
case u.CurrentPlan != nil && u.CurrentPlan.Months >= 6 &&
untilEnd < 4*week && sinceLastNotified > week:
// Notify 4, 3 and 2 weeks before end - on long running packages.
// Get names and messages.
packageNameTitle := "Portmaster Package"
if u.CurrentPlan != nil {
packageNameTitle = u.CurrentPlan.Name
}
packageNameBody := packageNameTitle
if !strings.HasSuffix(packageNameBody, " Package") {
packageNameBody += " Package"
}
var endsText string
daysUntilEnd := untilEnd / day
switch daysUntilEnd { //nolint:exhaustive
case 0:
endsText = "today"
case 1:
endsText = "tomorrow"
default:
endsText = fmt.Sprintf("in %d days", daysUntilEnd)
}
// Send notification.
notifications.Notify(&notifications.Notification{
EventID: endOfPackageNearNotifID,
Type: notifType,
Title: fmt.Sprintf("%s About to Expire", packageNameTitle),
Message: fmt.Sprintf(
"Your current %s ends %s. Extend it to keep your full privacy protections.",
packageNameBody,
endsText,
),
ShowOnSystem: notifType == notifications.Warning,
AvailableActions: []*notifications.Action{
{
Text: "Open Account Page",
Type: notifications.ActionTypeOpenURL,
Payload: "https://account.safing.io",
},
{
ID: "ack",
Text: "Got it!",
},
},
})
// Save that we sent a notification.
now := time.Now()
u.LastNotifiedOfEnd = &now
err := u.Save()
if err != nil {
log.Warningf("spn/access: failed to save user after sending subscription ending soon notification: %s", err)
}
}
}

75
spn/access/op_auth.go Normal file
View File

@@ -0,0 +1,75 @@
package access
import (
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/access/token"
"github.com/safing/portmaster/spn/terminal"
)
// OpTypeAccessCodeAuth is the type ID of the auth operation.
const OpTypeAccessCodeAuth = "auth"
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: OpTypeAccessCodeAuth,
Start: checkAccessCode,
})
}
// AuthorizeOp is used to authorize a session.
type AuthorizeOp struct {
terminal.OneOffOperationBase
}
// Type returns the type ID.
func (op *AuthorizeOp) Type() string {
return OpTypeAccessCodeAuth
}
// AuthorizeToTerminal starts an authorization operation.
func AuthorizeToTerminal(t terminal.Terminal) (*AuthorizeOp, *terminal.Error) {
op := &AuthorizeOp{}
op.Init()
newToken, err := GetToken(ExpandAndConnectZones)
if err != nil {
return nil, terminal.ErrInternalError.With("failed to get access token: %w", err)
}
tErr := t.StartOperation(op, container.New(newToken.Raw()), 10*time.Second)
if tErr != nil {
return nil, terminal.ErrInternalError.With("failed to init auth op: %w", tErr)
}
return op, nil
}
func checkAccessCode(t terminal.Terminal, opID uint32, initData *container.Container) (terminal.Operation, *terminal.Error) {
// Parse provided access token.
receivedToken, err := token.ParseRawToken(initData.CompileData())
if err != nil {
return nil, terminal.ErrMalformedData.With("failed to parse access token: %w", err)
}
// Check if token is valid.
granted, err := VerifyToken(receivedToken)
if err != nil {
return nil, terminal.ErrPermissionDenied.With("invalid access token: %w", err)
}
// Get the authorizing terminal for applying the granted permission.
authTerm, ok := t.(terminal.AuthorizingTerminal)
if !ok {
return nil, terminal.ErrIncorrectUsage.With("terminal does not handle authorization")
}
// Grant permissions.
authTerm.GrantPermission(granted)
log.Debugf("spn/access: granted %s permissions via %s zone", t.FmtID(), receivedToken.Zone)
// End successfully.
return nil, terminal.ErrExplicitAck
}

131
spn/access/storage.go Normal file
View File

@@ -0,0 +1,131 @@
package access
import (
"context"
"errors"
"fmt"
"time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/access/token"
)
func loadTokens() {
for _, zone := range persistentZones {
// Get handler of zone.
handler, ok := token.GetHandler(zone)
if !ok {
log.Warningf("spn/access: could not find zone %s for loading tokens", zone)
continue
}
// Get data from database.
r, err := db.Get(fmt.Sprintf(tokenStorageKeyTemplate, zone))
if err != nil {
if errors.Is(err, database.ErrNotFound) {
log.Debugf("spn/access: no %s tokens to load", zone)
} else {
log.Warningf("spn/access: failed to load %s tokens: %s", zone, err)
}
continue
}
// Get wrapper.
wrapper, ok := r.(*record.Wrapper)
if !ok {
log.Warningf("spn/access: failed to parse %s tokens: expected wrapper, got %T", zone, r)
continue
}
// Load into handler.
err = handler.Load(wrapper.Data)
if err != nil {
log.Warningf("spn/access: failed to load %s tokens: %s", zone, err)
}
log.Infof("spn/access: loaded %d %s tokens", handler.Amount(), zone)
}
}
func storeTokens() {
for _, zone := range persistentZones {
// Get handler of zone.
handler, ok := token.GetHandler(zone)
if !ok {
log.Warningf("spn/access: could not find zone %s for storing tokens", zone)
continue
}
// Generate storage key.
storageKey := fmt.Sprintf(tokenStorageKeyTemplate, zone)
// Check if there is data to save.
amount := handler.Amount()
if amount == 0 {
// Remove possible old entry from database.
err := db.Delete(storageKey)
if err != nil {
log.Warningf("spn/access: failed to delete possible old %s tokens from storage: %s", zone, err)
}
log.Debugf("spn/access: no %s tokens to store", zone)
continue
}
// Export data.
data, err := handler.Save()
if err != nil {
log.Warningf("spn/access: failed to export %s tokens for storing: %s", zone, err)
continue
}
// Wrap data into raw record.
r, err := record.NewWrapper(storageKey, nil, dsd.RAW, data)
if err != nil {
log.Warningf("spn/access: failed to prepare %s token export for storing: %s", zone, err)
continue
}
// Let tokens expire after one month.
// This will regularly happen when we switch zones.
r.UpdateMeta()
r.Meta().MakeSecret()
r.Meta().MakeCrownJewel()
r.Meta().SetRelativateExpiry(30 * 86400)
// Save to database.
err = db.Put(r)
if err != nil {
log.Warningf("spn/access: failed to store %s tokens: %s", zone, err)
continue
}
log.Infof("spn/access: stored %d %s tokens", amount, zone)
}
}
func clearTokens() {
for _, zone := range persistentZones {
// Get handler of zone.
handler, ok := token.GetHandler(zone)
if !ok {
log.Warningf("spn/access: could not find zone %s for clearing tokens", zone)
continue
}
// Clear tokens.
handler.Clear()
}
// Purge database storage prefix.
ctx, cancel := context.WithTimeout(module.Ctx, 10*time.Second)
defer cancel()
n, err := db.Purge(ctx, query.New(fmt.Sprintf(tokenStorageKeyTemplate, "")))
if err != nil {
log.Warningf("spn/access: failed to clear token storages: %s", err)
return
}
log.Infof("spn/access: cleared %d token storages", n)
}

View File

@@ -0,0 +1,15 @@
package token
import "errors"
// Errors.
var (
ErrEmpty = errors.New("token storage is empty")
ErrNoZone = errors.New("no zone specified")
ErrTokenInvalid = errors.New("token is invalid")
ErrTokenMalformed = errors.New("token malformed")
ErrTokenUsed = errors.New("token already used")
ErrZoneMismatch = errors.New("zone mismatch")
ErrZoneTaken = errors.New("zone taken")
ErrZoneUnknown = errors.New("zone unknown")
)

View File

@@ -0,0 +1,13 @@
package token
import (
"testing"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/service/core/pmtesting"
)
func TestMain(m *testing.M) {
module := modules.Register("token", nil, nil, nil, "rng")
pmtesting.TestMain(m, module)
}

552
spn/access/token/pblind.go Normal file
View File

@@ -0,0 +1,552 @@
package token
import (
"crypto/elliptic"
"crypto/rand"
"errors"
"fmt"
"math"
"math/big"
mrand "math/rand"
"sync"
"github.com/mr-tron/base58"
"github.com/rot256/pblind"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
)
const pblindSecretSize = 32
// PBlindToken is token based on the pblind library.
type PBlindToken struct {
Serial int `json:"N,omitempty"`
Token []byte `json:"T,omitempty"`
Signature *pblind.Signature `json:"S,omitempty"`
}
// Pack packs the token.
func (pbt *PBlindToken) Pack() ([]byte, error) {
return dsd.Dump(pbt, dsd.CBOR)
}
// UnpackPBlindToken unpacks the token.
func UnpackPBlindToken(token []byte) (*PBlindToken, error) {
t := &PBlindToken{}
_, err := dsd.Load(token, t)
if err != nil {
return nil, err
}
return t, nil
}
// PBlindHandler is a handler for the pblind tokens.
type PBlindHandler struct {
sync.Mutex
opts *PBlindOptions
publicKey *pblind.PublicKey
privateKey *pblind.SecretKey
storageLock sync.Mutex
Storage []*PBlindToken
// Client request state.
requestStateLock sync.Mutex
requestState []RequestState
}
// PBlindOptions are options for the PBlindHandler.
type PBlindOptions struct {
Zone string
CurveName string
Curve elliptic.Curve
PublicKey string
PrivateKey string
BatchSize int
UseSerials bool
RandomizeOrder bool
Fallback bool
SignalShouldRequest func(Handler)
DoubleSpendProtection func([]byte) error
}
// PBlindSignerState is a signer state.
type PBlindSignerState struct {
signers []*pblind.StateSigner
}
// PBlindSetupResponse is a setup response.
type PBlindSetupResponse struct {
Msgs []*pblind.Message1
}
// PBlindTokenRequest is a token request.
type PBlindTokenRequest struct {
Msgs []*pblind.Message2
}
// IssuedPBlindTokens are issued pblind tokens.
type IssuedPBlindTokens struct {
Msgs []*pblind.Message3
}
// RequestState is a request state.
type RequestState struct {
Token []byte
State *pblind.StateRequester
}
// NewPBlindHandler creates a new pblind handler.
func NewPBlindHandler(opts PBlindOptions) (*PBlindHandler, error) {
pbh := &PBlindHandler{
opts: &opts,
}
// Check curve, get from name.
if opts.Curve == nil {
switch opts.CurveName {
case "P-256":
opts.Curve = elliptic.P256()
case "P-384":
opts.Curve = elliptic.P384()
case "P-521":
opts.Curve = elliptic.P521()
default:
return nil, errors.New("no curve supplied")
}
} else if opts.CurveName != "" {
return nil, errors.New("both curve and curve name supplied")
}
// Load keys.
switch {
case pbh.opts.PrivateKey != "":
keyData, err := base58.Decode(pbh.opts.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to decode private key: %w", err)
}
pivateKey := pblind.SecretKeyFromBytes(pbh.opts.Curve, keyData)
pbh.privateKey = &pivateKey
publicKey := pbh.privateKey.GetPublicKey()
pbh.publicKey = &publicKey
// Check public key if also provided.
if pbh.opts.PublicKey != "" {
if pbh.opts.PublicKey != base58.Encode(pbh.publicKey.Bytes()) {
return nil, errors.New("private and public mismatch")
}
}
case pbh.opts.PublicKey != "":
keyData, err := base58.Decode(pbh.opts.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to decode public key: %w", err)
}
publicKey, err := pblind.PublicKeyFromBytes(pbh.opts.Curve, keyData)
if err != nil {
return nil, fmt.Errorf("failed to decode public key: %w", err)
}
pbh.publicKey = &publicKey
default:
return nil, errors.New("no key supplied")
}
return pbh, nil
}
func (pbh *PBlindHandler) makeInfo(serial int) (*pblind.Info, error) {
// Gather data for info.
infoData := container.New()
infoData.AppendAsBlock([]byte(pbh.opts.Zone))
if pbh.opts.UseSerials {
infoData.AppendInt(serial)
}
// Compress to point.
info, err := pblind.CompressInfo(pbh.opts.Curve, infoData.CompileData())
if err != nil {
return nil, fmt.Errorf("failed to compress info: %w", err)
}
return &info, nil
}
// Zone returns the zone name.
func (pbh *PBlindHandler) Zone() string {
return pbh.opts.Zone
}
// ShouldRequest returns whether the new tokens should be requested.
func (pbh *PBlindHandler) ShouldRequest() bool {
pbh.storageLock.Lock()
defer pbh.storageLock.Unlock()
return pbh.shouldRequest()
}
func (pbh *PBlindHandler) shouldRequest() bool {
// Return true if storage is at or below 10%.
return len(pbh.Storage) == 0 || pbh.opts.BatchSize/len(pbh.Storage) > 10
}
// Amount returns the current amount of tokens in this handler.
func (pbh *PBlindHandler) Amount() int {
pbh.storageLock.Lock()
defer pbh.storageLock.Unlock()
return len(pbh.Storage)
}
// IsFallback returns whether this handler should only be used as a fallback.
func (pbh *PBlindHandler) IsFallback() bool {
return pbh.opts.Fallback
}
// CreateSetup sets up signers for a request.
func (pbh *PBlindHandler) CreateSetup() (state *PBlindSignerState, setupResponse *PBlindSetupResponse, err error) {
state = &PBlindSignerState{
signers: make([]*pblind.StateSigner, pbh.opts.BatchSize),
}
setupResponse = &PBlindSetupResponse{
Msgs: make([]*pblind.Message1, pbh.opts.BatchSize),
}
// Go through the batch.
for i := 0; i < pbh.opts.BatchSize; i++ {
info, err := pbh.makeInfo(i + 1)
if err != nil {
return nil, nil, fmt.Errorf("failed to create info #%d: %w", i, err)
}
// Create signer.
signer, err := pblind.CreateSigner(*pbh.privateKey, *info)
if err != nil {
return nil, nil, fmt.Errorf("failed to create signer #%d: %w", i, err)
}
state.signers[i] = signer
// Create request setup.
setupMsg, err := signer.CreateMessage1()
if err != nil {
return nil, nil, fmt.Errorf("failed to create setup msg #%d: %w", i, err)
}
setupResponse.Msgs[i] = &setupMsg
}
return state, setupResponse, nil
}
// CreateTokenRequest creates a token request to be sent to the token server.
func (pbh *PBlindHandler) CreateTokenRequest(requestSetup *PBlindSetupResponse) (request *PBlindTokenRequest, err error) {
// Check request setup data.
if len(requestSetup.Msgs) != pbh.opts.BatchSize {
return nil, fmt.Errorf("invalid request setup msg count of %d", len(requestSetup.Msgs))
}
// Lock and reset the request state.
pbh.requestStateLock.Lock()
defer pbh.requestStateLock.Unlock()
pbh.requestState = make([]RequestState, pbh.opts.BatchSize)
request = &PBlindTokenRequest{
Msgs: make([]*pblind.Message2, pbh.opts.BatchSize),
}
// Go through the batch.
for i := 0; i < pbh.opts.BatchSize; i++ {
// Check if we have setup data.
if requestSetup.Msgs[i] == nil {
return nil, fmt.Errorf("missing setup data #%d", i)
}
// Generate secret token.
token := make([]byte, pblindSecretSize)
n, err := rand.Read(token) //nolint:gosec // False positive - check the imports.
if err != nil {
return nil, fmt.Errorf("failed to get random token #%d: %w", i, err)
}
if n != pblindSecretSize {
return nil, fmt.Errorf("failed to get full random token #%d: only got %d bytes", i, n)
}
pbh.requestState[i].Token = token
// Create public metadata.
info, err := pbh.makeInfo(i + 1)
if err != nil {
return nil, fmt.Errorf("failed to make token info #%d: %w", i, err)
}
// Create request and request state.
requester, err := pblind.CreateRequester(*pbh.publicKey, *info, token)
if err != nil {
return nil, fmt.Errorf("failed to create request state #%d: %w", i, err)
}
pbh.requestState[i].State = requester
err = requester.ProcessMessage1(*requestSetup.Msgs[i])
if err != nil {
return nil, fmt.Errorf("failed to process setup message #%d: %w", i, err)
}
// Create request message.
requestMsg, err := requester.CreateMessage2()
if err != nil {
return nil, fmt.Errorf("failed to create request message #%d: %w", i, err)
}
request.Msgs[i] = &requestMsg
}
return request, nil
}
// IssueTokens sign the requested tokens.
func (pbh *PBlindHandler) IssueTokens(state *PBlindSignerState, request *PBlindTokenRequest) (response *IssuedPBlindTokens, err error) {
// Check request data.
if len(request.Msgs) != pbh.opts.BatchSize {
return nil, fmt.Errorf("invalid request msg count of %d", len(request.Msgs))
}
if len(state.signers) != pbh.opts.BatchSize {
return nil, fmt.Errorf("invalid request state count of %d", len(request.Msgs))
}
// Create response.
response = &IssuedPBlindTokens{
Msgs: make([]*pblind.Message3, pbh.opts.BatchSize),
}
// Go through the batch.
for i := 0; i < pbh.opts.BatchSize; i++ {
// Check if we have request data.
if request.Msgs[i] == nil {
return nil, fmt.Errorf("missing request data #%d", i)
}
// Process request msg.
err = state.signers[i].ProcessMessage2(*request.Msgs[i])
if err != nil {
return nil, fmt.Errorf("failed to process request msg #%d: %w", i, err)
}
// Issue token.
responseMsg, err := state.signers[i].CreateMessage3()
if err != nil {
return nil, fmt.Errorf("failed to issue token #%d: %w", i, err)
}
response.Msgs[i] = &responseMsg
}
return response, nil
}
// ProcessIssuedTokens processes the issued token from the server.
func (pbh *PBlindHandler) ProcessIssuedTokens(issuedTokens *IssuedPBlindTokens) error {
// Check data.
if len(issuedTokens.Msgs) != pbh.opts.BatchSize {
return fmt.Errorf("invalid issued token count of %d", len(issuedTokens.Msgs))
}
// Step 1: Process issued tokens.
// Lock and reset the request state.
pbh.requestStateLock.Lock()
defer pbh.requestStateLock.Unlock()
defer func() {
pbh.requestState = make([]RequestState, pbh.opts.BatchSize)
}()
finalizedTokens := make([]*PBlindToken, pbh.opts.BatchSize)
// Go through the batch.
for i := 0; i < pbh.opts.BatchSize; i++ {
// Finalize token.
err := pbh.requestState[i].State.ProcessMessage3(*issuedTokens.Msgs[i])
if err != nil {
return fmt.Errorf("failed to create final signature #%d: %w", i, err)
}
// Get and check final signature.
signature, err := pbh.requestState[i].State.Signature()
if err != nil {
return fmt.Errorf("failed to create final signature #%d: %w", i, err)
}
info, err := pbh.makeInfo(i + 1)
if err != nil {
return fmt.Errorf("failed to make token info #%d: %w", i, err)
}
if !pbh.publicKey.Check(signature, *info, pbh.requestState[i].Token) {
return fmt.Errorf("invalid signature on #%d", i)
}
// Save to temporary slice.
newToken := &PBlindToken{
Token: pbh.requestState[i].Token,
Signature: &signature,
}
if pbh.opts.UseSerials {
newToken.Serial = i + 1
}
finalizedTokens[i] = newToken
}
// Step 2: Randomize received tokens
if pbh.opts.RandomizeOrder {
rInt, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
if err != nil {
return fmt.Errorf("failed to get seed for shuffle: %w", err)
}
mr := mrand.New(mrand.NewSource(rInt.Int64())) //nolint:gosec
mr.Shuffle(len(finalizedTokens), func(i, j int) {
finalizedTokens[i], finalizedTokens[j] = finalizedTokens[j], finalizedTokens[i]
})
}
// Step 3: Add tokens to storage.
// Wait for all processing to be complete, as using tokens from a faulty
// batch can be dangerous, as the server could be doing this purposely to
// create conditions that may benefit an attacker.
pbh.storageLock.Lock()
defer pbh.storageLock.Unlock()
// Add finalized tokens to storage.
pbh.Storage = append(pbh.Storage, finalizedTokens...)
return nil
}
// GetToken returns a token.
func (pbh *PBlindHandler) GetToken() (token *Token, err error) {
pbh.storageLock.Lock()
defer pbh.storageLock.Unlock()
// Check if we have supply.
if len(pbh.Storage) == 0 {
return nil, ErrEmpty
}
// Pack token.
data, err := pbh.Storage[0].Pack()
if err != nil {
return nil, fmt.Errorf("failed to pack token: %w", err)
}
// Shift to next token.
pbh.Storage = pbh.Storage[1:]
// Check if we should signal that we should request tokens.
if pbh.opts.SignalShouldRequest != nil && pbh.shouldRequest() {
pbh.opts.SignalShouldRequest(pbh)
}
return &Token{
Zone: pbh.opts.Zone,
Data: data,
}, nil
}
// Verify verifies the given token.
func (pbh *PBlindHandler) Verify(token *Token) error {
// Check if zone matches.
if token.Zone != pbh.opts.Zone {
return ErrZoneMismatch
}
// Unpack token.
t, err := UnpackPBlindToken(token.Data)
if err != nil {
return fmt.Errorf("%w: %w", ErrTokenMalformed, err)
}
// Check if serial is valid.
switch {
case pbh.opts.UseSerials && t.Serial > 0 && t.Serial <= pbh.opts.BatchSize:
// Using serials in accepted range.
case !pbh.opts.UseSerials && t.Serial == 0:
// Not using serials and serial is zero.
default:
return fmt.Errorf("%w: invalid serial", ErrTokenMalformed)
}
// Build info for checking signature.
info, err := pbh.makeInfo(t.Serial)
if err != nil {
return fmt.Errorf("%w: %w", ErrTokenMalformed, err)
}
// Check signature.
if !pbh.publicKey.Check(*t.Signature, *info, t.Token) {
return ErrTokenInvalid
}
// Check for double spending.
if pbh.opts.DoubleSpendProtection != nil {
if err := pbh.opts.DoubleSpendProtection(t.Token); err != nil {
return fmt.Errorf("%w: %w", ErrTokenUsed, err)
}
}
return nil
}
// PBlindStorage is a storage for pblind tokens.
type PBlindStorage struct {
Storage []*PBlindToken
}
// Save serializes and returns the current tokens.
func (pbh *PBlindHandler) Save() ([]byte, error) {
pbh.storageLock.Lock()
defer pbh.storageLock.Unlock()
if len(pbh.Storage) == 0 {
return nil, ErrEmpty
}
s := &PBlindStorage{
Storage: pbh.Storage,
}
return dsd.Dump(s, dsd.CBOR)
}
// Load loads the given tokens into the handler.
func (pbh *PBlindHandler) Load(data []byte) error {
pbh.storageLock.Lock()
defer pbh.storageLock.Unlock()
s := &PBlindStorage{}
_, err := dsd.Load(data, s)
if err != nil {
return err
}
// Check signatures on load.
for _, t := range s.Storage {
// Build info for checking signature.
info, err := pbh.makeInfo(t.Serial)
if err != nil {
return err
}
// Check signature.
if !pbh.publicKey.Check(*t.Signature, *info, t.Token) {
return ErrTokenInvalid
}
}
pbh.Storage = s.Storage
return nil
}
// Clear clears all the tokens in the handler.
func (pbh *PBlindHandler) Clear() {
pbh.storageLock.Lock()
defer pbh.storageLock.Unlock()
pbh.Storage = nil
}

View File

@@ -0,0 +1,39 @@
package token
import (
"crypto/elliptic"
"fmt"
"testing"
"github.com/mr-tron/base58"
"github.com/rot256/pblind"
)
func TestGeneratePBlindKeys(t *testing.T) {
t.Parallel()
for _, curve := range []elliptic.Curve{
elliptic.P256(),
elliptic.P384(),
elliptic.P521(),
} {
privateKey, err := pblind.NewSecretKey(curve)
if err != nil {
t.Fatal(err)
}
publicKey := privateKey.GetPublicKey()
fmt.Printf(
"%s (%dbit) private key: %s\n",
curve.Params().Name,
curve.Params().BitSize,
base58.Encode(privateKey.Bytes()),
)
fmt.Printf(
"%s (%dbit) public key: %s\n",
curve.Params().Name,
curve.Params().BitSize,
base58.Encode(publicKey.Bytes()),
)
}
}

View File

@@ -0,0 +1,260 @@
package token
import (
"crypto/elliptic"
"encoding/asn1"
"testing"
"time"
"github.com/rot256/pblind"
)
const PBlindTestZone = "test-pblind"
func init() {
// Combined testing config.
h, err := NewPBlindHandler(PBlindOptions{
Zone: PBlindTestZone,
Curve: elliptic.P256(),
PrivateKey: "HbwGtLsqek1Fdwuz1MhNQfiY7tj9EpWHeMWHPZ9c6KYY",
UseSerials: true,
BatchSize: 1000,
RandomizeOrder: true,
})
if err != nil {
panic(err)
}
err = RegisterPBlindHandler(h)
if err != nil {
panic(err)
}
}
func TestPBlind(t *testing.T) {
t.Parallel()
opts := &PBlindOptions{
Zone: PBlindTestZone,
Curve: elliptic.P256(),
UseSerials: true,
BatchSize: 1000,
RandomizeOrder: true,
}
// Issuer
opts.PrivateKey = "HbwGtLsqek1Fdwuz1MhNQfiY7tj9EpWHeMWHPZ9c6KYY"
issuer, err := NewPBlindHandler(*opts)
if err != nil {
t.Fatal(err)
}
// Client
opts.PrivateKey = ""
opts.PublicKey = "285oMDh3w5mxyFgpmmURifKfhkcqwwsdnePpPZ6Nqm8cc"
client, err := NewPBlindHandler(*opts)
if err != nil {
t.Fatal(err)
}
// Verifier
verifier, err := NewPBlindHandler(*opts)
if err != nil {
t.Fatal(err)
}
// Play through the whole use case.
signerState, setupResponse, err := issuer.CreateSetup()
if err != nil {
t.Fatal(err)
}
request, err := client.CreateTokenRequest(setupResponse)
if err != nil {
t.Fatal(err)
}
issuedTokens, err := issuer.IssueTokens(signerState, request)
if err != nil {
t.Fatal(err)
}
err = client.ProcessIssuedTokens(issuedTokens)
if err != nil {
t.Fatal(err)
}
token, err := client.GetToken()
if err != nil {
t.Fatal(err)
}
err = verifier.Verify(token)
if err != nil {
t.Fatal(err)
}
}
func TestPBlindLibrary(t *testing.T) {
t.Parallel()
// generate a key-pair
curve := elliptic.P256()
sk, _ := pblind.NewSecretKey(curve)
pk := sk.GetPublicKey()
msgStr := []byte("128b_accesstoken")
infoStr := []byte("v=1 serial=12345")
info, err := pblind.CompressInfo(curve, infoStr)
if err != nil {
t.Fatal(err)
}
totalStart := time.Now()
batchSize := 1000
signers := make([]*pblind.StateSigner, batchSize)
requesters := make([]*pblind.StateRequester, batchSize)
toServer := make([][]byte, batchSize)
toClient := make([][]byte, batchSize)
// Create signers and prep requests.
start := time.Now()
for i := 0; i < batchSize; i++ {
signer, err := pblind.CreateSigner(sk, info)
if err != nil {
t.Fatal(err)
}
signers[i] = signer
msg1S, err := signer.CreateMessage1()
if err != nil {
t.Fatal(err)
}
ser1S, err := asn1.Marshal(msg1S)
if err != nil {
t.Fatal(err)
}
toClient[i] = ser1S
}
t.Logf("created %d signers and request preps in %s", batchSize, time.Since(start))
t.Logf("sending %d bytes to client", lenOfByteSlices(toClient))
// Create requesters and create requests.
start = time.Now()
for i := 0; i < batchSize; i++ {
requester, err := pblind.CreateRequester(pk, info, msgStr)
if err != nil {
t.Fatal(err)
}
requesters[i] = requester
var msg1R pblind.Message1
_, err = asn1.Unmarshal(toClient[i], &msg1R)
if err != nil {
t.Fatal(err)
}
err = requester.ProcessMessage1(msg1R)
if err != nil {
t.Fatal(err)
}
msg2R, err := requester.CreateMessage2()
if err != nil {
t.Fatal(err)
}
ser2R, err := asn1.Marshal(msg2R)
if err != nil {
t.Fatal(err)
}
toServer[i] = ser2R
}
t.Logf("created %d requesters and requests in %s", batchSize, time.Since(start))
t.Logf("sending %d bytes to server", lenOfByteSlices(toServer))
// Sign requests
start = time.Now()
for i := 0; i < batchSize; i++ {
var msg2S pblind.Message2
_, err = asn1.Unmarshal(toServer[i], &msg2S)
if err != nil {
t.Fatal(err)
}
err = signers[i].ProcessMessage2(msg2S)
if err != nil {
t.Fatal(err)
}
msg3S, err := signers[i].CreateMessage3()
if err != nil {
t.Fatal(err)
}
ser3S, err := asn1.Marshal(msg3S)
if err != nil {
t.Fatal(err)
}
toClient[i] = ser3S
}
t.Logf("signed %d requests in %s", batchSize, time.Since(start))
t.Logf("sending %d bytes to client", lenOfByteSlices(toClient))
// Verify signed requests
start = time.Now()
for i := 0; i < batchSize; i++ {
var msg3R pblind.Message3
_, err := asn1.Unmarshal(toClient[i], &msg3R)
if err != nil {
t.Fatal(err)
}
err = requesters[i].ProcessMessage3(msg3R)
if err != nil {
t.Fatal(err)
}
signature, err := requesters[i].Signature()
if err != nil {
t.Fatal(err)
}
sig, err := asn1.Marshal(signature)
if err != nil {
t.Fatal(err)
}
toServer[i] = sig
// check signature
if !pk.Check(signature, info, msgStr) {
t.Fatal("signature invalid")
}
}
t.Logf("finalized and verified %d signed tokens in %s", batchSize, time.Since(start))
t.Logf("stored %d signed tokens in %d bytes", batchSize, lenOfByteSlices(toServer))
// Verify on server
start = time.Now()
for i := 0; i < batchSize; i++ {
var sig pblind.Signature
_, err := asn1.Unmarshal(toServer[i], &sig)
if err != nil {
t.Fatal(err)
}
// check signature
if !pk.Check(sig, info, msgStr) {
t.Fatal("signature invalid")
}
}
t.Logf("verified %d signed tokens in %s", batchSize, time.Since(start))
t.Logf("process complete")
t.Logf("simulated the whole process for %d tokens in %s", batchSize, time.Since(totalStart))
}
func lenOfByteSlices(v [][]byte) (length int) {
for _, s := range v {
length += len(s)
}
return
}

View File

@@ -0,0 +1,116 @@
package token
import "sync"
// Handler represents a token handling system.
type Handler interface {
// Zone returns the zone name.
Zone() string
// ShouldRequest returns whether the new tokens should be requested.
ShouldRequest() bool
// Amount returns the current amount of tokens in this handler.
Amount() int
// IsFallback returns whether this handler should only be used as a fallback.
IsFallback() bool
// GetToken returns a token.
GetToken() (token *Token, err error)
// Verify verifies the given token.
Verify(token *Token) error
// Save serializes and returns the current tokens.
Save() ([]byte, error)
// Load loads the given tokens into the handler.
Load(data []byte) error
// Clear clears all the tokens in the handler.
Clear()
}
var (
registry map[string]Handler
pblindRegistry []*PBlindHandler
scrambleRegistry []*ScrambleHandler
registryLock sync.RWMutex
)
func init() {
initRegistry()
}
func initRegistry() {
registry = make(map[string]Handler)
pblindRegistry = make([]*PBlindHandler, 0, 1)
scrambleRegistry = make([]*ScrambleHandler, 0, 1)
}
// RegisterPBlindHandler registers a pblind handler with the registry.
func RegisterPBlindHandler(h *PBlindHandler) error {
registryLock.Lock()
defer registryLock.Unlock()
if err := registerHandler(h, h.opts.Zone); err != nil {
return err
}
pblindRegistry = append(pblindRegistry, h)
return nil
}
// RegisterScrambleHandler registers a scramble handler with the registry.
func RegisterScrambleHandler(h *ScrambleHandler) error {
registryLock.Lock()
defer registryLock.Unlock()
if err := registerHandler(h, h.opts.Zone); err != nil {
return err
}
scrambleRegistry = append(scrambleRegistry, h)
return nil
}
func registerHandler(h Handler, zone string) error {
if zone == "" {
return ErrNoZone
}
_, ok := registry[zone]
if ok {
return ErrZoneTaken
}
registry[zone] = h
return nil
}
// GetHandler returns the handler of the given zone.
func GetHandler(zone string) (handler Handler, ok bool) {
registryLock.RLock()
defer registryLock.RUnlock()
handler, ok = registry[zone]
return
}
// ResetRegistry resets the token handler registry.
func ResetRegistry() {
registryLock.Lock()
defer registryLock.Unlock()
initRegistry()
}
// RegistrySize returns the amount of handler registered.
func RegistrySize() int {
registryLock.Lock()
defer registryLock.Unlock()
return len(registry)
}

244
spn/access/token/request.go Normal file
View File

@@ -0,0 +1,244 @@
package token
import (
"crypto/rand"
"errors"
"fmt"
"github.com/mr-tron/base58"
)
const sessionIDSize = 32
// RequestHandlingState is a request handling state.
type RequestHandlingState struct {
SessionID string
PBlind map[string]*PBlindSignerState
}
// SetupRequest is a setup request.
type SetupRequest struct {
PBlind map[string]struct{} `json:"PB,omitempty"`
}
// SetupResponse is a setup response.
type SetupResponse struct {
SessionID string `json:"ID,omitempty"`
PBlind map[string]*PBlindSetupResponse `json:"PB,omitempty"`
}
// TokenRequest is a token request.
type TokenRequest struct { //nolint:golint // Be explicit.
SessionID string `json:"ID,omitempty"`
PBlind map[string]*PBlindTokenRequest `json:"PB,omitempty"`
Scramble map[string]*ScrambleTokenRequest `json:"S,omitempty"`
}
// IssuedTokens are issued tokens.
type IssuedTokens struct {
PBlind map[string]*IssuedPBlindTokens `json:"PB,omitempty"`
Scramble map[string]*IssuedScrambleTokens `json:"SC,omitempty"`
}
// CreateSetupRequest creates a combined setup request for all registered tokens, if needed.
func CreateSetupRequest() (request *SetupRequest, setupRequired bool) {
registryLock.RLock()
defer registryLock.RUnlock()
request = &SetupRequest{
PBlind: make(map[string]struct{}, len(pblindRegistry)),
}
// Go through handlers and create request setups.
for _, pblindHandler := range pblindRegistry {
// Check if we need to request with this handler.
if pblindHandler.ShouldRequest() {
request.PBlind[pblindHandler.Zone()] = struct{}{}
setupRequired = true
}
}
return
}
// HandleSetupRequest handles a setup request for all registered tokens.
func HandleSetupRequest(request *SetupRequest) (*RequestHandlingState, *SetupResponse, error) {
registryLock.RLock()
defer registryLock.RUnlock()
// Generate session token.
randomID := make([]byte, sessionIDSize)
n, err := rand.Read(randomID)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate session ID: %w", err)
}
if n != sessionIDSize {
return nil, nil, fmt.Errorf("failed to get full session ID: only got %d bytes", n)
}
sessionID := base58.Encode(randomID)
// Create state and response.
state := &RequestHandlingState{
SessionID: sessionID,
PBlind: make(map[string]*PBlindSignerState, len(pblindRegistry)),
}
setup := &SetupResponse{
SessionID: sessionID,
PBlind: make(map[string]*PBlindSetupResponse, len(pblindRegistry)),
}
// Go through handlers and create setups.
for _, pblindHandler := range pblindRegistry {
// Check if we have a request for this handler.
_, ok := request.PBlind[pblindHandler.Zone()]
if !ok {
continue
}
plindState, pblindSetup, err := pblindHandler.CreateSetup()
if err != nil {
return nil, nil, fmt.Errorf("failed to create setup for %s: %w", pblindHandler.Zone(), err)
}
state.PBlind[pblindHandler.Zone()] = plindState
setup.PBlind[pblindHandler.Zone()] = pblindSetup
}
return state, setup, nil
}
// CreateTokenRequest creates a token request for all registered tokens.
func CreateTokenRequest(setup *SetupResponse) (request *TokenRequest, requestRequired bool, err error) {
registryLock.RLock()
defer registryLock.RUnlock()
// Check setup data.
if setup != nil && setup.SessionID == "" {
return nil, false, errors.New("setup data is missing a session ID")
}
// Create token request.
request = &TokenRequest{
PBlind: make(map[string]*PBlindTokenRequest, len(pblindRegistry)),
Scramble: make(map[string]*ScrambleTokenRequest, len(scrambleRegistry)),
}
if setup != nil {
request.SessionID = setup.SessionID
}
// Go through handlers and create requests.
if setup != nil {
for _, pblindHandler := range pblindRegistry {
// Check if we have setup data for this handler.
pblindSetup, ok := setup.PBlind[pblindHandler.Zone()]
if !ok {
// TODO: Abort if we should have received request data.
continue
}
// Create request.
pblindRequest, err := pblindHandler.CreateTokenRequest(pblindSetup)
if err != nil {
return nil, false, fmt.Errorf("failed to create token request for %s: %w", pblindHandler.Zone(), err)
}
requestRequired = true
request.PBlind[pblindHandler.Zone()] = pblindRequest
}
}
for _, scrambleHandler := range scrambleRegistry {
// Check if we need to request with this handler.
if scrambleHandler.ShouldRequest() {
requestRequired = true
request.Scramble[scrambleHandler.Zone()] = scrambleHandler.CreateTokenRequest()
}
}
return request, requestRequired, nil
}
// IssueTokens issues tokens for all registered tokens.
func IssueTokens(state *RequestHandlingState, request *TokenRequest) (response *IssuedTokens, err error) {
registryLock.RLock()
defer registryLock.RUnlock()
// Create token response.
response = &IssuedTokens{
PBlind: make(map[string]*IssuedPBlindTokens, len(pblindRegistry)),
Scramble: make(map[string]*IssuedScrambleTokens, len(scrambleRegistry)),
}
// Go through handlers and create requests.
for _, pblindHandler := range pblindRegistry {
// Check if we have all the data for issuing.
pblindState, ok := state.PBlind[pblindHandler.Zone()]
if !ok {
continue
}
pblindRequest, ok := request.PBlind[pblindHandler.Zone()]
if !ok {
continue
}
// Issue tokens.
pblindTokens, err := pblindHandler.IssueTokens(pblindState, pblindRequest)
if err != nil {
return nil, fmt.Errorf("failed to issue tokens for %s: %w", pblindHandler.Zone(), err)
}
response.PBlind[pblindHandler.Zone()] = pblindTokens
}
for _, scrambleHandler := range scrambleRegistry {
// Check if we have all the data for issuing.
scrambleRequest, ok := request.Scramble[scrambleHandler.Zone()]
if !ok {
continue
}
// Issue tokens.
scrambleTokens, err := scrambleHandler.IssueTokens(scrambleRequest)
if err != nil {
return nil, fmt.Errorf("failed to issue tokens for %s: %w", scrambleHandler.Zone(), err)
}
response.Scramble[scrambleHandler.Zone()] = scrambleTokens
}
return response, nil
}
// ProcessIssuedTokens processes issued tokens for all registered tokens.
func ProcessIssuedTokens(response *IssuedTokens) error {
registryLock.RLock()
defer registryLock.RUnlock()
// Go through handlers and create requests.
for _, pblindHandler := range pblindRegistry {
// Check if we received tokens.
pblindResponse, ok := response.PBlind[pblindHandler.Zone()]
if !ok {
continue
}
// Process issued tokens.
err := pblindHandler.ProcessIssuedTokens(pblindResponse)
if err != nil {
return fmt.Errorf("failed to process issued tokens for %s: %w", pblindHandler.Zone(), err)
}
}
for _, scrambleHandler := range scrambleRegistry {
// Check if we received tokens.
scrambleResponse, ok := response.Scramble[scrambleHandler.Zone()]
if !ok {
continue
}
// Process issued tokens.
err := scrambleHandler.ProcessIssuedTokens(scrambleResponse)
if err != nil {
return fmt.Errorf("failed to process issued tokens for %s: %w", scrambleHandler.Zone(), err)
}
}
return nil
}

View File

@@ -0,0 +1,125 @@
package token
import (
"testing"
"time"
"github.com/safing/portbase/formats/dsd"
)
func TestFull(t *testing.T) {
t.Parallel()
testStart := time.Now()
// Roundtrip 1
start := time.Now()
setupRequest, setupRequired := CreateSetupRequest()
if !setupRequired {
t.Fatal("setup should be required")
}
setupRequestData, err := dsd.Dump(setupRequest, dsd.CBOR)
if err != nil {
t.Fatal(err)
}
setupRequest = nil // nolint:ineffassign,wastedassign // Just to be sure.
t.Logf("setupRequest: %s, %d bytes", time.Since(start), len(setupRequestData))
start = time.Now()
loadedSetupRequest := &SetupRequest{}
_, err = dsd.Load(setupRequestData, loadedSetupRequest)
if err != nil {
t.Fatal(err)
}
serverState, setupResponse, err := HandleSetupRequest(loadedSetupRequest)
if err != nil {
t.Fatal(err)
}
setupResponseData, err := dsd.Dump(setupResponse, dsd.CBOR)
if err != nil {
t.Fatal(err)
}
setupResponse = nil // nolint:ineffassign,wastedassign // Just to be sure.
t.Logf("setupResponse: %s, %d bytes", time.Since(start), len(setupResponseData))
// Roundtrip 2
start = time.Now()
loadedSetupResponse := &SetupResponse{}
_, err = dsd.Load(setupResponseData, loadedSetupResponse)
if err != nil {
t.Fatal(err)
}
request, requestRequired, err := CreateTokenRequest(loadedSetupResponse)
if err != nil {
t.Fatal(err)
}
if !requestRequired {
t.Fatal("request should be required")
}
requestData, err := dsd.Dump(request, dsd.CBOR)
if err != nil {
t.Fatal(err)
}
request = nil // nolint:ineffassign,wastedassign // Just to be sure.
t.Logf("request: %s, %d bytes", time.Since(start), len(requestData))
start = time.Now()
loadedRequest := &TokenRequest{}
_, err = dsd.Load(requestData, loadedRequest)
if err != nil {
t.Fatal(err)
}
response, err := IssueTokens(serverState, loadedRequest)
if err != nil {
t.Fatal(err)
}
responseData, err := dsd.Dump(response, dsd.CBOR)
if err != nil {
t.Fatal(err)
}
response = nil // nolint:ineffassign,wastedassign // Just to be sure.
t.Logf("response: %s, %d bytes", time.Since(start), len(responseData))
start = time.Now()
loadedResponse := &IssuedTokens{}
_, err = dsd.Load(responseData, loadedResponse)
if err != nil {
t.Fatal(err)
}
err = ProcessIssuedTokens(loadedResponse)
if err != nil {
t.Fatal(err)
}
t.Logf("processing: %s", time.Since(start))
// Token Usage
for _, testZone := range []string{
PBlindTestZone,
ScrambleTestZone,
} {
start = time.Now()
token, err := GetToken(testZone)
if err != nil {
t.Fatal(err)
}
tokenData := token.Raw()
token = nil // nolint:wastedassign // Just to be sure.
loadedToken, err := ParseRawToken(tokenData)
if err != nil {
t.Fatal(err)
}
err = VerifyToken(loadedToken)
if err != nil {
t.Fatal(err)
}
t.Logf("using %s token: %s", testZone, time.Since(start))
}
t.Logf("full simulation took %s", time.Since(testStart))
}

View File

@@ -0,0 +1,240 @@
package token
import (
"fmt"
"sync"
"github.com/mr-tron/base58"
"github.com/safing/jess/lhash"
"github.com/safing/portbase/formats/dsd"
)
const (
scrambleSecretSize = 32
)
// ScrambleToken is token based on hashing.
type ScrambleToken struct {
Token []byte
}
// Pack packs the token.
func (pbt *ScrambleToken) Pack() ([]byte, error) {
return pbt.Token, nil
}
// UnpackScrambleToken unpacks the token.
func UnpackScrambleToken(token []byte) (*ScrambleToken, error) {
return &ScrambleToken{Token: token}, nil
}
// ScrambleHandler is a handler for the scramble tokens.
type ScrambleHandler struct {
sync.Mutex
opts *ScrambleOptions
storageLock sync.Mutex
Storage []*ScrambleToken
verifiersLock sync.RWMutex
verifiers map[string]*ScrambleToken
}
// ScrambleOptions are options for the ScrambleHandler.
type ScrambleOptions struct {
Zone string
Algorithm lhash.Algorithm
InitialTokens []string
InitialVerifiers []string
Fallback bool
}
// ScrambleTokenRequest is a token request.
type ScrambleTokenRequest struct{}
// IssuedScrambleTokens are issued scrambled tokens.
type IssuedScrambleTokens struct {
Tokens []*ScrambleToken
}
// NewScrambleHandler creates a new scramble handler.
func NewScrambleHandler(opts ScrambleOptions) (*ScrambleHandler, error) {
sh := &ScrambleHandler{
opts: &opts,
verifiers: make(map[string]*ScrambleToken, len(opts.InitialTokens)+len(opts.InitialVerifiers)),
}
// Add initial tokens.
sh.Storage = make([]*ScrambleToken, len(opts.InitialTokens))
for i, token := range opts.InitialTokens {
// Add to storage.
tokenData, err := base58.Decode(token)
if err != nil {
return nil, fmt.Errorf("failed to decode initial token %q: %w", token, err)
}
sh.Storage[i] = &ScrambleToken{
Token: tokenData,
}
// Add to verifiers.
scrambledToken := lhash.Digest(sh.opts.Algorithm, tokenData).Bytes()
sh.verifiers[string(scrambledToken)] = sh.Storage[i]
}
// Add initial verifiers.
for _, verifier := range opts.InitialVerifiers {
verifierData, err := base58.Decode(verifier)
if err != nil {
return nil, fmt.Errorf("failed to decode verifier %q: %w", verifier, err)
}
sh.verifiers[string(verifierData)] = &ScrambleToken{}
}
return sh, nil
}
// Zone returns the zone name.
func (sh *ScrambleHandler) Zone() string {
return sh.opts.Zone
}
// ShouldRequest returns whether the new tokens should be requested.
func (sh *ScrambleHandler) ShouldRequest() bool {
sh.storageLock.Lock()
defer sh.storageLock.Unlock()
return len(sh.Storage) == 0
}
// Amount returns the current amount of tokens in this handler.
func (sh *ScrambleHandler) Amount() int {
sh.storageLock.Lock()
defer sh.storageLock.Unlock()
return len(sh.Storage)
}
// IsFallback returns whether this handler should only be used as a fallback.
func (sh *ScrambleHandler) IsFallback() bool {
return sh.opts.Fallback
}
// CreateTokenRequest creates a token request to be sent to the token server.
func (sh *ScrambleHandler) CreateTokenRequest() (request *ScrambleTokenRequest) {
return &ScrambleTokenRequest{}
}
// IssueTokens sign the requested tokens.
func (sh *ScrambleHandler) IssueTokens(request *ScrambleTokenRequest) (response *IssuedScrambleTokens, err error) {
// Copy the storage.
tokens := make([]*ScrambleToken, len(sh.Storage))
copy(tokens, sh.Storage)
return &IssuedScrambleTokens{
Tokens: tokens,
}, nil
}
// ProcessIssuedTokens processes the issued token from the server.
func (sh *ScrambleHandler) ProcessIssuedTokens(issuedTokens *IssuedScrambleTokens) error {
sh.verifiersLock.RLock()
defer sh.verifiersLock.RUnlock()
// Validate tokens.
for i, newToken := range issuedTokens.Tokens {
// Scramle token.
scrambledToken := lhash.Digest(sh.opts.Algorithm, newToken.Token).Bytes()
// Check if token is valid.
_, ok := sh.verifiers[string(scrambledToken)]
if !ok {
return fmt.Errorf("invalid token on #%d", i)
}
}
// Copy to storage.
sh.Storage = issuedTokens.Tokens
return nil
}
// Verify verifies the given token.
func (sh *ScrambleHandler) Verify(token *Token) error {
if token.Zone != sh.opts.Zone {
return ErrZoneMismatch
}
// Hash the data.
scrambledToken := lhash.Digest(sh.opts.Algorithm, token.Data).Bytes()
sh.verifiersLock.RLock()
defer sh.verifiersLock.RUnlock()
// Check if token is valid.
_, ok := sh.verifiers[string(scrambledToken)]
if !ok {
return ErrTokenInvalid
}
return nil
}
// GetToken returns a token.
func (sh *ScrambleHandler) GetToken() (*Token, error) {
sh.storageLock.Lock()
defer sh.storageLock.Unlock()
if len(sh.Storage) == 0 {
return nil, ErrEmpty
}
return &Token{
Zone: sh.opts.Zone,
Data: sh.Storage[0].Token,
}, nil
}
// ScrambleStorage is a storage for scramble tokens.
type ScrambleStorage struct {
Storage []*ScrambleToken
}
// Save serializes and returns the current tokens.
func (sh *ScrambleHandler) Save() ([]byte, error) {
sh.storageLock.Lock()
defer sh.storageLock.Unlock()
if len(sh.Storage) == 0 {
return nil, ErrEmpty
}
s := &ScrambleStorage{
Storage: sh.Storage,
}
return dsd.Dump(s, dsd.CBOR)
}
// Load loads the given tokens into the handler.
func (sh *ScrambleHandler) Load(data []byte) error {
sh.storageLock.Lock()
defer sh.storageLock.Unlock()
s := &ScrambleStorage{}
_, err := dsd.Load(data, s)
if err != nil {
return err
}
sh.Storage = s.Storage
return nil
}
// Clear clears all the tokens in the handler.
func (sh *ScrambleHandler) Clear() {
sh.storageLock.Lock()
defer sh.storageLock.Unlock()
sh.Storage = nil
}

View File

@@ -0,0 +1,48 @@
package token
import (
"crypto/rand"
"fmt"
"testing"
"github.com/mr-tron/base58"
"github.com/safing/jess/lhash"
)
type genAlgs struct {
alg lhash.Algorithm
name string
}
func TestGenerateScrambleKeys(t *testing.T) {
t.Parallel()
for _, alg := range []genAlgs{
{alg: lhash.SHA2_256, name: "SHA2_256"},
{alg: lhash.SHA3_256, name: "SHA3_256"},
{alg: lhash.SHA3_512, name: "SHA3_512"},
{alg: lhash.BLAKE2b_256, name: "BLAKE2b_256"},
} {
token := make([]byte, scrambleSecretSize)
n, err := rand.Read(token)
if err != nil {
t.Fatal(err)
}
if n != scrambleSecretSize {
t.Fatalf("only got %d bytes", n)
}
scrambledToken := lhash.Digest(alg.alg, token).Bytes()
fmt.Printf(
"%s secret token: %s\n",
alg.name,
base58.Encode(token),
)
fmt.Printf(
"%s scrambled (public) token: %s\n",
alg.name,
base58.Encode(scrambledToken),
)
}
}

View File

@@ -0,0 +1,84 @@
package token
import (
"testing"
"github.com/safing/jess/lhash"
)
const ScrambleTestZone = "test-scramble"
func init() {
// Combined testing config.
h, err := NewScrambleHandler(ScrambleOptions{
Zone: ScrambleTestZone,
Algorithm: lhash.SHA2_256,
InitialTokens: []string{"2VqJ8BvDew1tUpytZhR7tuvq7ToPpW3tQtHvu3veE3iW"},
})
if err != nil {
panic(err)
}
err = RegisterScrambleHandler(h)
if err != nil {
panic(err)
}
}
func TestScramble(t *testing.T) {
t.Parallel()
opts := &ScrambleOptions{
Zone: ScrambleTestZone,
Algorithm: lhash.SHA2_256,
}
// Issuer
opts.InitialTokens = []string{"2VqJ8BvDew1tUpytZhR7tuvq7ToPpW3tQtHvu3veE3iW"}
issuer, err := NewScrambleHandler(*opts)
if err != nil {
t.Fatal(err)
}
// Client
opts.InitialTokens = nil
opts.InitialVerifiers = []string{"Cy9tz37Xq9NiXGDRU9yicjGU62GjXskE9KqUmuoddSxaE3"}
client, err := NewScrambleHandler(*opts)
if err != nil {
t.Fatal(err)
}
// Verifier
verifier, err := NewScrambleHandler(*opts)
if err != nil {
t.Fatal(err)
}
// Play through the whole use case.
request := client.CreateTokenRequest()
if err != nil {
t.Fatal(err)
}
issuedTokens, err := issuer.IssueTokens(request)
if err != nil {
t.Fatal(err)
}
err = client.ProcessIssuedTokens(issuedTokens)
if err != nil {
t.Fatal(err)
}
token, err := client.GetToken()
if err != nil {
t.Fatal(err)
}
err = verifier.Verify(token)
if err != nil {
t.Fatal(err)
}
}

83
spn/access/token/token.go Normal file
View File

@@ -0,0 +1,83 @@
package token
import (
"bytes"
"errors"
"fmt"
"strings"
"github.com/mr-tron/base58"
"github.com/safing/portbase/container"
)
// Token represents a token, consisting of a zone (name) and some data.
type Token struct {
Zone string
Data []byte
}
// GetToken returns a token of the given zone.
func GetToken(zone string) (*Token, error) {
handler, ok := GetHandler(zone)
if !ok {
return nil, ErrZoneUnknown
}
return handler.GetToken()
}
// VerifyToken verifies the given token.
func VerifyToken(token *Token) error {
handler, ok := GetHandler(token.Zone)
if !ok {
return ErrZoneUnknown
}
return handler.Verify(token)
}
// Raw returns the raw format of the token.
func (c *Token) Raw() []byte {
cont := container.New()
cont.Append([]byte(c.Zone))
cont.Append([]byte(":"))
cont.Append(c.Data)
return cont.CompileData()
}
// String returns the stringified format of the token.
func (c *Token) String() string {
return c.Zone + ":" + base58.Encode(c.Data)
}
// ParseRawToken parses a raw token.
func ParseRawToken(code []byte) (*Token, error) {
splitted := bytes.SplitN(code, []byte(":"), 2)
if len(splitted) < 2 {
return nil, errors.New("invalid code format: zone/data separator missing")
}
return &Token{
Zone: string(splitted[0]),
Data: splitted[1],
}, nil
}
// ParseToken parses a stringified token.
func ParseToken(code string) (*Token, error) {
splitted := strings.SplitN(code, ":", 2)
if len(splitted) < 2 {
return nil, errors.New("invalid code format: zone/data separator missing")
}
data, err := base58.Decode(splitted[1])
if err != nil {
return nil, fmt.Errorf("invalid code format: %w", err)
}
return &Token{
Zone: splitted[0],
Data: data,
}, nil
}

View File

@@ -0,0 +1,33 @@
package token
import (
"testing"
"github.com/safing/portbase/rng"
)
func TestToken(t *testing.T) {
t.Parallel()
randomData, err := rng.Bytes(32)
if err != nil {
t.Fatal(err)
}
c := &Token{
Zone: "test",
Data: randomData,
}
s := c.String()
_, err = ParseToken(s)
if err != nil {
t.Fatal(err)
}
r := c.Raw()
_, err = ParseRawToken(r)
if err != nil {
t.Fatal(err)
}
}

257
spn/access/zones.go Normal file
View File

@@ -0,0 +1,257 @@
package access
import (
"errors"
"fmt"
"os"
"strings"
"github.com/tevino/abool"
"github.com/safing/jess/lhash"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/access/token"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/terminal"
)
var (
// ExpandAndConnectZones are the zones that grant access to the expand and
// connect operations.
ExpandAndConnectZones = []string{"pblind1", "alpha2", "fallback1"}
zonePermissions = map[string]terminal.Permission{
"pblind1": terminal.AddPermissions(terminal.MayExpand, terminal.MayConnect),
"alpha2": terminal.AddPermissions(terminal.MayExpand, terminal.MayConnect),
"fallback1": terminal.AddPermissions(terminal.MayExpand, terminal.MayConnect),
}
persistentZones = ExpandAndConnectZones
enableTestMode = abool.New()
)
// EnableTestMode enables the test mode, leading the access module to only
// register a test zone.
// This should not be used to test the access module itself.
func EnableTestMode() {
enableTestMode.Set()
}
// InitializeZones initialized the permission zones.
// It initializes the test zones, if EnableTestMode was called before.
// Must only be called once.
func InitializeZones() error {
// Check if we are testing.
if enableTestMode.IsSet() {
return initializeTestZone()
}
// Special client zone config.
var requestSignalHandler func(token.Handler)
if conf.Client() {
requestSignalHandler = shouldRequestTokensHandler
}
// Register pblind1 as the first primary zone.
ph, err := token.NewPBlindHandler(token.PBlindOptions{
Zone: "pblind1",
CurveName: "P-256",
PublicKey: "eXoJXzXbM66UEsM2eVi9HwyBPLMfVnNrC7gNrsfMUJDs",
UseSerials: true,
BatchSize: 1000,
RandomizeOrder: true,
SignalShouldRequest: requestSignalHandler,
})
if err != nil {
return fmt.Errorf("failed to create pblind1 token handler: %w", err)
}
err = token.RegisterPBlindHandler(ph)
if err != nil {
return fmt.Errorf("failed to register pblind1 token handler: %w", err)
}
// Register fallback1 zone as fallback when the issuer is not available.
sh, err := token.NewScrambleHandler(token.ScrambleOptions{
Zone: "fallback1",
Algorithm: lhash.BLAKE2b_256,
InitialVerifiers: []string{"ZwkQoaAttVBMURzeLzNXokFBMAMUUwECfM1iHojcVKBmjk"},
Fallback: true,
})
if err != nil {
return fmt.Errorf("failed to create fallback1 token handler: %w", err)
}
err = token.RegisterScrambleHandler(sh)
if err != nil {
return fmt.Errorf("failed to register fallback1 token handler: %w", err)
}
// Register alpha2 zone for transition phase.
sh, err = token.NewScrambleHandler(token.ScrambleOptions{
Zone: "alpha2",
Algorithm: lhash.BLAKE2b_256,
InitialVerifiers: []string{"ZwojEvXZmAv7SZdNe7m94Xzu7F9J8vULqKf7QYtoTpN2tH"},
})
if err != nil {
return fmt.Errorf("failed to create alpha2 token handler: %w", err)
}
err = token.RegisterScrambleHandler(sh)
if err != nil {
return fmt.Errorf("failed to register alpha2 token handler: %w", err)
}
return nil
}
func initializeTestZone() error {
// Safeguard checks if we should really enable the test zone.
if !strings.HasSuffix(os.Args[0], ".test") {
return errors.New("tried to enable test mode, but no test binary was detected")
}
if token.RegistrySize() > 0 {
return fmt.Errorf("tried to enable test zone, but %d handlers are already registered", token.RegistrySize())
}
// Reset zones.
token.ResetRegistry()
// Set eligible zones.
ExpandAndConnectZones = []string{"unittest"}
zonePermissions = map[string]terminal.Permission{
"unittest": terminal.AddPermissions(terminal.MayExpand, terminal.MayConnect),
}
// Register unittest zone as for testing.
sh, err := token.NewScrambleHandler(token.ScrambleOptions{
Zone: "unittest",
Algorithm: lhash.BLAKE2b_256,
InitialTokens: []string{"6jFqLA93uSLL52utGKrvctG3ZfopSQ8WFqjsRK1c2Svt"},
InitialVerifiers: []string{"ZwoEoL59sr81s7WnF2vydGzjeejE3u8CqVafig1NTQzUr7"},
})
if err != nil {
return fmt.Errorf("failed to create unittest token handler: %w", err)
}
err = token.RegisterScrambleHandler(sh)
if err != nil {
return fmt.Errorf("failed to register unittest token handler: %w", err)
}
return nil
}
func shouldRequestTokensHandler(_ token.Handler) {
// accountUpdateTask is always set in client mode and when the module is online.
// Check if it's set in case this gets executed in other circumstances.
if accountUpdateTask == nil {
log.Warningf("spn/access: trying to trigger account update, but the task is not available")
return
}
accountUpdateTask.StartASAP()
}
// GetTokenAmount returns the amount of tokens for the given zones.
func GetTokenAmount(zones []string) (regular, fallback int) {
handlerLoop:
for _, zone := range zones {
// Get handler and check if it should be used.
handler, ok := token.GetHandler(zone)
if !ok {
log.Warningf("spn/access: use of non-registered zone %q", zone)
continue handlerLoop
}
if handler.IsFallback() {
fallback += handler.Amount()
} else {
regular += handler.Amount()
}
}
return
}
// ShouldRequest returns whether tokens should be requested for the given zones.
func ShouldRequest(zones []string) (shouldRequest bool) {
handlerLoop:
for _, zone := range zones {
// Get handler and check if it should be used.
handler, ok := token.GetHandler(zone)
if !ok {
log.Warningf("spn/access: use of non-registered zone %q", zone)
continue handlerLoop
}
// Go through all handlers every time as this will be the case anyway most
// of the time and will help us better catch zone misconfiguration.
if handler.ShouldRequest() {
shouldRequest = true
}
}
return shouldRequest
}
// GetToken returns a token of one of the given zones.
func GetToken(zones []string) (t *token.Token, err error) {
handlerSelection:
for _, zone := range zones {
// Get handler and check if it should be used.
handler, ok := token.GetHandler(zone)
switch {
case !ok:
log.Warningf("spn/access: use of non-registered zone %q", zone)
continue handlerSelection
case handler.IsFallback() && !TokenIssuerIsFailing():
// Skip fallback zone if everything works.
continue handlerSelection
}
// Get token from handler.
t, err = token.GetToken(zone)
if err == nil {
return t, nil
}
}
// Return existing error, if exists.
if err != nil {
return nil, err
}
return nil, token.ErrEmpty
}
// VerifyRawToken verifies a raw token.
func VerifyRawToken(data []byte) (granted terminal.Permission, err error) {
t, err := token.ParseRawToken(data)
if err != nil {
return 0, fmt.Errorf("failed to parse token: %w", err)
}
return VerifyToken(t)
}
// VerifyToken verifies a token.
func VerifyToken(t *token.Token) (granted terminal.Permission, err error) {
handler, ok := token.GetHandler(t.Zone)
if !ok {
return terminal.NoPermission, token.ErrZoneUnknown
}
// Check if the token is a fallback token.
if handler.IsFallback() && !healthCheck() {
return terminal.NoPermission, ErrFallbackNotAvailable
}
// Verify token.
err = handler.Verify(t)
if err != nil {
return 0, fmt.Errorf("failed to verify token: %w", err)
}
// Return permission of zone.
granted, ok = zonePermissions[t.Zone]
if !ok {
return terminal.NoPermission, nil
}
return granted, nil
}

392
spn/cabin/config-public.go Normal file
View File

@@ -0,0 +1,392 @@
package cabin
import (
"fmt"
"net"
"os"
"github.com/safing/portbase/config"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/spn/hub"
)
// Configuration Keys.
var (
// Name of the node.
publicCfgOptionNameKey = "spn/publicHub/name"
publicCfgOptionName config.StringOption
publicCfgOptionNameDefault = ""
publicCfgOptionNameOrder = 512
// Person or organisation, who is in control of the node (should be same for all nodes of this person or organisation).
publicCfgOptionGroupKey = "spn/publicHub/group"
publicCfgOptionGroup config.StringOption
publicCfgOptionGroupDefault = ""
publicCfgOptionGroupOrder = 513
// Contact possibility (recommended, but optional).
publicCfgOptionContactAddressKey = "spn/publicHub/contactAddress"
publicCfgOptionContactAddress config.StringOption
publicCfgOptionContactAddressDefault = ""
publicCfgOptionContactAddressOrder = 514
// Type of service of the contact address, if not email.
publicCfgOptionContactServiceKey = "spn/publicHub/contactService"
publicCfgOptionContactService config.StringOption
publicCfgOptionContactServiceDefault = ""
publicCfgOptionContactServiceOrder = 515
// Hosters - supply chain (reseller, hosting provider, datacenter operator, ...).
publicCfgOptionHostersKey = "spn/publicHub/hosters"
publicCfgOptionHosters config.StringArrayOption
publicCfgOptionHostersDefault = []string{}
publicCfgOptionHostersOrder = 516
// Datacenter
// Format: CC-COMPANY-INTERNALCODE
// Eg: DE-Hetzner-FSN1-DC5
//.
publicCfgOptionDatacenterKey = "spn/publicHub/datacenter"
publicCfgOptionDatacenter config.StringOption
publicCfgOptionDatacenterDefault = ""
publicCfgOptionDatacenterOrder = 517
// Network Location and Access.
// IPv4 must be global and accessible.
publicCfgOptionIPv4Key = "spn/publicHub/ip4"
publicCfgOptionIPv4 config.StringOption
publicCfgOptionIPv4Default = ""
publicCfgOptionIPv4Order = 518
// IPv6 must be global and accessible.
publicCfgOptionIPv6Key = "spn/publicHub/ip6"
publicCfgOptionIPv6 config.StringOption
publicCfgOptionIPv6Default = ""
publicCfgOptionIPv6Order = 519
// Transports.
publicCfgOptionTransportsKey = "spn/publicHub/transports"
publicCfgOptionTransports config.StringArrayOption
publicCfgOptionTransportsDefault = []string{
"tcp:17",
}
publicCfgOptionTransportsOrder = 520
// Entry Policy.
publicCfgOptionEntryKey = "spn/publicHub/entry"
publicCfgOptionEntry config.StringArrayOption
publicCfgOptionEntryDefault = []string{}
publicCfgOptionEntryOrder = 521
// Exit Policy.
publicCfgOptionExitKey = "spn/publicHub/exit"
publicCfgOptionExit config.StringArrayOption
publicCfgOptionExitDefault = []string{"- * TCP/25"}
publicCfgOptionExitOrder = 522
// Allow Unencrypted.
publicCfgOptionAllowUnencryptedKey = "spn/publicHub/allowUnencrypted"
publicCfgOptionAllowUnencrypted config.BoolOption
publicCfgOptionAllowUnencryptedDefault = false
publicCfgOptionAllowUnencryptedOrder = 523
)
func prepPublicHubConfig() error {
err := config.Register(&config.Option{
Name: "Name",
Key: publicCfgOptionNameKey,
Description: "Human readable name of the Hub.",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionNameDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionNameOrder,
},
})
if err != nil {
return err
}
publicCfgOptionName = config.GetAsString(publicCfgOptionNameKey, publicCfgOptionNameDefault)
err = config.Register(&config.Option{
Name: "Group",
Key: publicCfgOptionGroupKey,
Description: "Name of the hub group this Hub belongs to.",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionGroupDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionGroupOrder,
},
})
if err != nil {
return err
}
publicCfgOptionGroup = config.GetAsString(publicCfgOptionGroupKey, publicCfgOptionGroupDefault)
err = config.Register(&config.Option{
Name: "Contact Address",
Key: publicCfgOptionContactAddressKey,
Description: "Contact address where the Hub operator can be reached.",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionContactAddressDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionContactAddressOrder,
},
})
if err != nil {
return err
}
publicCfgOptionContactAddress = config.GetAsString(publicCfgOptionContactAddressKey, publicCfgOptionContactAddressDefault)
err = config.Register(&config.Option{
Name: "Contact Service",
Key: publicCfgOptionContactServiceKey,
Description: "Name of the service the contact address corresponds to, if not email.",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionContactServiceDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionContactServiceOrder,
},
})
if err != nil {
return err
}
publicCfgOptionContactService = config.GetAsString(publicCfgOptionContactServiceKey, publicCfgOptionContactServiceDefault)
err = config.Register(&config.Option{
Name: "Hosters",
Key: publicCfgOptionHostersKey,
Description: "List of all involved entities and organisations that are involved in hosting this Hub.",
OptType: config.OptTypeStringArray,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionHostersDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionHostersOrder,
},
})
if err != nil {
return err
}
publicCfgOptionHosters = config.GetAsStringArray(publicCfgOptionHostersKey, publicCfgOptionHostersDefault)
err = config.Register(&config.Option{
Name: "Datacenter",
Key: publicCfgOptionDatacenterKey,
Description: "Identifier of the datacenter this Hub is hosted in.",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionDatacenterDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionDatacenterOrder,
},
})
if err != nil {
return err
}
publicCfgOptionDatacenter = config.GetAsString(publicCfgOptionDatacenterKey, publicCfgOptionDatacenterDefault)
err = config.Register(&config.Option{
Name: "IPv4",
Key: publicCfgOptionIPv4Key,
Description: "IPv4 address of this Hub. Must be globally reachable.",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionIPv4Default,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionIPv4Order,
},
})
if err != nil {
return err
}
publicCfgOptionIPv4 = config.GetAsString(publicCfgOptionIPv4Key, publicCfgOptionIPv4Default)
err = config.Register(&config.Option{
Name: "IPv6",
Key: publicCfgOptionIPv6Key,
Description: "IPv6 address of this Hub. Must be globally reachable.",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionIPv6Default,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionIPv6Order,
},
})
if err != nil {
return err
}
publicCfgOptionIPv6 = config.GetAsString(publicCfgOptionIPv6Key, publicCfgOptionIPv6Default)
err = config.Register(&config.Option{
Name: "Transports",
Key: publicCfgOptionTransportsKey,
Description: "List of transports this Hub supports.",
OptType: config.OptTypeStringArray,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionTransportsDefault,
ValidationFunc: func(value any) error {
if transports, ok := value.([]string); ok {
for i, transport := range transports {
if _, err := hub.ParseTransport(transport); err != nil {
return fmt.Errorf("failed to parse transport #%d: %w", i, err)
}
}
} else {
return fmt.Errorf("not a []string, but %T", value)
}
return nil
},
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionTransportsOrder,
},
})
if err != nil {
return err
}
publicCfgOptionTransports = config.GetAsStringArray(publicCfgOptionTransportsKey, publicCfgOptionTransportsDefault)
err = config.Register(&config.Option{
Name: "Entry",
Key: publicCfgOptionEntryKey,
Description: "Define an entry policy. The format is the same for the endpoint lists. Default is permit.",
OptType: config.OptTypeStringArray,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionEntryDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionEntryOrder,
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
},
})
if err != nil {
return err
}
publicCfgOptionEntry = config.GetAsStringArray(publicCfgOptionEntryKey, publicCfgOptionEntryDefault)
err = config.Register(&config.Option{
Name: "Exit",
Key: publicCfgOptionExitKey,
Description: "Define an exit policy. The format is the same for the endpoint lists. Default is permit.",
OptType: config.OptTypeStringArray,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionExitDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionExitOrder,
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
},
})
if err != nil {
return err
}
publicCfgOptionExit = config.GetAsStringArray(publicCfgOptionExitKey, publicCfgOptionExitDefault)
err = config.Register(&config.Option{
Name: "Allow Unencrypted Connections",
Key: publicCfgOptionAllowUnencryptedKey,
Description: "Advertise that this Hub is available for handling unencrypted connections, as detected by clients.",
OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelExpert,
RequiresRestart: true,
DefaultValue: publicCfgOptionAllowUnencryptedDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: publicCfgOptionAllowUnencryptedOrder,
},
})
if err != nil {
return err
}
publicCfgOptionAllowUnencrypted = config.GetAsBool(publicCfgOptionAllowUnencryptedKey, publicCfgOptionAllowUnencryptedDefault)
// update defaults from system
setDynamicPublicDefaults()
return nil
}
func getPublicHubInfo() *hub.Announcement {
// get configuration
info := &hub.Announcement{
Name: publicCfgOptionName(),
Group: publicCfgOptionGroup(),
ContactAddress: publicCfgOptionContactAddress(),
ContactService: publicCfgOptionContactService(),
Hosters: publicCfgOptionHosters(),
Datacenter: publicCfgOptionDatacenter(),
Transports: publicCfgOptionTransports(),
Entry: publicCfgOptionEntry(),
Exit: publicCfgOptionExit(),
Flags: []string{},
}
if publicCfgOptionAllowUnencrypted() {
info.Flags = append(info.Flags, hub.FlagAllowUnencrypted)
}
ip4 := publicCfgOptionIPv4()
if ip4 != "" {
ip := net.ParseIP(ip4)
if ip == nil {
log.Warningf("spn/cabin: invalid %s config: %s", publicCfgOptionIPv4Key, ip4)
} else {
info.IPv4 = ip
}
}
ip6 := publicCfgOptionIPv6()
if ip6 != "" {
ip := net.ParseIP(ip6)
if ip == nil {
log.Warningf("spn/cabin: invalid %s config: %s", publicCfgOptionIPv6Key, ip6)
} else {
info.IPv6 = ip
}
}
return info
}
func setDynamicPublicDefaults() {
// name
hostname, err := os.Hostname()
if err == nil {
err := config.SetDefaultConfigOption(publicCfgOptionNameKey, hostname)
if err != nil {
log.Warningf("spn/cabin: failed to set %s default to %s", publicCfgOptionNameKey, hostname)
}
}
// IPs
v4IPs, v6IPs, err := netenv.GetAssignedGlobalAddresses()
if err != nil {
log.Warningf("spn/cabin: failed to get assigned addresses: %s", err)
return
}
if len(v4IPs) == 1 {
err = config.SetDefaultConfigOption(publicCfgOptionIPv4Key, v4IPs[0].String())
if err != nil {
log.Warningf("spn/cabin: failed to set %s default to %s", publicCfgOptionIPv4Key, v4IPs[0].String())
}
}
if len(v6IPs) == 1 {
err = config.SetDefaultConfigOption(publicCfgOptionIPv6Key, v6IPs[0].String())
if err != nil {
log.Warningf("spn/cabin: failed to set %s default to %s", publicCfgOptionIPv6Key, v6IPs[0].String())
}
}
}

98
spn/cabin/database.go Normal file
View File

@@ -0,0 +1,98 @@
package cabin
import (
"errors"
"fmt"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/record"
"github.com/safing/portmaster/spn/hub"
)
var db = database.NewInterface(nil)
// LoadIdentity loads an identify with the given key.
func LoadIdentity(key string) (id *Identity, changed bool, err error) {
r, err := db.Get(key)
if err != nil {
return nil, false, err
}
id, err = EnsureIdentity(r)
if err != nil {
return nil, false, fmt.Errorf("failed to parse identity: %w", err)
}
// Check if required fields are present.
switch {
case id.Hub == nil:
return nil, false, errors.New("missing id.Hub")
case id.Signet == nil:
return nil, false, errors.New("missing id.Signet")
case id.Hub.Info == nil:
return nil, false, errors.New("missing hub.Info")
case id.Hub.Status == nil:
return nil, false, errors.New("missing hub.Status")
case id.ID != id.Hub.ID:
return nil, false, errors.New("hub.ID mismatch")
case id.ID != id.Hub.Info.ID:
return nil, false, errors.New("hub.Info.ID mismatch")
case id.Map == "":
return nil, false, errors.New("invalid id.Map")
case id.Hub.Map == "":
return nil, false, errors.New("invalid hub.Map")
case id.Hub.FirstSeen.IsZero():
return nil, false, errors.New("missing hub.FirstSeen")
case id.Hub.Info.Timestamp == 0:
return nil, false, errors.New("missing hub.Info.Timestamp")
case id.Hub.Status.Timestamp == 0:
return nil, false, errors.New("missing hub.Status.Timestamp")
}
// Run a initial maintenance routine.
infoChanged, err := id.MaintainAnnouncement(nil, true)
if err != nil {
return nil, false, fmt.Errorf("failed to initialize announcement: %w", err)
}
statusChanged, err := id.MaintainStatus(nil, nil, nil, true)
if err != nil {
return nil, false, fmt.Errorf("failed to initialize status: %w", err)
}
// Ensure the Measurements reset the values.
measurements := id.Hub.GetMeasurements()
measurements.SetLatency(0)
measurements.SetCapacity(0)
measurements.SetCalculatedCost(hub.MaxCalculatedCost)
return id, infoChanged || statusChanged, nil
}
// EnsureIdentity makes sure a database record is an Identity.
func EnsureIdentity(r record.Record) (*Identity, error) {
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
id := &Identity{}
err := record.Unwrap(r, id)
if err != nil {
return nil, err
}
return id, nil
}
// or adjust type
id, ok := r.(*Identity)
if !ok {
return nil, fmt.Errorf("record not of type *Identity, but %T", r)
}
return id, nil
}
// Save saves the Identity to the database.
func (id *Identity) Save() error {
if !id.KeyIsSet() {
return errors.New("no key set")
}
return db.Put(id)
}

311
spn/cabin/identity.go Normal file
View File

@@ -0,0 +1,311 @@
package cabin
import (
"context"
"errors"
"fmt"
"time"
"github.com/safing/jess"
"github.com/safing/jess/tools"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/info"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/hub"
)
const (
// DefaultIDKeyScheme is the default jess tool for creating ID keys.
DefaultIDKeyScheme = "Ed25519"
// DefaultIDKeySecurityLevel is the default security level for creating ID keys.
DefaultIDKeySecurityLevel = 256 // Ed25519 security level is fixed, setting is ignored.
)
// Identity holds the identity of a Hub.
type Identity struct {
record.Base
ID string
Map string
Hub *hub.Hub
Signet *jess.Signet
ExchKeys map[string]*ExchKey
infoExportCache []byte
statusExportCache []byte
}
// Lock locks the Identity through the Hub lock.
func (id *Identity) Lock() {
id.Hub.Lock()
}
// Unlock unlocks the Identity through the Hub lock.
func (id *Identity) Unlock() {
id.Hub.Unlock()
}
// ExchKey holds the private information of a HubKey.
type ExchKey struct {
Created time.Time
Expires time.Time
key *jess.Signet
tool *tools.Tool
}
// CreateIdentity creates a new identity.
func CreateIdentity(ctx context.Context, mapName string) (*Identity, error) {
id := &Identity{
Map: mapName,
ExchKeys: make(map[string]*ExchKey),
}
// create signet
signet, recipient, err := hub.CreateHubSignet(DefaultIDKeyScheme, DefaultIDKeySecurityLevel)
if err != nil {
return nil, err
}
id.Signet = signet
id.ID = signet.ID
id.Hub = &hub.Hub{
ID: id.ID,
Map: mapName,
PublicKey: recipient,
}
// initial maintenance routine
_, err = id.MaintainAnnouncement(nil, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize announcement: %w", err)
}
_, err = id.MaintainStatus([]*hub.Lane{}, new(int), nil, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize status: %w", err)
}
return id, nil
}
// MaintainAnnouncement maintains the Hub's Announcenemt and returns whether
// there was a change that should be communicated to other Hubs.
// If newInfo is nil, it will be derived from configuration.
func (id *Identity) MaintainAnnouncement(newInfo *hub.Announcement, selfcheck bool) (changed bool, err error) {
id.Lock()
defer id.Unlock()
// Populate new info with data.
if newInfo == nil {
newInfo = getPublicHubInfo()
}
newInfo.ID = id.Hub.ID
if id.Hub.Info != nil {
newInfo.Timestamp = id.Hub.Info.Timestamp
}
if !newInfo.Equal(id.Hub.Info) {
changed = true
}
if changed {
// Update timestamp.
newInfo.Timestamp = time.Now().Unix()
}
if changed || selfcheck {
// Export new data.
newInfoData, err := newInfo.Export(id.signingEnvelope())
if err != nil {
return false, fmt.Errorf("failed to export: %w", err)
}
// Apply the status as all other Hubs would in order to check if it's valid.
_, _, _, err = hub.ApplyAnnouncement(id.Hub, newInfoData, conf.MainMapName, conf.MainMapScope, true)
if err != nil {
return false, fmt.Errorf("failed to apply new announcement: %w", err)
}
id.infoExportCache = newInfoData
// Save message to hub message storage.
err = hub.SaveHubMsg(id.ID, conf.MainMapName, hub.MsgTypeAnnouncement, newInfoData)
if err != nil {
log.Warningf("spn/cabin: failed to save own new/updated announcement of %s: %s", id.ID, err)
}
}
return changed, nil
}
// MaintainStatus maintains the Hub's Status and returns whether there was a change that should be communicated to other Hubs.
func (id *Identity) MaintainStatus(lanes []*hub.Lane, load *int, flags []string, selfcheck bool) (changed bool, err error) {
id.Lock()
defer id.Unlock()
// Create a new status or make a copy of the status for editing.
var newStatus *hub.Status
if id.Hub.Status != nil {
newStatus = id.Hub.Status.Copy()
} else {
newStatus = &hub.Status{}
}
// Update software version.
if newStatus.Version != info.Version() {
newStatus.Version = info.Version()
changed = true
}
// Update keys.
keysChanged, err := id.MaintainExchKeys(newStatus, time.Now())
if err != nil {
return false, fmt.Errorf("failed to maintain keys: %w", err)
}
if keysChanged {
changed = true
}
// Update lanes.
if lanes != nil && !hub.LanesEqual(newStatus.Lanes, lanes) {
newStatus.Lanes = lanes
changed = true
}
// Update load.
if load != nil && newStatus.Load != *load {
newStatus.Load = *load
changed = true
}
// Update flags.
if !hub.FlagsEqual(newStatus.Flags, flags) {
newStatus.Flags = flags
changed = true
}
// Update timestamp if something changed.
if changed {
newStatus.Timestamp = time.Now().Unix()
}
if changed || selfcheck {
// Export new data.
newStatusData, err := newStatus.Export(id.signingEnvelope())
if err != nil {
return false, fmt.Errorf("failed to export: %w", err)
}
// Apply the status as all other Hubs would in order to check if it's valid.
_, _, _, err = hub.ApplyStatus(id.Hub, newStatusData, conf.MainMapName, conf.MainMapScope, true)
if err != nil {
return false, fmt.Errorf("failed to apply new status: %w", err)
}
id.statusExportCache = newStatusData
// Save message to hub message storage.
err = hub.SaveHubMsg(id.ID, conf.MainMapName, hub.MsgTypeStatus, newStatusData)
if err != nil {
log.Warningf("spn/cabin: failed to save own new/updated status: %s", err)
}
}
return changed, nil
}
// MakeOfflineStatus creates and signs an offline status message.
func (id *Identity) MakeOfflineStatus() (offlineStatusExport []byte, err error) {
// Make offline status.
newStatus := &hub.Status{
Timestamp: time.Now().Unix(),
Version: info.Version(),
Flags: []string{hub.FlagOffline},
}
// Export new data.
newStatusData, err := newStatus.Export(id.signingEnvelope())
if err != nil {
return nil, fmt.Errorf("failed to export: %w", err)
}
return newStatusData, nil
}
func (id *Identity) signingEnvelope() *jess.Envelope {
env := jess.NewUnconfiguredEnvelope()
env.SuiteID = jess.SuiteSignV1
env.Senders = []*jess.Signet{id.Signet}
return env
}
// ExportAnnouncement serializes and signs the Announcement.
func (id *Identity) ExportAnnouncement() ([]byte, error) {
id.Lock()
defer id.Unlock()
if id.infoExportCache == nil {
return nil, errors.New("announcement not exported")
}
return id.infoExportCache, nil
}
// ExportStatus serializes and signs the Status.
func (id *Identity) ExportStatus() ([]byte, error) {
id.Lock()
defer id.Unlock()
if id.statusExportCache == nil {
return nil, errors.New("status not exported")
}
return id.statusExportCache, nil
}
// SignHubMsg signs a data blob with the identity's private key.
func (id *Identity) SignHubMsg(data []byte) ([]byte, error) {
return hub.SignHubMsg(data, id.signingEnvelope(), false)
}
// GetSignet returns the private exchange key with the given ID.
func (id *Identity) GetSignet(keyID string, recipient bool) (*jess.Signet, error) {
if recipient {
return nil, errors.New("cabin.Identity only serves private keys")
}
id.Lock()
defer id.Unlock()
key, ok := id.ExchKeys[keyID]
if !ok {
return nil, errors.New("the requested key does not exist")
}
if time.Now().After(key.Expires) || key.key == nil {
return nil, errors.New("the requested key has expired")
}
return key.key, nil
}
func (ek *ExchKey) toHubKey() (*hub.Key, error) {
if ek.key == nil {
return nil, errors.New("no key")
}
// export public key
rcpt, err := ek.key.AsRecipient()
if err != nil {
return nil, err
}
err = rcpt.StoreKey()
if err != nil {
return nil, err
}
// repackage
return &hub.Key{
Scheme: rcpt.Scheme,
Key: rcpt.Key,
Expires: ek.Expires.Unix(),
}, nil
}

129
spn/cabin/identity_test.go Normal file
View File

@@ -0,0 +1,129 @@
package cabin
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/hub"
)
func TestIdentity(t *testing.T) {
t.Parallel()
// Register config options for public hub.
if err := prepPublicHubConfig(); err != nil {
t.Fatal(err)
}
// Create new identity.
identityTestKey := "core:spn/public/identity"
id, err := CreateIdentity(module.Ctx, conf.MainMapName)
if err != nil {
t.Fatal(err)
}
id.SetKey(identityTestKey)
// Check values
// Identity
assert.NotEmpty(t, id.ID, "id.ID must be set")
assert.NotEmpty(t, id.Map, "id.Map must be set")
assert.NotNil(t, id.Signet, "id.Signet must be set")
assert.NotNil(t, id.infoExportCache, "id.infoExportCache must be set")
assert.NotNil(t, id.statusExportCache, "id.statusExportCache must be set")
// Hub
assert.NotEmpty(t, id.Hub.ID, "hub.ID must be set")
assert.NotEmpty(t, id.Hub.Map, "hub.Map must be set")
assert.NotZero(t, id.Hub.FirstSeen, "hub.FirstSeen must be set")
// Info
assert.NotEmpty(t, id.Hub.Info.ID, "info.ID must be set")
assert.NotEqual(t, 0, id.Hub.Info.Timestamp, "info.Timestamp must be set")
assert.NotEqual(t, "", id.Hub.Info.Name, "info.Name must be set (to hostname)")
// Status
assert.NotEqual(t, 0, id.Hub.Status.Timestamp, "status.Timestamp must be set")
assert.NotEmpty(t, id.Hub.Status.Keys, "status.Keys must be set")
fmt.Printf("id: %+v\n", id)
fmt.Printf("id.hub: %+v\n", id.Hub)
fmt.Printf("id.Hub.Info: %+v\n", id.Hub.Info)
fmt.Printf("id.Hub.Status: %+v\n", id.Hub.Status)
// Maintenance is run in creation, so nothing should change now.
changed, err := id.MaintainAnnouncement(nil, false)
if err != nil {
t.Fatal(err)
}
if changed {
t.Error("unexpected change of announcement")
}
changed, err = id.MaintainStatus(nil, nil, nil, false)
if err != nil {
t.Fatal(err)
}
if changed {
t.Error("unexpected change of status")
}
// Change lanes.
lanes := []*hub.Lane{
{
ID: "A",
Capacity: 1,
Latency: 2,
},
{
ID: "B",
Capacity: 3,
Latency: 4,
},
{
ID: "C",
Capacity: 5,
Latency: 6,
},
}
changed, err = id.MaintainStatus(lanes, new(int), nil, false)
if err != nil {
t.Fatal(err)
}
if !changed {
t.Error("status should have changed")
}
// Change nothing.
changed, err = id.MaintainStatus(lanes, new(int), nil, false)
if err != nil {
t.Fatal(err)
}
if changed {
t.Error("unexpected change of status")
}
// Exporting
_, err = id.ExportAnnouncement()
if err != nil {
t.Fatal(err)
}
_, err = id.ExportStatus()
if err != nil {
t.Fatal(err)
}
// Save to and load from database.
err = id.Save()
if err != nil {
t.Fatal(err)
}
id2, changed, err := LoadIdentity(identityTestKey)
if err != nil {
t.Fatal(err)
}
if changed {
t.Error("unexpected change")
}
// Check if they match
assert.Equal(t, id, id2, "identities should be equal")
}

179
spn/cabin/keys.go Normal file
View File

@@ -0,0 +1,179 @@
package cabin
import (
"encoding/base64"
"errors"
"fmt"
"time"
"github.com/safing/jess"
"github.com/safing/jess/tools"
"github.com/safing/portbase/log"
"github.com/safing/portbase/rng"
"github.com/safing/portmaster/spn/hub"
)
type providedExchKeyScheme struct {
id string
securityLevel int //nolint:structcheck // TODO
tool *tools.Tool
}
var (
// validFor defines how long keys are valid for use by clients.
validFor = 48 * time.Hour // 2 days
// renewBeforeExpiry defines the duration how long before expiry keys should be renewed.
renewBeforeExpiry = 24 * time.Hour // 1 day
// burnAfter defines how long after expiry keys are burnt/deleted.
burnAfter = 12 * time.Hour // 1/2 day
// reuseAfter defines how long IDs should be blocked after expiry (and not be reused for new keys).
reuseAfter = 2 * 7 * 24 * time.Hour // 2 weeks
// provideExchKeySchemes defines the jess tools for creating exchange keys.
provideExchKeySchemes = []*providedExchKeyScheme{
{
id: "ECDH-X25519",
securityLevel: 128, // informative only, security level of ECDH-X25519 is fixed
},
// TODO: test with rsa keys
}
)
func initProvidedExchKeySchemes() error {
for _, eks := range provideExchKeySchemes {
tool, err := tools.Get(eks.id)
if err != nil {
return err
}
eks.tool = tool
}
return nil
}
// MaintainExchKeys maintains the exchange keys, creating new ones and
// deprecating and deleting old ones.
func (id *Identity) MaintainExchKeys(newStatus *hub.Status, now time.Time) (changed bool, err error) {
// create Keys map
if id.ExchKeys == nil {
id.ExchKeys = make(map[string]*ExchKey)
}
// lifecycle management
for keyID, exchKey := range id.ExchKeys {
if exchKey.key != nil && now.After(exchKey.Expires.Add(burnAfter)) {
// delete key
err := exchKey.tool.StaticLogic.BurnKey(exchKey.key)
if err != nil {
log.Warningf(
"spn/cabin: failed to burn key %s (%s) of %s: %s",
keyID,
exchKey.tool.Info.Name,
id.Hub.ID,
err,
)
}
// remove reference
exchKey.key = nil
}
if now.After(exchKey.Expires.Add(reuseAfter)) {
// remove key
delete(id.ExchKeys, keyID)
}
}
// find or create current keys
for _, eks := range provideExchKeySchemes {
found := false
for _, exchKey := range id.ExchKeys {
if exchKey.key != nil &&
exchKey.key.Scheme == eks.id &&
now.Before(exchKey.Expires.Add(-renewBeforeExpiry)) {
found = true
break
}
}
if !found {
err := id.createExchKey(eks, now)
if err != nil {
return false, fmt.Errorf("failed to create %s exchange key: %w", eks.tool.Info.Name, err)
}
changed = true
}
}
// export most recent keys to HubStatus
if changed || len(newStatus.Keys) == 0 {
// reset
newStatus.Keys = make(map[string]*hub.Key)
// find longest valid key for every provided scheme
for _, eks := range provideExchKeySchemes {
// find key of scheme that is valid the longest
longestValid := &ExchKey{
Expires: now,
}
for _, exchKey := range id.ExchKeys {
if exchKey.key != nil &&
exchKey.key.Scheme == eks.id &&
exchKey.Expires.After(longestValid.Expires) {
longestValid = exchKey
}
}
// check result
if longestValid.key == nil {
log.Warningf("spn/cabin: could not find export candidate for exchange key scheme %s", eks.id)
continue
}
// export
hubKey, err := longestValid.toHubKey()
if err != nil {
return false, fmt.Errorf("failed to export %s exchange key: %w", longestValid.tool.Info.Name, err)
}
// add
newStatus.Keys[longestValid.key.ID] = hubKey
}
}
return changed, nil
}
func (id *Identity) createExchKey(eks *providedExchKeyScheme, now time.Time) error {
// get ID
var keyID string
for i := 0; i < 1000000; i++ { // not forever
// generate new ID
b, err := rng.Bytes(3)
if err != nil {
return fmt.Errorf("failed to get random data for key ID: %w", err)
}
keyID = base64.RawURLEncoding.EncodeToString(b)
_, exists := id.ExchKeys[keyID]
if !exists {
break
}
}
if keyID == "" {
return errors.New("unable to find available exchange key ID")
}
// generate key
signet := jess.NewSignetBase(eks.tool)
signet.ID = keyID
// TODO: use security level for key generation
if err := signet.GenerateKey(); err != nil {
return fmt.Errorf("failed to get new exchange key: %w", err)
}
// add to key map
id.ExchKeys[keyID] = &ExchKey{
Created: now,
Expires: now.Add(validFor),
key: signet,
tool: eks.tool,
}
return nil
}

43
spn/cabin/keys_test.go Normal file
View File

@@ -0,0 +1,43 @@
package cabin
import (
"testing"
"time"
"github.com/safing/portmaster/spn/conf"
)
func TestKeyMaintenance(t *testing.T) {
t.Parallel()
id, err := CreateIdentity(module.Ctx, conf.MainMapName)
if err != nil {
t.Fatal(err)
}
iterations := 1000
changeCnt := 0
now := time.Now()
for i := 0; i < iterations; i++ {
changed, err := id.MaintainExchKeys(id.Hub.Status, now)
if err != nil {
t.Fatal(err)
}
if changed {
changeCnt++
t.Logf("===== exchange keys updated at %s:\n", now)
for keyID, exchKey := range id.ExchKeys {
t.Logf("[%s] %s %v\n", exchKey.Created, keyID, exchKey.key)
}
}
now = now.Add(1 * time.Hour)
}
if iterations/changeCnt > 25 { // one new key every 24 hours/ticks
t.Fatal("more changes than expected")
}
if len(id.ExchKeys) > 17 { // one new key every day for two weeks + 3 in use
t.Fatal("more keys than expected")
}
}

26
spn/cabin/module.go Normal file
View File

@@ -0,0 +1,26 @@
package cabin
import (
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/spn/conf"
)
var module *modules.Module
func init() {
module = modules.Register("cabin", prep, nil, nil, "base", "rng")
}
func prep() error {
if err := initProvidedExchKeySchemes(); err != nil {
return err
}
if conf.PublicHub() {
if err := prepPublicHubConfig(); err != nil {
return err
}
}
return nil
}

13
spn/cabin/module_test.go Normal file
View File

@@ -0,0 +1,13 @@
package cabin
import (
"testing"
"github.com/safing/portmaster/service/core/pmtesting"
"github.com/safing/portmaster/spn/conf"
)
func TestMain(m *testing.M) {
conf.EnablePublicHub(true)
pmtesting.TestMain(m, module)
}

157
spn/cabin/verification.go Normal file
View File

@@ -0,0 +1,157 @@
package cabin
import (
"crypto/subtle"
"errors"
"fmt"
"github.com/safing/jess"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/rng"
"github.com/safing/portmaster/spn/hub"
)
var (
verificationChallengeSize = 32
verificationChallengeMinSize = 16
verificationSigningSuite = jess.SuiteSignV1
verificationRequirements = jess.NewRequirements().
Remove(jess.Confidentiality).
Remove(jess.Integrity).
Remove(jess.RecipientAuthentication)
)
// Verification is used to verify certain aspects of another Hub.
type Verification struct {
// Challenge is a random value chosen by the client.
Challenge []byte `json:"c"`
// Purpose defines the purpose of the verification. Protects against using verification for other purposes.
Purpose string `json:"p"`
// ClientReference is an optional field for exchanging metadata about the client. Protects against forwarding/relay attacks.
ClientReference string `json:"cr"`
// ServerReference is an optional field for exchanging metadata about the server. Protects against forwarding/relay attacks.
ServerReference string `json:"sr"`
}
// CreateVerificationRequest creates a new verification request with the given
// purpose and references.
func CreateVerificationRequest(purpose, clientReference, serverReference string) (v *Verification, request []byte, err error) {
// Generate random challenge.
challenge, err := rng.Bytes(verificationChallengeSize)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate challenge: %w", err)
}
// Create verification object.
v = &Verification{
Purpose: purpose,
ClientReference: clientReference,
Challenge: challenge,
}
// Serialize verification.
request, err = dsd.Dump(v, dsd.JSON)
if err != nil {
return nil, nil, fmt.Errorf("failed to serialize verification request: %w", err)
}
// The server reference is not sent to the server, but needs to be supplied
// by the server.
v.ServerReference = serverReference
return v, request, nil
}
// SignVerificationRequest sign a verification request.
// The purpose and references must match the request, else the verification
// will fail.
func (id *Identity) SignVerificationRequest(request []byte, purpose, clientReference, serverReference string) (response []byte, err error) {
// Parse request.
v := new(Verification)
_, err = dsd.Load(request, v)
if err != nil {
return nil, fmt.Errorf("failed to parse request: %w", err)
}
// Validate request.
if len(v.Challenge) < verificationChallengeMinSize {
return nil, errors.New("challenge too small")
}
if v.Purpose != purpose {
return nil, errors.New("purpose mismatch")
}
if v.ClientReference != clientReference {
return nil, errors.New("client reference mismatch")
}
// Assign server reference and serialize.
v.ServerReference = serverReference
dataToSign, err := dsd.Dump(v, dsd.JSON)
if err != nil {
return nil, fmt.Errorf("failed to serialize verification response: %w", err)
}
// Sign response.
e := jess.NewUnconfiguredEnvelope()
e.SuiteID = verificationSigningSuite
e.Senders = []*jess.Signet{id.Signet}
jession, err := e.Correspondence(nil)
if err != nil {
return nil, fmt.Errorf("failed to setup signer: %w", err)
}
letter, err := jession.Close(dataToSign)
if err != nil {
return nil, fmt.Errorf("failed to sign: %w", err)
}
// Serialize and return.
signedResponse, err := letter.ToDSD(dsd.JSON)
if err != nil {
return nil, fmt.Errorf("failed to serialize letter: %w", err)
}
return signedResponse, nil
}
// Verify verifies the verification response and checks if everything is valid.
func (v *Verification) Verify(response []byte, h *hub.Hub) error {
// Parse response.
letter, err := jess.LetterFromDSD(response)
if err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
// Verify response.
responseData, err := letter.Open(
verificationRequirements,
&hub.SingleTrustStore{
Signet: h.PublicKey,
},
)
if err != nil {
return fmt.Errorf("failed to verify response: %w", err)
}
// Parse verified response.
responseV := new(Verification)
_, err = dsd.Load(responseData, responseV)
if err != nil {
return fmt.Errorf("failed to parse verified response: %w", err)
}
// Validate request.
if subtle.ConstantTimeCompare(v.Challenge, responseV.Challenge) != 1 {
return errors.New("challenge mismatch")
}
if subtle.ConstantTimeCompare([]byte(v.Purpose), []byte(responseV.Purpose)) != 1 {
return errors.New("purpose mismatch")
}
if subtle.ConstantTimeCompare([]byte(v.ClientReference), []byte(responseV.ClientReference)) != 1 {
return errors.New("client reference mismatch")
}
if subtle.ConstantTimeCompare([]byte(v.ServerReference), []byte(responseV.ServerReference)) != 1 {
return errors.New("server reference mismatch")
}
return nil
}

View File

@@ -0,0 +1,127 @@
package cabin
import (
"fmt"
"testing"
)
func TestVerification(t *testing.T) {
t.Parallel()
id, err := CreateIdentity(module.Ctx, "test")
if err != nil {
t.Fatal(err)
}
if err := testVerificationWith(
t, id,
"a", "b", "c",
"a", "b", "c",
"", "", "", nil,
); err != nil {
t.Fatal(err)
}
if err := testVerificationWith(
t, id,
"a", "b", "c",
"x", "b", "c",
"", "", "", nil,
); err == nil {
t.Fatal("should fail on purpose mismatch")
}
if err := testVerificationWith(
t, id,
"a", "b", "c",
"a", "x", "c",
"", "", "", nil,
); err == nil {
t.Fatal("should fail on client ref mismatch")
}
if err := testVerificationWith(
t, id,
"a", "b", "c",
"a", "b", "x",
"", "", "", nil,
); err == nil {
t.Fatal("should fail on server ref mismatch")
}
if err := testVerificationWith(
t, id,
"a", "b", "c",
"a", "b", "c",
"x", "", "", nil,
); err == nil {
t.Fatal("should fail on purpose mismatch")
}
if err := testVerificationWith(
t, id,
"a", "b", "c",
"a", "b", "c",
"", "x", "", nil,
); err == nil {
t.Fatal("should fail on client ref mismatch")
}
if err := testVerificationWith(
t, id,
"a", "b", "c",
"a", "b", "c",
"", "", "x", nil,
); err == nil {
t.Fatal("should fail on server ref mismatch")
}
if err := testVerificationWith(
t, id,
"a", "b", "c",
"a", "b", "c",
"", "", "", []byte{1, 2, 3, 4},
); err == nil {
t.Fatal("should fail on challenge mismatch")
}
}
func testVerificationWith(
t *testing.T, id *Identity,
purpose1, clientRef1, serverRef1 string, //nolint:unparam
purpose2, clientRef2, serverRef2 string,
mitmPurpose, mitmClientRef, mitmServerRef string,
mitmChallenge []byte,
) error {
t.Helper()
v, request, err := CreateVerificationRequest(purpose1, clientRef1, serverRef1)
if err != nil {
return fmt.Errorf("failed to create verification request: %w", err)
}
response, err := id.SignVerificationRequest(request, purpose2, clientRef2, serverRef2)
if err != nil {
return fmt.Errorf("failed to sign verification response: %w", err)
}
if mitmPurpose != "" {
v.Purpose = mitmPurpose
}
if mitmClientRef != "" {
v.ClientReference = mitmClientRef
}
if mitmServerRef != "" {
v.ServerReference = mitmServerRef
}
if mitmChallenge != nil {
v.Challenge = mitmChallenge
}
err = v.Verify(response, id.Hub)
if err != nil {
return fmt.Errorf("failed to verify: %w", err)
}
return nil
}

68
spn/captain/api.go Normal file
View File

@@ -0,0 +1,68 @@
package captain
import (
"errors"
"fmt"
"github.com/safing/portbase/api"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/modules"
)
const (
apiPathForSPNReInit = "spn/reinit"
)
func registerAPIEndpoints() error {
if err := api.RegisterEndpoint(api.Endpoint{
Path: apiPathForSPNReInit,
Write: api.PermitAdmin,
// BelongsTo: module, // Do not attach to module, as this must run outside of the module.
ActionFunc: handleReInit,
Name: "Re-initialize SPN",
Description: "Stops the SPN, resets all caches and starts it again. The SPN account and settings are not changed.",
}); err != nil {
return err
}
return nil
}
func handleReInit(ar *api.Request) (msg string, err error) {
// Disable module and check
changed := module.Disable()
if !changed {
return "", errors.New("can only re-initialize when the SPN is enabled")
}
// Run module manager.
err = modules.ManageModules()
if err != nil {
return "", fmt.Errorf("failed to stop SPN: %w", err)
}
// Delete SPN cache.
db := database.NewInterface(&database.Options{
Local: true,
Internal: true,
})
deletedRecords, err := db.Purge(ar.Context(), query.New("cache:spn/"))
if err != nil {
return "", fmt.Errorf("failed to delete SPN cache: %w", err)
}
// Enable module.
module.Enable()
// Run module manager.
err = modules.ManageModules()
if err != nil {
return "", fmt.Errorf("failed to start SPN after cache reset: %w", err)
}
return fmt.Sprintf(
"Completed SPN re-initialization and deleted %d cache records in the process.",
deletedRecords,
), nil
}

152
spn/captain/bootstrap.go Normal file
View File

@@ -0,0 +1,152 @@
package captain
import (
"errors"
"flag"
"fmt"
"io/fs"
"os"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/navigator"
)
// BootstrapFile is used for sideloading bootstrap data.
type BootstrapFile struct {
Main BootstrapFileEntry
}
// BootstrapFileEntry is the bootstrap data structure for one map.
type BootstrapFileEntry struct {
Hubs []string
}
var (
bootstrapHubFlag string
bootstrapFileFlag string
)
func init() {
flag.StringVar(&bootstrapHubFlag, "bootstrap-hub", "", "transport address of hub for bootstrapping with the hub ID in the fragment")
flag.StringVar(&bootstrapFileFlag, "bootstrap-file", "", "bootstrap file containing bootstrap hubs - will be initialized if running a public hub and it doesn't exist")
}
// prepBootstrapHubFlag checks the bootstrap-hub argument if it is valid.
func prepBootstrapHubFlag() error {
if bootstrapHubFlag != "" {
_, _, _, err := hub.ParseBootstrapHub(bootstrapHubFlag)
return err
}
return nil
}
// processBootstrapHubFlag processes the bootstrap-hub argument.
func processBootstrapHubFlag() error {
if bootstrapHubFlag != "" {
return navigator.Main.AddBootstrapHubs([]string{bootstrapHubFlag})
}
return nil
}
// processBootstrapFileFlag processes the bootstrap-file argument.
func processBootstrapFileFlag() error {
if bootstrapFileFlag == "" {
return nil
}
_, err := os.Stat(bootstrapFileFlag)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return createBootstrapFile(bootstrapFileFlag)
}
return fmt.Errorf("failed to access bootstrap hub file: %w", err)
}
return loadBootstrapFile(bootstrapFileFlag)
}
// bootstrapWithUpdates loads bootstrap hubs from the updates server and imports them.
func bootstrapWithUpdates() error {
if bootstrapFileFlag != "" {
return errors.New("using the bootstrap-file argument disables bootstrapping via the update system")
}
return updateSPNIntel(module.Ctx, nil)
}
// loadBootstrapFile loads a file with bootstrap hub entries and imports them.
func loadBootstrapFile(filename string) (err error) {
// Load bootstrap file from disk and parse it.
data, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("failed to load bootstrap file: %w", err)
}
bootstrapFile := &BootstrapFile{}
_, err = dsd.Load(data, bootstrapFile)
if err != nil {
return fmt.Errorf("failed to parse bootstrap file: %w", err)
}
if len(bootstrapFile.Main.Hubs) == 0 {
return errors.New("bootstrap holds no hubs for main map")
}
// Add Hubs to map.
err = navigator.Main.AddBootstrapHubs(bootstrapFile.Main.Hubs)
if err == nil {
log.Infof("spn/captain: loaded bootstrap file %s", filename)
}
return err
}
// createBootstrapFile save a bootstrap hub file with an entry of the public identity.
func createBootstrapFile(filename string) error {
if !conf.PublicHub() {
log.Infof("spn/captain: skipped writing a bootstrap hub file, as this is not a public hub")
return nil
}
// create bootstrap hub
if len(publicIdentity.Hub.Info.Transports) == 0 {
return errors.New("public identity has no transports available")
}
// parse first transport
t, err := hub.ParseTransport(publicIdentity.Hub.Info.Transports[0])
if err != nil {
return fmt.Errorf("failed to parse transport of public identity: %w", err)
}
// add IP address
switch {
case publicIdentity.Hub.Info.IPv4 != nil:
t.Domain = publicIdentity.Hub.Info.IPv4.String()
case publicIdentity.Hub.Info.IPv6 != nil:
t.Domain = "[" + publicIdentity.Hub.Info.IPv6.String() + "]"
default:
return errors.New("public identity has no IP address available")
}
// add Hub ID
t.Option = publicIdentity.Hub.ID
// put together
bs := &BootstrapFile{
Main: BootstrapFileEntry{
Hubs: []string{t.String()},
},
}
// serialize
fileData, err := dsd.Dump(bs, dsd.JSON)
if err != nil {
return err
}
// save to disk
err = os.WriteFile(filename, fileData, 0o0664) //nolint:gosec // Should be able to be read by others.
if err != nil {
return err
}
log.Infof("spn/captain: created bootstrap file %s", filename)
return nil
}

506
spn/captain/client.go Normal file
View File

@@ -0,0 +1,506 @@
package captain
import (
"context"
"errors"
"fmt"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/safing/portbase/notifications"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/spn/access"
"github.com/safing/portmaster/spn/crew"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/navigator"
"github.com/safing/portmaster/spn/terminal"
)
var (
ready = abool.New()
spnLoginButton = notifications.Action{
Text: "Login",
Type: notifications.ActionTypeOpenPage,
Payload: "spn",
}
spnOpenAccountPage = notifications.Action{
Text: "Open Account Page",
Type: notifications.ActionTypeOpenURL,
Payload: "https://account.safing.io",
}
)
// ClientReady signifies if the SPN client is fully ready to handle connections.
func ClientReady() bool {
return ready.IsSet()
}
type (
clientComponentFunc func(ctx context.Context) clientComponentResult
clientComponentResult uint8
)
const (
clientResultOk clientComponentResult = iota // Continue and clean module status.
clientResultRetry // Go back to start of current step, don't clear module status.
clientResultReconnect // Stop current connection and start from zero.
clientResultShutdown // SPN Module is shutting down.
)
var (
clientNetworkChangedFlag = netenv.GetNetworkChangedFlag()
clientIneligibleAccountUpdateDelay = 1 * time.Minute
clientRetryConnectBackoffDuration = 5 * time.Second
clientInitialHealthCheckDelay = 10 * time.Second
clientHealthCheckTickDuration = 1 * time.Minute
clientHealthCheckTickDurationSleepMode = 5 * time.Minute
clientHealthCheckTimeout = 15 * time.Second
clientHealthCheckTrigger = make(chan struct{}, 1)
lastHealthCheck time.Time
)
func triggerClientHealthCheck() {
select {
case clientHealthCheckTrigger <- struct{}{}:
default:
}
}
func clientManager(ctx context.Context) error {
defer func() {
ready.UnSet()
netenv.ConnectedToSPN.UnSet()
resetSPNStatus(StatusDisabled, true)
module.Resolve("")
clientStopHomeHub(ctx)
}()
module.Hint(
"spn:establishing-home-hub",
"Connecting to SPN...",
"Connecting to the SPN network is in progress.",
)
// TODO: When we are starting and the SPN module is faster online than the
// nameserver, then updating the account will fail as the DNS query is
// redirected to a closed port.
// We also can't add the nameserver as a module dependency, as the nameserver
// is not part of the server.
select {
case <-time.After(1 * time.Second):
case <-ctx.Done():
return nil
}
healthCheckTicker := module.NewSleepyTicker(clientHealthCheckTickDuration, clientHealthCheckTickDurationSleepMode)
reconnect:
for {
// Check if we are shutting down.
select {
case <-ctx.Done():
return nil
default:
}
// Reset SPN status.
if ready.SetToIf(true, false) {
netenv.ConnectedToSPN.UnSet()
log.Info("spn/captain: client not ready")
}
resetSPNStatus(StatusConnecting, true)
// Check everything and connect to the SPN.
for _, clientFunc := range []clientComponentFunc{
clientStopHomeHub,
clientCheckNetworkReady,
clientCheckAccountAndTokens,
clientConnectToHomeHub,
clientSetActiveConnectionStatus,
} {
switch clientFunc(ctx) {
case clientResultOk:
// Continue
case clientResultRetry, clientResultReconnect:
// Wait for a short time to not loop too quickly.
select {
case <-time.After(clientRetryConnectBackoffDuration):
continue reconnect
case <-ctx.Done():
return nil
}
case clientResultShutdown:
return nil
}
}
log.Info("spn/captain: client is ready")
ready.Set()
netenv.ConnectedToSPN.Set()
module.TriggerEvent(SPNConnectedEvent, nil)
module.StartWorker("update quick setting countries", navigator.Main.UpdateConfigQuickSettings)
// Reset last health check value, as we have just connected.
lastHealthCheck = time.Now()
// Back off before starting initial health checks.
select {
case <-time.After(clientInitialHealthCheckDelay):
case <-ctx.Done():
return nil
}
for {
// Check health of the current SPN connection and monitor the user status.
maintainers:
for _, clientFunc := range []clientComponentFunc{
clientCheckHomeHubConnection,
clientCheckAccountAndTokens,
clientSetActiveConnectionStatus,
} {
switch clientFunc(ctx) {
case clientResultOk:
// Continue
case clientResultRetry:
// Abort and wait for the next run.
break maintainers
case clientResultReconnect:
continue reconnect
case clientResultShutdown:
return nil
}
}
// Wait for signal to run maintenance again.
select {
case <-healthCheckTicker.Wait():
case <-clientHealthCheckTrigger:
case <-crew.ConnectErrors():
case <-clientNetworkChangedFlag.Signal():
clientNetworkChangedFlag.Refresh()
case <-ctx.Done():
return nil
}
}
}
}
func clientCheckNetworkReady(ctx context.Context) clientComponentResult {
// Check if we are online enough for connecting.
switch netenv.GetOnlineStatus() { //nolint:exhaustive
case netenv.StatusOffline,
netenv.StatusLimited:
select {
case <-ctx.Done():
return clientResultShutdown
case <-time.After(1 * time.Second):
return clientResultRetry
}
}
return clientResultOk
}
// DisableAccount disables using any account related SPN functionality.
// Attempts to use the same will result in errors.
var DisableAccount bool
func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult {
if DisableAccount {
return clientResultOk
}
// Get SPN user.
user, err := access.GetUser()
if err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
notifications.NotifyError(
"spn:failed-to-get-user",
"SPN Internal Error",
`Please restart Portmaster.`,
// TODO: Add restart button.
// TODO: Use special UI restart action in order to reload UI on restart.
).AttachToModule(module)
resetSPNStatus(StatusFailed, true)
log.Errorf("spn/captain: client internal error: %s", err)
return clientResultReconnect
}
// Check if user is logged in.
if user == nil || !user.IsLoggedIn() {
notifications.NotifyWarn(
"spn:not-logged-in",
"SPN Login Required",
`Please log in to access the SPN.`,
spnLoginButton,
).AttachToModule(module)
resetSPNStatus(StatusFailed, true)
log.Warningf("spn/captain: enabled but not logged in")
return clientResultReconnect
}
// Check if user is eligible.
if !user.MayUseTheSPN() {
// Update user in case there was a change.
// Only update here if we need to - there is an update task in the access
// module for periodic updates.
if time.Now().Add(-clientIneligibleAccountUpdateDelay).After(time.Unix(user.Meta().Modified, 0)) {
_, _, err := access.UpdateUser()
if err != nil {
notifications.NotifyError(
"spn:failed-to-update-user",
"SPN Account Server Error",
fmt.Sprintf(`The status of your SPN account could not be updated: %s`, err),
).AttachToModule(module)
resetSPNStatus(StatusFailed, true)
log.Errorf("spn/captain: failed to update ineligible account: %s", err)
return clientResultReconnect
}
}
// Check if user is eligible after a possible update.
if !user.MayUseTheSPN() {
// If package is generally valid, then the current package does not have access to the SPN.
if user.MayUse("") {
notifications.NotifyError(
"spn:package-not-eligible",
"SPN Not Included In Package",
"Your current Portmaster Package does not include access to the SPN. Please upgrade your package on the Account Page.",
spnOpenAccountPage,
).AttachToModule(module)
resetSPNStatus(StatusFailed, true)
return clientResultReconnect
}
// Otherwise, include the message from the user view.
message := "There is an issue with your Portmaster Package. Please check the Account Page."
if user.View != nil && user.View.Message != "" {
message = user.View.Message
}
notifications.NotifyError(
"spn:subscription-inactive",
"Portmaster Package Issue",
"Cannot enable SPN: "+message,
spnOpenAccountPage,
).AttachToModule(module)
resetSPNStatus(StatusFailed, true)
return clientResultReconnect
}
}
// Check if we have enough tokens.
if access.ShouldRequest(access.ExpandAndConnectZones) {
err := access.UpdateTokens()
if err != nil {
log.Errorf("spn/captain: failed to get tokens: %s", err)
// There was an error updating the account.
// Check if we have enough tokens to continue anyway.
regular, _ := access.GetTokenAmount(access.ExpandAndConnectZones)
if regular == 0 /* && fallback == 0 */ { // TODO: Add fallback token check when fallback was tested on servers.
notifications.NotifyError(
"spn:tokens-exhausted",
"SPN Access Tokens Exhausted",
`The Portmaster failed to get new access tokens to access the SPN. The Portmaster will automatically retry to get new access tokens.`,
).AttachToModule(module)
resetSPNStatus(StatusFailed, false)
}
return clientResultRetry
}
}
return clientResultOk
}
func clientStopHomeHub(ctx context.Context) clientComponentResult {
// Don't use the context in this function, as it will likely be canceled
// already and would disrupt any context usage in here.
// Get crane connecting to home.
home, _ := navigator.Main.GetHome()
if home == nil {
return clientResultOk
}
crane := docks.GetAssignedCrane(home.Hub.ID)
if crane == nil {
return clientResultOk
}
// Stop crane and all connected terminals.
crane.Stop(nil)
return clientResultOk
}
func clientConnectToHomeHub(ctx context.Context) clientComponentResult {
err := establishHomeHub(ctx)
if err != nil {
log.Errorf("spn/captain: failed to establish connection to home hub: %s", err)
resetSPNStatus(StatusFailed, true)
switch {
case errors.Is(err, ErrAllHomeHubsExcluded):
notifications.NotifyError(
"spn:all-home-hubs-excluded",
"All Home Nodes Excluded",
"Your current Home Node Rules exclude all available and eligible SPN Nodes. Please change your rules to allow for at least one available and eligible Home Node.",
notifications.Action{
Text: "Configure",
Type: notifications.ActionTypeOpenSetting,
Payload: &notifications.ActionTypeOpenSettingPayload{
Key: CfgOptionHomeHubPolicyKey,
},
},
).AttachToModule(module)
case errors.Is(err, ErrReInitSPNSuggested):
notifications.NotifyError(
"spn:cannot-bootstrap",
"SPN Cannot Bootstrap",
"The local state of the SPN network is likely outdated. Portmaster was not able to identify a server to connect to. Please re-initialize the SPN using the tools menu or the button on the notification.",
notifications.Action{
ID: "re-init",
Text: "Re-Init SPN",
Type: notifications.ActionTypeWebhook,
Payload: &notifications.ActionTypeWebhookPayload{
URL: apiPathForSPNReInit,
ResultAction: "display",
},
},
).AttachToModule(module)
default:
notifications.NotifyWarn(
"spn:home-hub-failure",
"SPN Failed to Connect",
fmt.Sprintf("Failed to connect to a home hub: %s. The Portmaster will retry to connect automatically.", err),
).AttachToModule(module)
}
return clientResultReconnect
}
// Log new connection.
home, _ := navigator.Main.GetHome()
if home != nil {
log.Infof("spn/captain: established new home %s", home.Hub)
}
return clientResultOk
}
func clientSetActiveConnectionStatus(ctx context.Context) clientComponentResult {
// Get current home.
home, homeTerminal := navigator.Main.GetHome()
if home == nil || homeTerminal == nil {
return clientResultReconnect
}
// Resolve any connection error.
module.Resolve("")
// Update SPN Status with connection information, if not already correctly set.
spnStatus.Lock()
defer spnStatus.Unlock()
if spnStatus.Status != StatusConnected || spnStatus.HomeHubID != home.Hub.ID {
// Fill connection status data.
spnStatus.Status = StatusConnected
spnStatus.HomeHubID = home.Hub.ID
spnStatus.HomeHubName = home.Hub.Info.Name
connectedIP, _, err := netutils.IPPortFromAddr(homeTerminal.RemoteAddr())
if err != nil {
spnStatus.ConnectedIP = homeTerminal.RemoteAddr().String()
} else {
spnStatus.ConnectedIP = connectedIP.String()
}
spnStatus.ConnectedTransport = homeTerminal.Transport().String()
geoLoc := home.GetLocation(connectedIP)
if geoLoc != nil {
spnStatus.ConnectedCountry = &geoLoc.Country
}
now := time.Now()
spnStatus.ConnectedSince = &now
// Push new status.
pushSPNStatusUpdate()
}
return clientResultOk
}
func clientCheckHomeHubConnection(ctx context.Context) clientComponentResult {
// Check the status of the Home Hub.
home, homeTerminal := navigator.Main.GetHome()
if home == nil || homeTerminal == nil || homeTerminal.IsBeingAbandoned() {
return clientResultReconnect
}
// Get crane controller for health check.
crane := docks.GetAssignedCrane(home.Hub.ID)
if crane == nil {
log.Errorf("spn/captain: could not find home hub crane for health check")
return clientResultOk
}
// Ping home hub.
latency, tErr := pingHome(ctx, crane.Controller, clientHealthCheckTimeout)
if tErr != nil {
log.Warningf("spn/captain: failed to ping home hub: %s", tErr)
// Prepare to reconnect to the network.
// Reset all failing states, as these might have been caused by the failing home hub.
navigator.Main.ResetFailingStates(ctx)
// If the last health check is clearly too long ago, assume that the device was sleeping and do not set the home node to failing yet.
if time.Since(lastHealthCheck) > clientHealthCheckTickDuration+
clientHealthCheckTickDurationSleepMode+
(clientHealthCheckTimeout*2) {
return clientResultReconnect
}
// Mark the home hub itself as failing, as we want to try to connect to somewhere else.
home.MarkAsFailingFor(5 * time.Minute)
return clientResultReconnect
}
lastHealthCheck = time.Now()
log.Debugf("spn/captain: pinged home hub in %s", latency)
return clientResultOk
}
func pingHome(ctx context.Context, t terminal.Terminal, timeout time.Duration) (latency time.Duration, err *terminal.Error) {
started := time.Now()
// Start ping operation.
pingOp, tErr := crew.NewPingOp(t)
if tErr != nil {
return 0, tErr
}
// Wait for response.
select {
case <-ctx.Done():
return 0, terminal.ErrCanceled
case <-time.After(timeout):
return 0, terminal.ErrTimeout
case result := <-pingOp.Result:
if result.Is(terminal.ErrExplicitAck) {
return time.Since(started), nil
}
if result.IsOK() {
return 0, result.Wrap("unexpected response")
}
return 0, result
}
}

253
spn/captain/config.go Normal file
View File

@@ -0,0 +1,253 @@
package captain
import (
"sync"
"github.com/safing/portbase/config"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/navigator"
)
var (
// CfgOptionEnableSPNKey is the configuration key for the SPN module.
CfgOptionEnableSPNKey = "spn/enable"
cfgOptionEnableSPNOrder = 128
// CfgOptionHomeHubPolicyKey is the configuration key for the SPN home policy.
CfgOptionHomeHubPolicyKey = "spn/homePolicy"
cfgOptionHomeHubPolicy config.StringArrayOption
cfgOptionHomeHubPolicyOrder = 145
// CfgOptionDNSExitHubPolicyKey is the configuration key for the SPN DNS exit policy.
CfgOptionDNSExitHubPolicyKey = "spn/dnsExitPolicy"
cfgOptionDNSExitHubPolicy config.StringArrayOption
cfgOptionDNSExitHubPolicyOrder = 148
// CfgOptionUseCommunityNodesKey is the configuration key for whether to use community nodes.
CfgOptionUseCommunityNodesKey = "spn/useCommunityNodes"
cfgOptionUseCommunityNodes config.BoolOption
cfgOptionUseCommunityNodesOrder = 149
// NonCommunityVerifiedOwners holds a list of verified owners that are not
// considered "community".
NonCommunityVerifiedOwners = []string{"Safing"}
// CfgOptionTrustNodeNodesKey is the configuration key for whether additional trusted nodes.
CfgOptionTrustNodeNodesKey = "spn/trustNodes"
cfgOptionTrustNodeNodes config.StringArrayOption
cfgOptionTrustNodeNodesOrder = 150
// Special Access Code.
cfgOptionSpecialAccessCodeKey = "spn/specialAccessCode"
cfgOptionSpecialAccessCodeDefault = "none"
cfgOptionSpecialAccessCode config.StringOption //nolint:unused // Linter, you drunk?
cfgOptionSpecialAccessCodeOrder = 160
// IPv6 must be global and accessible.
cfgOptionBindToAdvertisedKey = "spn/publicHub/bindToAdvertised"
cfgOptionBindToAdvertised config.BoolOption
cfgOptionBindToAdvertisedDefault = false
cfgOptionBindToAdvertisedOrder = 161
// Config options for use.
cfgOptionRoutingAlgorithm config.StringOption
)
func prepConfig() error {
// Home Node Rules
err := config.Register(&config.Option{
Name: "Home Node Rules",
Key: CfgOptionHomeHubPolicyKey,
Description: `Customize which countries should or should not be used for your Home Node. The Home Node is your entry into the SPN. You connect directly to it and all your connections are routed through it.
By default, the Portmaster tries to choose the nearest node as your Home Node in order to reduce your exposure to the open Internet.
Reconnect to the SPN in order to apply new rules.`,
Help: profile.SPNRulesHelp,
Sensitive: true,
OptType: config.OptTypeStringArray,
RequiresRestart: true,
ExpertiseLevel: config.ExpertiseLevelExpert,
DefaultValue: []string{},
Annotations: config.Annotations{
config.CategoryAnnotation: "Routing",
config.DisplayOrderAnnotation: cfgOptionHomeHubPolicyOrder,
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
config.QuickSettingsAnnotation: profile.SPNRulesQuickSettings,
endpoints.EndpointListVerdictNamesAnnotation: profile.SPNRulesVerdictNames,
},
ValidationRegex: endpoints.ListEntryValidationRegex,
ValidationFunc: endpoints.ValidateEndpointListConfigOption,
})
if err != nil {
return err
}
cfgOptionHomeHubPolicy = config.Concurrent.GetAsStringArray(CfgOptionHomeHubPolicyKey, []string{})
// DNS Exit Node Rules
err = config.Register(&config.Option{
Name: "DNS Exit Node Rules",
Key: CfgOptionDNSExitHubPolicyKey,
Description: `Customize which countries should or should not be used as DNS Exit Nodes.
By default, the Portmaster will exit DNS requests directly at your Home Node in order to keep them fast and close to your location. This is important, as DNS resolution often takes your approximate location into account when deciding which optimized DNS records are returned to you. As the Portmaster encrypts your DNS requests by default, you effectively gain a two-hop security level for your DNS requests in order to protect your privacy.
This setting mainly exists for when you need to simulate your presence in another location on a lower level too. This might be necessary to defeat more intelligent geo-blocking systems.`,
Help: profile.SPNRulesHelp,
Sensitive: true,
OptType: config.OptTypeStringArray,
RequiresRestart: true,
ExpertiseLevel: config.ExpertiseLevelExpert,
DefaultValue: []string{},
Annotations: config.Annotations{
config.CategoryAnnotation: "Routing",
config.DisplayOrderAnnotation: cfgOptionDNSExitHubPolicyOrder,
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
config.QuickSettingsAnnotation: profile.SPNRulesQuickSettings,
endpoints.EndpointListVerdictNamesAnnotation: profile.SPNRulesVerdictNames,
},
ValidationRegex: endpoints.ListEntryValidationRegex,
ValidationFunc: endpoints.ValidateEndpointListConfigOption,
})
if err != nil {
return err
}
cfgOptionDNSExitHubPolicy = config.Concurrent.GetAsStringArray(CfgOptionDNSExitHubPolicyKey, []string{})
err = config.Register(&config.Option{
Name: "Use Community Nodes",
Key: CfgOptionUseCommunityNodesKey,
Description: "Use nodes (servers) not operated by Safing themselves. The use of community nodes is recommended as it diversifies the ownership of the nodes you use for your connections and further strengthens your privacy. Plain connections (eg. http, smtp, ...) will never exit via community nodes, making this setting safe to use.",
Sensitive: true,
OptType: config.OptTypeBool,
RequiresRestart: true,
DefaultValue: true,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionUseCommunityNodesOrder,
config.CategoryAnnotation: "Routing",
},
})
if err != nil {
return err
}
cfgOptionUseCommunityNodes = config.Concurrent.GetAsBool(CfgOptionUseCommunityNodesKey, true)
err = config.Register(&config.Option{
Name: "Trust Nodes",
Key: CfgOptionTrustNodeNodesKey,
Description: "Specify which community nodes to additionally trust. These nodes may then also be used as a Home Node, as well as an Exit Node for unencrypted connections.",
Help: "You can specify nodes by their ID or their verified operator.",
Sensitive: true,
OptType: config.OptTypeStringArray,
ExpertiseLevel: config.ExpertiseLevelExpert,
DefaultValue: []string{},
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionTrustNodeNodesOrder,
config.CategoryAnnotation: "Routing",
},
})
if err != nil {
return err
}
cfgOptionTrustNodeNodes = config.Concurrent.GetAsStringArray(CfgOptionTrustNodeNodesKey, []string{})
err = config.Register(&config.Option{
Name: "Special Access Code",
Key: cfgOptionSpecialAccessCodeKey,
Description: "Special Access Codes grant access to the SPN for testing or evaluation purposes.",
Sensitive: true,
OptType: config.OptTypeString,
DefaultValue: cfgOptionSpecialAccessCodeDefault,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionSpecialAccessCodeOrder,
config.CategoryAnnotation: "Advanced",
},
})
if err != nil {
return err
}
cfgOptionSpecialAccessCode = config.Concurrent.GetAsString(cfgOptionSpecialAccessCodeKey, "")
if conf.PublicHub() {
err = config.Register(&config.Option{
Name: "Connect From Advertised IPs Only",
Key: cfgOptionBindToAdvertisedKey,
Description: "Only connect from (bind to) the advertised IP addresses.",
OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelExpert,
DefaultValue: cfgOptionBindToAdvertisedDefault,
RequiresRestart: true,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionBindToAdvertisedOrder,
},
})
if err != nil {
return err
}
cfgOptionBindToAdvertised = config.GetAsBool(cfgOptionBindToAdvertisedKey, cfgOptionBindToAdvertisedDefault)
}
// Config options for use.
cfgOptionRoutingAlgorithm = config.Concurrent.GetAsString(profile.CfgOptionRoutingAlgorithmKey, navigator.DefaultRoutingProfileID)
return nil
}
var (
homeHubPolicy endpoints.Endpoints
homeHubPolicyLock sync.Mutex
homeHubPolicyConfigFlag = config.NewValidityFlag()
)
func getHomeHubPolicy() (endpoints.Endpoints, error) {
homeHubPolicyLock.Lock()
defer homeHubPolicyLock.Unlock()
// Return cached value if config is still valid.
if homeHubPolicyConfigFlag.IsValid() {
return homeHubPolicy, nil
}
homeHubPolicyConfigFlag.Refresh()
// Parse new policy.
policy, err := endpoints.ParseEndpoints(cfgOptionHomeHubPolicy())
if err != nil {
homeHubPolicy = nil
return nil, err
}
// Save and return the new policy.
homeHubPolicy = policy
return homeHubPolicy, nil
}
var (
dnsExitHubPolicy endpoints.Endpoints
dnsExitHubPolicyLock sync.Mutex
dnsExitHubPolicyConfigFlag = config.NewValidityFlag()
)
// GetDNSExitHubPolicy return the current DNS exit policy.
func GetDNSExitHubPolicy() (endpoints.Endpoints, error) {
dnsExitHubPolicyLock.Lock()
defer dnsExitHubPolicyLock.Unlock()
// Return cached value if config is still valid.
if dnsExitHubPolicyConfigFlag.IsValid() {
return dnsExitHubPolicy, nil
}
dnsExitHubPolicyConfigFlag.Refresh()
// Parse new policy.
policy, err := endpoints.ParseEndpoints(cfgOptionDNSExitHubPolicy())
if err != nil {
dnsExitHubPolicy = nil
return nil, err
}
// Save and return the new policy.
dnsExitHubPolicy = policy
return dnsExitHubPolicy, nil
}

105
spn/captain/establish.go Normal file
View File

@@ -0,0 +1,105 @@
package captain
import (
"context"
"errors"
"fmt"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/ships"
"github.com/safing/portmaster/spn/terminal"
)
// EstablishCrane establishes a crane to another Hub.
func EstablishCrane(callerCtx context.Context, dst *hub.Hub) (*docks.Crane, error) {
if conf.PublicHub() && dst.ID == publicIdentity.ID {
return nil, errors.New("connecting to self")
}
if docks.GetAssignedCrane(dst.ID) != nil {
return nil, fmt.Errorf("route to %s already exists", dst.ID)
}
ship, err := ships.Launch(callerCtx, dst, nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to launch ship: %w", err)
}
// On pure clients, mark all ships as public in order to show unmasked data in logs.
if conf.Client() && !conf.PublicHub() {
ship.MarkPublic()
}
crane, err := docks.NewCrane(ship, dst, publicIdentity)
if err != nil {
return nil, fmt.Errorf("failed to create crane: %w", err)
}
err = crane.Start(callerCtx)
if err != nil {
return nil, fmt.Errorf("failed to start crane: %w", err)
}
// Start gossip op for live map updates.
_, tErr := NewGossipOp(crane.Controller)
if tErr != nil {
crane.Stop(tErr)
return nil, fmt.Errorf("failed to start gossip op: %w", tErr)
}
return crane, nil
}
// EstablishPublicLane establishes a crane to another Hub and publishes it.
func EstablishPublicLane(ctx context.Context, dst *hub.Hub) (*docks.Crane, *terminal.Error) {
// Create new context with timeout.
// The maximum timeout is a worst case safeguard.
// Keep in mind that multiple IPs and protocols may be tried in all configurations.
// Some servers will be (possibly on purpose) hard to reach.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
// Connect to destination and establish communication.
crane, err := EstablishCrane(ctx, dst)
if err != nil {
return nil, terminal.ErrInternalError.With("failed to establish crane: %w", err)
}
// Publish as Lane.
publishOp, tErr := NewPublishOp(crane.Controller, publicIdentity)
if tErr != nil {
return nil, terminal.ErrInternalError.With("failed to publish: %w", err)
}
// Wait for publishing to complete.
select {
case tErr := <-publishOp.Result():
if !tErr.Is(terminal.ErrExplicitAck) {
// Stop crane again, because we failed to publish it.
defer crane.Stop(nil)
return nil, terminal.ErrInternalError.With("failed to publish lane: %w", tErr)
}
case <-crane.Controller.Ctx().Done():
defer crane.Stop(nil)
return nil, terminal.ErrStopping
case <-ctx.Done():
defer crane.Stop(nil)
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return nil, terminal.ErrTimeout
}
return nil, terminal.ErrCanceled
}
// Query all gossip msgs.
_, tErr = NewGossipQueryOp(crane.Controller)
if tErr != nil {
log.Warningf("spn/captain: failed to start initial gossip query: %s", tErr)
}
return crane, nil
}

28
spn/captain/exceptions.go Normal file
View File

@@ -0,0 +1,28 @@
package captain
import (
"net"
"sync"
)
var (
exceptionLock sync.Mutex
exceptIPv4 net.IP
exceptIPv6 net.IP
)
func setExceptions(ipv4, ipv6 net.IP) {
exceptionLock.Lock()
defer exceptionLock.Unlock()
exceptIPv4 = ipv4
exceptIPv6 = ipv6
}
// IsExcepted checks if the given IP is currently excepted from the SPN.
func IsExcepted(ip net.IP) bool {
exceptionLock.Lock()
defer exceptionLock.Unlock()
return ip.Equal(exceptIPv4) || ip.Equal(exceptIPv6)
}

38
spn/captain/gossip.go Normal file
View File

@@ -0,0 +1,38 @@
package captain
import (
"sync"
)
var (
gossipOps = make(map[string]*GossipOp)
gossipOpsLock sync.RWMutex
)
func registerGossipOp(craneID string, op *GossipOp) {
gossipOpsLock.Lock()
defer gossipOpsLock.Unlock()
gossipOps[craneID] = op
}
func deleteGossipOp(craneID string) {
gossipOpsLock.Lock()
defer gossipOpsLock.Unlock()
delete(gossipOps, craneID)
}
func gossipRelayMsg(receivedFrom string, msgType GossipMsgType, data []byte) {
gossipOpsLock.RLock()
defer gossipOpsLock.RUnlock()
for craneID, gossipOp := range gossipOps {
// Don't return same msg back to sender.
if craneID == receivedFrom {
continue
}
gossipOp.sendMsg(msgType, data)
}
}

47
spn/captain/hooks.go Normal file
View File

@@ -0,0 +1,47 @@
package captain
import (
"time"
"github.com/safing/portmaster/service/updates"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/docks"
)
func startDockHooks() {
docks.RegisterCraneUpdateHook(handleCraneUpdate)
}
func stopDockHooks() {
docks.ResetCraneUpdateHook()
}
func handleCraneUpdate(crane *docks.Crane) {
if crane == nil {
return
}
if conf.Client() && crane.Controller != nil && crane.Controller.Abandoning.IsSet() {
// Check connection to home hub.
triggerClientHealthCheck()
}
if conf.PublicHub() && crane.Public() {
// Update Hub status.
updateConnectionStatus()
}
}
func updateConnectionStatus() {
// Delay updating status for a better chance to combine multiple changes.
statusUpdateTask.Schedule(time.Now().Add(maintainStatusUpdateDelay))
// Check if we lost all connections and trigger a pending restart if we did.
for _, crane := range docks.GetAllAssignedCranes() {
if crane.Public() && !crane.Stopped() {
// There is at least one public and active crane, so don't restart now.
return
}
}
updates.TriggerRestartIfPending()
}

108
spn/captain/intel.go Normal file
View File

@@ -0,0 +1,108 @@
package captain
import (
"context"
"fmt"
"os"
"sync"
"github.com/safing/portbase/config"
"github.com/safing/portbase/updater"
"github.com/safing/portmaster/service/updates"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/navigator"
"github.com/safing/portmaster/spn/ships"
)
var (
intelResource *updater.File
intelResourcePath = "intel/spn/main-intel.yaml"
intelResourceMapName = "main"
intelResourceUpdateLock sync.Mutex
)
func registerIntelUpdateHook() error {
if err := module.RegisterEventHook(
updates.ModuleName,
updates.ResourceUpdateEvent,
"update SPN intel",
updateSPNIntel,
); err != nil {
return err
}
if err := module.RegisterEventHook(
"config",
config.ChangeEvent,
"update SPN intel",
updateSPNIntel,
); err != nil {
return err
}
return nil
}
func updateSPNIntel(ctx context.Context, _ interface{}) (err error) {
intelResourceUpdateLock.Lock()
defer intelResourceUpdateLock.Unlock()
// Only update SPN intel when using the matching map.
if conf.MainMapName != intelResourceMapName {
return fmt.Errorf("intel resource not for map %q", conf.MainMapName)
}
// Check if there is something to do.
if intelResource != nil && !intelResource.UpgradeAvailable() {
return nil
}
// Get intel file and load it from disk.
intelResource, err = updates.GetFile(intelResourcePath)
if err != nil {
return fmt.Errorf("failed to get SPN intel update: %w", err)
}
intelData, err := os.ReadFile(intelResource.Path())
if err != nil {
return fmt.Errorf("failed to load SPN intel update: %w", err)
}
// Parse and apply intel data.
intel, err := hub.ParseIntel(intelData)
if err != nil {
return fmt.Errorf("failed to parse SPN intel update: %w", err)
}
setVirtualNetworkConfig(intel.VirtualNetworks)
return navigator.Main.UpdateIntel(intel, cfgOptionTrustNodeNodes())
}
func resetSPNIntel() {
intelResourceUpdateLock.Lock()
defer intelResourceUpdateLock.Unlock()
intelResource = nil
}
func setVirtualNetworkConfig(configs []*hub.VirtualNetworkConfig) {
// Do nothing if not public Hub.
if !conf.PublicHub() {
return
}
// Reset if there are no virtual networks configured.
if len(configs) == 0 {
ships.SetVirtualNetworkConfig(nil)
}
// Check if we are in a virtual network.
for _, config := range configs {
if _, ok := config.Mapping[publicIdentity.Hub.ID]; ok {
ships.SetVirtualNetworkConfig(config)
return
}
}
// If not, reset - we might have been in one before.
ships.SetVirtualNetworkConfig(nil)
}

219
spn/captain/module.go Normal file
View File

@@ -0,0 +1,219 @@
package captain
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"time"
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/rng"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/crew"
"github.com/safing/portmaster/spn/navigator"
"github.com/safing/portmaster/spn/patrol"
"github.com/safing/portmaster/spn/ships"
_ "github.com/safing/portmaster/spn/sluice"
)
const controlledFailureExitCode = 24
var module *modules.Module
// SPNConnectedEvent is the name of the event that is fired when the SPN has connected and is ready.
const SPNConnectedEvent = "spn connect"
func init() {
module = modules.Register("captain", prep, start, stop, "base", "terminal", "cabin", "ships", "docks", "crew", "navigator", "sluice", "patrol", "netenv")
module.RegisterEvent(SPNConnectedEvent, false)
subsystems.Register(
"spn",
"SPN",
"Safing Privacy Network",
module,
"config:spn/",
&config.Option{
Name: "SPN Module",
Key: CfgOptionEnableSPNKey,
Description: "Start the Safing Privacy Network module. If turned off, the SPN is fully disabled on this device.",
OptType: config.OptTypeBool,
DefaultValue: false,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionEnableSPNOrder,
config.CategoryAnnotation: "General",
},
},
)
}
func prep() error {
// Check if we can parse the bootstrap hub flag.
if err := prepBootstrapHubFlag(); err != nil {
return err
}
// Register SPN status provider.
if err := registerSPNStatusProvider(); err != nil {
return err
}
// Register API endpoints.
if err := registerAPIEndpoints(); err != nil {
return err
}
if conf.PublicHub() {
// Register API authenticator.
if err := api.SetAuthenticator(apiAuthenticator); err != nil {
return err
}
if err := module.RegisterEventHook(
"patrol",
patrol.ChangeSignalEventName,
"trigger hub status maintenance",
func(_ context.Context, _ any) error {
TriggerHubStatusMaintenance()
return nil
},
); err != nil {
return err
}
}
return prepConfig()
}
func start() error {
maskingBytes, err := rng.Bytes(16)
if err != nil {
return fmt.Errorf("failed to get random bytes for masking: %w", err)
}
ships.EnableMasking(maskingBytes)
// Initialize intel.
if err := registerIntelUpdateHook(); err != nil {
return err
}
if err := updateSPNIntel(module.Ctx, nil); err != nil {
log.Errorf("spn/captain: failed to update SPN intel: %s", err)
}
// Initialize identity and piers.
if conf.PublicHub() {
// Load identity.
if err := loadPublicIdentity(); err != nil {
// We cannot recover from this, set controlled failure (do not retry).
modules.SetExitStatusCode(controlledFailureExitCode)
return err
}
// Check if any networks are configured.
if !conf.HubHasIPv4() && !conf.HubHasIPv6() {
// We cannot recover from this, set controlled failure (do not retry).
modules.SetExitStatusCode(controlledFailureExitCode)
return errors.New("no IP addresses for Hub configured (or detected)")
}
// Start management of identity and piers.
if err := prepPublicIdentityMgmt(); err != nil {
return err
}
// Set ID to display on http info page.
ships.DisplayHubID = publicIdentity.ID
// Start listeners.
if err := startPiers(); err != nil {
return err
}
// Enable connect operation.
crew.EnableConnecting(publicIdentity.Hub)
}
// Subscribe to updates of cranes.
startDockHooks()
// bootstrapping
if err := processBootstrapHubFlag(); err != nil {
return err
}
if err := processBootstrapFileFlag(); err != nil {
return err
}
// network optimizer
if conf.PublicHub() {
module.NewTask("optimize network", optimizeNetwork).
Repeat(1 * time.Minute).
Schedule(time.Now().Add(15 * time.Second))
}
// client + home hub manager
if conf.Client() {
module.StartServiceWorker("client manager", 0, clientManager)
// Reset failing hubs when the network changes while not connected.
if err := module.RegisterEventHook(
"netenv",
"network changed",
"reset failing hubs",
func(_ context.Context, _ interface{}) error {
if ready.IsNotSet() {
navigator.Main.ResetFailingStates(module.Ctx)
}
return nil
},
); err != nil {
return err
}
}
return nil
}
func stop() error {
// Reset intel resource so that it is loaded again when starting.
resetSPNIntel()
// Unregister crane update hook.
stopDockHooks()
// Send shutdown status message.
if conf.PublicHub() {
publishShutdownStatus()
stopPiers()
}
return nil
}
// apiAuthenticator grants User permissions for local API requests.
func apiAuthenticator(r *http.Request, s *http.Server) (*api.AuthToken, error) {
// Get remote IP.
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return nil, fmt.Errorf("failed to split host/port: %w", err)
}
remoteIP := net.ParseIP(host)
if remoteIP == nil {
return nil, fmt.Errorf("failed to parse remote address %s", host)
}
if !netutils.GetIPScope(remoteIP).IsLocalhost() {
return nil, api.ErrAPIAccessDeniedMessage
}
return &api.AuthToken{
Read: api.PermitUser,
Write: api.PermitUser,
}, nil
}

306
spn/captain/navigation.go Normal file
View File

@@ -0,0 +1,306 @@
package captain
import (
"context"
"errors"
"fmt"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/service/intel"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/spn/access"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/navigator"
"github.com/safing/portmaster/spn/terminal"
)
const stopCraneAfterBeingUnsuggestedFor = 6 * time.Hour
var (
// ErrAllHomeHubsExcluded is returned when all available home hubs were excluded.
ErrAllHomeHubsExcluded = errors.New("all home hubs are excluded")
// ErrReInitSPNSuggested is returned when no home hub can be found, even without rules.
ErrReInitSPNSuggested = errors.New("SPN re-init suggested")
)
func establishHomeHub(ctx context.Context) error {
// Get own IP.
locations, ok := netenv.GetInternetLocation()
if !ok || len(locations.All) == 0 {
return errors.New("failed to locate own device")
}
log.Debugf(
"spn/captain: looking for new home hub near %s and %s",
locations.BestV4(),
locations.BestV6(),
)
// Get own entity.
// Checking the entity against the entry policies is somewhat hit and miss
// anyway, as the device location is an approximation.
var myEntity *intel.Entity
if dl := locations.BestV4(); dl != nil && dl.IP != nil {
myEntity = (&intel.Entity{IP: dl.IP}).Init(0)
myEntity.FetchData(ctx)
} else if dl := locations.BestV6(); dl != nil && dl.IP != nil {
myEntity = (&intel.Entity{IP: dl.IP}).Init(0)
myEntity.FetchData(ctx)
}
// Get home hub policy for selecting the home hub.
homePolicy, err := getHomeHubPolicy()
if err != nil {
return err
}
// Build navigation options for searching for a home hub.
opts := &navigator.Options{
Home: &navigator.HomeHubOptions{
HubPolicies: []endpoints.Endpoints{homePolicy},
CheckHubPolicyWith: myEntity,
},
}
// Add requirement to only use Safing nodes when not using community nodes.
if !cfgOptionUseCommunityNodes() {
opts.Home.RequireVerifiedOwners = NonCommunityVerifiedOwners
}
// Require a trusted home node when the routing profile requires less than two hops.
routingProfile := navigator.GetRoutingProfile(cfgOptionRoutingAlgorithm())
if routingProfile.MinHops < 2 {
opts.Home.Regard = opts.Home.Regard.Add(navigator.StateTrusted)
}
// Find nearby hubs.
findCandidates:
candidates, err := navigator.Main.FindNearestHubs(
locations.BestV4().LocationOrNil(),
locations.BestV6().LocationOrNil(),
opts, navigator.HomeHub,
)
if err != nil {
switch {
case errors.Is(err, navigator.ErrEmptyMap):
// bootstrap to the network!
err := bootstrapWithUpdates()
if err != nil {
return err
}
goto findCandidates
case errors.Is(err, navigator.ErrAllPinsDisregarded):
if len(homePolicy) > 0 {
return ErrAllHomeHubsExcluded
}
return ErrReInitSPNSuggested
default:
return fmt.Errorf("failed to find nearby hubs: %w", err)
}
}
// Try connecting to a hub.
var tries int
var candidate *hub.Hub
for tries, candidate = range candidates {
err = connectToHomeHub(ctx, candidate)
if err != nil {
// Check if context is canceled.
if ctx.Err() != nil {
return ctx.Err()
}
// Check if the SPN protocol is stopping again.
if errors.Is(err, terminal.ErrStopping) {
return err
}
log.Warningf("spn/captain: failed to connect to %s as new home: %s", candidate, err)
} else {
log.Infof("spn/captain: established connection to %s as new home with %d failed tries", candidate, tries)
return nil
}
}
if err != nil {
return fmt.Errorf("failed to connect to a new home hub - tried %d hubs: %w", tries+1, err)
}
return fmt.Errorf("no home hub candidates available")
}
func connectToHomeHub(ctx context.Context, dst *hub.Hub) error {
// Create new context with timeout.
// The maximum timeout is a worst case safeguard.
// Keep in mind that multiple IPs and protocols may be tried in all configurations.
// Some servers will be (possibly on purpose) hard to reach.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
// Set and clean up exceptions.
setExceptions(dst.Info.IPv4, dst.Info.IPv6)
defer setExceptions(nil, nil)
// Connect to hub.
crane, err := EstablishCrane(ctx, dst)
if err != nil {
return err
}
// Cleanup connection in case of failure.
var success bool
defer func() {
if !success {
crane.Stop(nil)
}
}()
// Query all gossip msgs on first connection.
gossipQuery, tErr := NewGossipQueryOp(crane.Controller)
if tErr != nil {
log.Warningf("spn/captain: failed to start initial gossip query: %s", tErr)
}
// Wait for gossip query to complete.
select {
case <-gossipQuery.ctx.Done():
case <-ctx.Done():
return context.Canceled
}
// Create communication terminal.
homeTerminal, initData, tErr := docks.NewLocalCraneTerminal(crane, nil, terminal.DefaultHomeHubTerminalOpts())
if tErr != nil {
return tErr.Wrap("failed to create home terminal")
}
tErr = crane.EstablishNewTerminal(homeTerminal, initData)
if tErr != nil {
return tErr.Wrap("failed to connect home terminal")
}
if !DisableAccount {
// Authenticate to home hub.
authOp, tErr := access.AuthorizeToTerminal(homeTerminal)
if tErr != nil {
return tErr.Wrap("failed to authorize")
}
select {
case tErr := <-authOp.Result:
if !tErr.Is(terminal.ErrExplicitAck) {
return tErr.Wrap("failed to authenticate to")
}
case <-time.After(3 * time.Second):
return terminal.ErrTimeout.With("waiting for auth to complete")
case <-ctx.Done():
return terminal.ErrStopping
}
}
// Set new home on map.
ok := navigator.Main.SetHome(dst.ID, homeTerminal)
if !ok {
return fmt.Errorf("failed to set home hub on map")
}
// Assign crane to home hub in order to query it later.
docks.AssignCrane(crane.ConnectedHub.ID, crane)
success = true
return nil
}
func optimizeNetwork(ctx context.Context, task *modules.Task) error {
if publicIdentity == nil {
return nil
}
optimize:
result, err := navigator.Main.Optimize(nil)
if err != nil {
if errors.Is(err, navigator.ErrEmptyMap) {
// bootstrap to the network!
err := bootstrapWithUpdates()
if err != nil {
return err
}
goto optimize
}
return err
}
// Create any new connections.
var createdConnections int
var attemptedConnections int
for _, connectTo := range result.SuggestedConnections {
// Skip duplicates.
if connectTo.Duplicate {
continue
}
// Check if connection already exists.
crane := docks.GetAssignedCrane(connectTo.Hub.ID)
if crane != nil {
// Update last suggested timestamp.
crane.NetState.UpdateLastSuggestedAt()
// Continue crane if stopping.
if crane.AbortStopping() {
log.Infof("spn/captain: optimization aborted retiring of %s, removed stopping mark", crane)
crane.NotifyUpdate()
}
// Create new connections if we have connects left.
} else if createdConnections < result.MaxConnect {
attemptedConnections++
crane, tErr := EstablishPublicLane(ctx, connectTo.Hub)
if !tErr.IsOK() {
log.Warningf("spn/captain: failed to establish lane to %s: %s", connectTo.Hub, tErr)
} else {
createdConnections++
crane.NetState.UpdateLastSuggestedAt()
log.Infof("spn/captain: established lane to %s", connectTo.Hub)
}
}
}
// Log optimization result.
if attemptedConnections > 0 {
log.Infof(
"spn/captain: created %d/%d new connections for %s optimization",
createdConnections,
attemptedConnections,
result.Purpose)
} else {
log.Infof(
"spn/captain: checked %d connections for %s optimization",
len(result.SuggestedConnections),
result.Purpose,
)
}
// Retire cranes if unsuggested for a while.
if result.StopOthers {
for _, crane := range docks.GetAllAssignedCranes() {
switch {
case crane.Stopped():
// Crane already stopped.
case crane.IsStopping():
// Crane is stopping, forcibly stop if mine and suggested.
if crane.IsMine() && crane.NetState.StopSuggested() {
crane.Stop(nil)
}
case crane.IsMine() && crane.NetState.StoppingSuggested():
// Mark as stopping if mine and suggested.
crane.MarkStopping()
case crane.NetState.RequestStoppingSuggested(stopCraneAfterBeingUnsuggestedFor):
// Mark as stopping requested.
crane.MarkStoppingRequested()
}
}
}
return nil
}

156
spn/captain/op_gossip.go Normal file
View File

@@ -0,0 +1,156 @@
package captain
import (
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/terminal"
)
// GossipOpType is the type ID of the gossip operation.
const GossipOpType string = "gossip"
// GossipMsgType is the gossip message type.
type GossipMsgType uint8
// Gossip Message Types.
const (
GossipHubAnnouncementMsg GossipMsgType = 1
GossipHubStatusMsg GossipMsgType = 2
)
func (msgType GossipMsgType) String() string {
switch msgType {
case GossipHubAnnouncementMsg:
return "hub announcement"
case GossipHubStatusMsg:
return "hub status"
default:
return "unknown gossip msg"
}
}
// GossipOp is used to gossip Hub messages.
type GossipOp struct {
terminal.OperationBase
craneID string
}
// Type returns the type ID.
func (op *GossipOp) Type() string {
return GossipOpType
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: GossipOpType,
Requires: terminal.IsCraneController,
Start: runGossipOp,
})
}
// NewGossipOp start a new gossip operation.
func NewGossipOp(controller *docks.CraneControllerTerminal) (*GossipOp, *terminal.Error) {
// Create and init.
op := &GossipOp{
craneID: controller.Crane.ID,
}
err := controller.StartOperation(op, nil, 1*time.Minute)
if err != nil {
return nil, err
}
op.InitOperationBase(controller, op.ID())
// Register and return.
registerGossipOp(controller.Crane.ID, op)
return op, nil
}
func runGossipOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Check if we are run by a controller.
controller, ok := t.(*docks.CraneControllerTerminal)
if !ok {
return nil, terminal.ErrIncorrectUsage.With("gossip op may only be started by a crane controller terminal, but was started by %T", t)
}
// Create, init, register and return.
op := &GossipOp{
craneID: controller.Crane.ID,
}
op.InitOperationBase(t, opID)
registerGossipOp(controller.Crane.ID, op)
return op, nil
}
func (op *GossipOp) sendMsg(msgType GossipMsgType, data []byte) {
// Create message.
msg := op.NewEmptyMsg()
msg.Data = container.New(
varint.Pack8(uint8(msgType)),
data,
)
msg.Unit.MakeHighPriority()
// Send.
err := op.Send(msg, 1*time.Second)
if err != nil {
log.Debugf("spn/captain: failed to forward %s via %s: %s", msgType, op.craneID, err)
}
}
// Deliver delivers a message to the operation.
func (op *GossipOp) Deliver(msg *terminal.Msg) *terminal.Error {
defer msg.Finish()
gossipMsgTypeN, err := msg.Data.GetNextN8()
if err != nil {
return terminal.ErrMalformedData.With("failed to parse gossip message type")
}
gossipMsgType := GossipMsgType(gossipMsgTypeN)
// Prepare data.
data := msg.Data.CompileData()
var announcementData, statusData []byte
switch gossipMsgType {
case GossipHubAnnouncementMsg:
announcementData = data
case GossipHubStatusMsg:
statusData = data
default:
log.Warningf("spn/captain: received unknown gossip message type from %s: %d", op.craneID, gossipMsgType)
return nil
}
// Import and verify.
h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope)
if tErr != nil {
if tErr.Is(hub.ErrOldData) {
log.Debugf("spn/captain: ignoring old %s from %s", gossipMsgType, op.craneID)
} else {
log.Warningf("spn/captain: failed to import %s from %s: %s", gossipMsgType, op.craneID, tErr)
}
} else if forward {
// Only log if we received something to save/forward.
log.Infof("spn/captain: received %s for %s", gossipMsgType, h)
}
// Relay data.
if forward {
gossipRelayMsg(op.craneID, gossipMsgType, data)
}
return nil
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *GossipOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
deleteGossipOp(op.craneID)
return err
}

View File

@@ -0,0 +1,195 @@
package captain
import (
"context"
"strings"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/terminal"
)
// GossipQueryOpType is the type ID of the gossip query operation.
const GossipQueryOpType string = "gossip/query"
// GossipQueryOp is used to query gossip messages.
type GossipQueryOp struct {
terminal.OperationBase
t terminal.Terminal
client bool
importCnt int
ctx context.Context
cancelCtx context.CancelFunc
}
// Type returns the type ID.
func (op *GossipQueryOp) Type() string {
return GossipQueryOpType
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: GossipQueryOpType,
Requires: terminal.IsCraneController,
Start: runGossipQueryOp,
})
}
// NewGossipQueryOp starts a new gossip query operation.
func NewGossipQueryOp(t terminal.Terminal) (*GossipQueryOp, *terminal.Error) {
// Create and init.
op := &GossipQueryOp{
t: t,
client: true,
}
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
err := t.StartOperation(op, nil, 1*time.Minute)
if err != nil {
return nil, err
}
return op, nil
}
func runGossipQueryOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Create, init, register and return.
op := &GossipQueryOp{t: t}
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
op.InitOperationBase(t, opID)
module.StartWorker("gossip query handler", op.handler)
return op, nil
}
func (op *GossipQueryOp) handler(_ context.Context) error {
tErr := op.sendMsgs(hub.MsgTypeAnnouncement)
if tErr != nil {
op.Stop(op, tErr)
return nil // Clean worker exit.
}
tErr = op.sendMsgs(hub.MsgTypeStatus)
if tErr != nil {
op.Stop(op, tErr)
return nil // Clean worker exit.
}
op.Stop(op, nil)
return nil // Clean worker exit.
}
func (op *GossipQueryOp) sendMsgs(msgType hub.MsgType) *terminal.Error {
it, err := hub.QueryRawGossipMsgs(conf.MainMapName, msgType)
if err != nil {
return terminal.ErrInternalError.With("failed to query: %w", err)
}
defer it.Cancel()
iterating:
for {
select {
case r := <-it.Next:
// Check if we are done.
if r == nil {
return nil
}
// Ensure we're handling a hub msg.
hubMsg, err := hub.EnsureHubMsg(r)
if err != nil {
log.Warningf("spn/captain: failed to load hub msg: %s", err)
continue iterating
}
// Create gossip msg.
var c *container.Container
switch hubMsg.Type {
case hub.MsgTypeAnnouncement:
c = container.New(
varint.Pack8(uint8(GossipHubAnnouncementMsg)),
hubMsg.Data,
)
case hub.MsgTypeStatus:
c = container.New(
varint.Pack8(uint8(GossipHubStatusMsg)),
hubMsg.Data,
)
default:
log.Warningf("spn/captain: unknown hub msg for gossip query at %q: %s", hubMsg.Key(), hubMsg.Type)
}
// Send msg.
if c != nil {
msg := op.NewEmptyMsg()
msg.Unit.MakeHighPriority()
msg.Data = c
tErr := op.Send(msg, 1*time.Second)
if tErr != nil {
return tErr.Wrap("failed to send msg")
}
}
case <-op.ctx.Done():
return terminal.ErrStopping
}
}
}
// Deliver delivers the message to the operation.
func (op *GossipQueryOp) Deliver(msg *terminal.Msg) *terminal.Error {
defer msg.Finish()
gossipMsgTypeN, err := msg.Data.GetNextN8()
if err != nil {
return terminal.ErrMalformedData.With("failed to parse gossip message type")
}
gossipMsgType := GossipMsgType(gossipMsgTypeN)
// Prepare data.
data := msg.Data.CompileData()
var announcementData, statusData []byte
switch gossipMsgType {
case GossipHubAnnouncementMsg:
announcementData = data
case GossipHubStatusMsg:
statusData = data
default:
log.Warningf("spn/captain: received unknown gossip message type from gossip query: %d", gossipMsgType)
return nil
}
// Import and verify.
h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope)
if tErr != nil {
log.Warningf("spn/captain: failed to import %s from gossip query: %s", gossipMsgType, tErr)
} else {
log.Infof("spn/captain: received %s for %s from gossip query", gossipMsgType, h)
op.importCnt++
}
// Relay data.
if forward {
// TODO: Find better way to get craneID.
craneID := strings.SplitN(op.t.FmtID(), "#", 2)[0]
gossipRelayMsg(craneID, gossipMsgType, data)
}
return nil
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *GossipQueryOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
if op.client {
log.Infof("spn/captain: gossip query imported %d entries", op.importCnt)
}
op.cancelCtx()
return err
}

183
spn/captain/op_publish.go Normal file
View File

@@ -0,0 +1,183 @@
package captain
import (
"time"
"github.com/safing/portbase/container"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/terminal"
)
// PublishOpType is the type ID of the publish operation.
const PublishOpType string = "publish"
// PublishOp is used to publish a connection.
type PublishOp struct {
terminal.OperationBase
controller *docks.CraneControllerTerminal
identity *cabin.Identity
requestingHub *hub.Hub
verification *cabin.Verification
result chan *terminal.Error
}
// Type returns the type ID.
func (op *PublishOp) Type() string {
return PublishOpType
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: PublishOpType,
Requires: terminal.IsCraneController,
Start: runPublishOp,
})
}
// NewPublishOp start a new publish operation.
func NewPublishOp(controller *docks.CraneControllerTerminal, identity *cabin.Identity) (*PublishOp, *terminal.Error) {
// Create and init.
op := &PublishOp{
controller: controller,
identity: identity,
result: make(chan *terminal.Error, 1),
}
msg := container.New()
// Add Hub Announcement.
announcementData, err := identity.ExportAnnouncement()
if err != nil {
return nil, terminal.ErrInternalError.With("failed to export announcement: %w", err)
}
msg.AppendAsBlock(announcementData)
// Add Hub Status.
statusData, err := identity.ExportStatus()
if err != nil {
return nil, terminal.ErrInternalError.With("failed to export status: %w", err)
}
msg.AppendAsBlock(statusData)
tErr := controller.StartOperation(op, msg, 10*time.Second)
if tErr != nil {
return nil, tErr
}
return op, nil
}
func runPublishOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Check if we are run by a controller.
controller, ok := t.(*docks.CraneControllerTerminal)
if !ok {
return nil, terminal.ErrIncorrectUsage.With("publish op may only be started by a crane controller terminal, but was started by %T", t)
}
// Parse and import Announcement and Status.
announcementData, err := data.GetNextBlock()
if err != nil {
return nil, terminal.ErrMalformedData.With("failed to get announcement: %w", err)
}
statusData, err := data.GetNextBlock()
if err != nil {
return nil, terminal.ErrMalformedData.With("failed to get status: %w", err)
}
h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope)
if tErr != nil {
return nil, tErr.Wrap("failed to import and verify hub")
}
// Update reference in case it was changed by the import.
controller.Crane.ConnectedHub = h
// Relay data.
if forward {
gossipRelayMsg(controller.Crane.ID, GossipHubAnnouncementMsg, announcementData)
gossipRelayMsg(controller.Crane.ID, GossipHubStatusMsg, statusData)
}
// Create verification request.
v, request, err := cabin.CreateVerificationRequest(PublishOpType, "", "")
if err != nil {
return nil, terminal.ErrInternalError.With("failed to create verification request: %w", err)
}
// Create operation.
op := &PublishOp{
controller: controller,
requestingHub: h,
verification: v,
result: make(chan *terminal.Error, 1),
}
op.InitOperationBase(controller, opID)
// Reply with verification request.
tErr = op.Send(op.NewMsg(request), 10*time.Second)
if tErr != nil {
return nil, tErr.Wrap("failed to send verification request")
}
return op, nil
}
// Deliver delivers a message to the operation.
func (op *PublishOp) Deliver(msg *terminal.Msg) *terminal.Error {
defer msg.Finish()
if op.identity != nil {
// Client
// Sign the received verification request.
response, err := op.identity.SignVerificationRequest(msg.Data.CompileData(), PublishOpType, "", "")
if err != nil {
return terminal.ErrPermissionDenied.With("signing verification request failed: %w", err)
}
return op.Send(op.NewMsg(response), 10*time.Second)
} else if op.requestingHub != nil {
// Server
// Verify the signed request.
err := op.verification.Verify(msg.Data.CompileData(), op.requestingHub)
if err != nil {
return terminal.ErrPermissionDenied.With("checking verification request failed: %w", err)
}
return terminal.ErrExplicitAck
}
return terminal.ErrInternalError.With("invalid operation state")
}
// Result returns the result (end error) of the operation.
func (op *PublishOp) Result() <-chan *terminal.Error {
return op.result
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *PublishOp) HandleStop(tErr *terminal.Error) (errorToSend *terminal.Error) {
if tErr.Is(terminal.ErrExplicitAck) {
// TODO: Check for concurrenct access.
if op.controller.Crane.ConnectedHub == nil {
op.controller.Crane.ConnectedHub = op.requestingHub
}
// Publish crane, abort if it fails.
err := op.controller.Crane.Publish()
if err != nil {
tErr = terminal.ErrInternalError.With("failed to publish crane: %w", err)
op.controller.Crane.Stop(tErr)
} else {
op.controller.Crane.NotifyUpdate()
}
}
select {
case op.result <- tErr:
default:
}
return tErr
}

131
spn/captain/piers.go Normal file
View File

@@ -0,0 +1,131 @@
package captain
import (
"context"
"errors"
"fmt"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/intel"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/ships"
)
var (
dockingRequests = make(chan ships.Ship, 100)
piers []ships.Pier
)
func startPiers() error {
// Get and check transports.
transports := publicIdentity.Hub.Info.Transports
if len(transports) == 0 {
return errors.New("no transports defined")
}
piers = make([]ships.Pier, 0, len(transports))
for _, t := range transports {
// Parse transport.
transport, err := hub.ParseTransport(t)
if err != nil {
return fmt.Errorf("cannot build pier for invalid transport %q: %w", t, err)
}
// Establish pier / listener.
pier, err := ships.EstablishPier(transport, dockingRequests)
if err != nil {
return fmt.Errorf("failed to establish pier for transport %q: %w", t, err)
}
piers = append(piers, pier)
log.Infof("spn/captain: pier for transport %q built", t)
}
// Start worker to handle docking requests.
module.StartServiceWorker("docking request handler", 0, dockingRequestHandler)
return nil
}
func stopPiers() {
for _, pier := range piers {
pier.Abolish()
}
}
func dockingRequestHandler(ctx context.Context) error {
// Sink all waiting ships when this worker ends.
// But don't be destructive so the service worker could recover.
defer func() {
for {
select {
case ship := <-dockingRequests:
if ship != nil {
ship.Sink()
}
default:
return
}
}
}()
for {
select {
case <-ctx.Done():
return nil
case ship := <-dockingRequests:
// Ignore nil ships.
if ship == nil {
continue
}
if err := checkDockingPermission(ctx, ship); err != nil {
log.Warningf("spn/captain: denied ship from %s to dock at pier %s: %s", ship.RemoteAddr(), ship.Transport().String(), err)
} else {
handleDockingRequest(ship)
}
}
}
}
func checkDockingPermission(ctx context.Context, ship ships.Ship) error {
remoteIP, remotePort, err := netutils.IPPortFromAddr(ship.RemoteAddr())
if err != nil {
return fmt.Errorf("failed to parse remote IP: %w", err)
}
// Create entity.
entity := (&intel.Entity{
IP: remoteIP,
Protocol: uint8(netutils.ProtocolFromNetwork(ship.RemoteAddr().Network())),
Port: remotePort,
}).Init(ship.Transport().Port)
entity.FetchData(ctx)
// Check against policy.
result, reason := publicIdentity.Hub.GetInfo().EntryPolicy().Match(ctx, entity)
if result == endpoints.Denied {
return fmt.Errorf("entry policy violated: %s", reason)
}
return nil
}
func handleDockingRequest(ship ships.Ship) {
log.Infof("spn/captain: pemitting %s to dock", ship)
crane, err := docks.NewCrane(ship, nil, publicIdentity)
if err != nil {
log.Warningf("spn/captain: failed to commission crane for %s: %s", ship, err)
return
}
module.StartWorker("start crane", func(ctx context.Context) error {
_ = crane.Start(ctx)
// Crane handles errors internally.
return nil
})
}

247
spn/captain/public.go Normal file
View File

@@ -0,0 +1,247 @@
package captain
import (
"context"
"errors"
"fmt"
"sort"
"time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/log"
"github.com/safing/portbase/metrics"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/navigator"
"github.com/safing/portmaster/spn/patrol"
)
const (
maintainStatusInterval = 15 * time.Minute
maintainStatusUpdateDelay = 5 * time.Second
)
var (
publicIdentity *cabin.Identity
publicIdentityKey = "core:spn/public/identity"
publicIdentityUpdateTask *modules.Task
statusUpdateTask *modules.Task
)
func loadPublicIdentity() (err error) {
var changed bool
publicIdentity, changed, err = cabin.LoadIdentity(publicIdentityKey)
switch {
case err == nil:
// load was successful
log.Infof("spn/captain: loaded public hub identity %s", publicIdentity.Hub.ID)
case errors.Is(err, database.ErrNotFound):
// does not exist, create new
publicIdentity, err = cabin.CreateIdentity(module.Ctx, conf.MainMapName)
if err != nil {
return fmt.Errorf("failed to create new identity: %w", err)
}
publicIdentity.SetKey(publicIdentityKey)
changed = true
log.Infof("spn/captain: created new public hub identity %s", publicIdentity.ID)
default:
// loading error, abort
return fmt.Errorf("failed to load public identity: %w", err)
}
// Save to database if the identity changed.
if changed {
err = publicIdentity.Save()
if err != nil {
return fmt.Errorf("failed to save new/updated identity to database: %w", err)
}
}
// Set available networks.
conf.SetHubNetworks(
publicIdentity.Hub.Info.IPv4 != nil,
publicIdentity.Hub.Info.IPv6 != nil,
)
if cfgOptionBindToAdvertised() {
conf.SetBindAddr(publicIdentity.Hub.Info.IPv4, publicIdentity.Hub.Info.IPv6)
}
// Set Home Hub before updating the hub on the map, as this would trigger a
// recalculation without a Home Hub.
ok := navigator.Main.SetHome(publicIdentity.ID, nil)
// Always update the navigator in any case in order to sync the reference to
// the active struct of the identity.
navigator.Main.UpdateHub(publicIdentity.Hub)
// Setting the Home Hub will have failed if the identidy was only just
// created - try again if it failed.
if !ok {
ok = navigator.Main.SetHome(publicIdentity.ID, nil)
if !ok {
return errors.New("failed to set self as home hub")
}
}
return nil
}
func prepPublicIdentityMgmt() error {
publicIdentityUpdateTask = module.NewTask(
"maintain public identity",
maintainPublicIdentity,
)
statusUpdateTask = module.NewTask(
"maintain public status",
maintainPublicStatus,
).Repeat(maintainStatusInterval)
return module.RegisterEventHook(
"config",
"config change",
"update public identity from config",
func(_ context.Context, _ interface{}) error {
// trigger update in 5 minutes
publicIdentityUpdateTask.Schedule(time.Now().Add(5 * time.Minute))
return nil
},
)
}
// TriggerHubStatusMaintenance queues the Hub status update task to be executed.
func TriggerHubStatusMaintenance() {
if statusUpdateTask != nil {
statusUpdateTask.Queue()
}
}
func maintainPublicIdentity(ctx context.Context, task *modules.Task) error {
changed, err := publicIdentity.MaintainAnnouncement(nil, false)
if err != nil {
return fmt.Errorf("failed to maintain announcement: %w", err)
}
if !changed {
return nil
}
// Update on map.
navigator.Main.UpdateHub(publicIdentity.Hub)
log.Debug("spn/captain: updated own hub on map after announcement change")
// export announcement
announcementData, err := publicIdentity.ExportAnnouncement()
if err != nil {
return fmt.Errorf("failed to export announcement: %w", err)
}
// forward to other connected Hubs
gossipRelayMsg("", GossipHubAnnouncementMsg, announcementData)
return nil
}
func maintainPublicStatus(ctx context.Context, task *modules.Task) error {
// Get current lanes.
cranes := docks.GetAllAssignedCranes()
lanes := make([]*hub.Lane, 0, len(cranes))
for _, crane := range cranes {
// Ignore private, stopped or stopping cranes.
if !crane.Public() || crane.Stopped() || crane.IsStopping() {
continue
}
// Get measurements.
measurements := crane.ConnectedHub.GetMeasurements()
latency, _ := measurements.GetLatency()
capacity, _ := measurements.GetCapacity()
// Add crane lane.
lanes = append(lanes, &hub.Lane{
ID: crane.ConnectedHub.ID,
Latency: latency,
Capacity: capacity,
})
}
// Sort Lanes for comparing.
hub.SortLanes(lanes)
// Get system load and convert to fixed steps.
var load int
loadAvg, ok := metrics.LoadAvg15()
switch {
case !ok:
load = -1
case loadAvg >= 1:
load = 100
case loadAvg >= 0.95:
load = 95
case loadAvg >= 0.8:
load = 80
default:
load = 0
}
if loadAvg >= 0.8 {
log.Warningf("spn/captain: publishing 15m system load average of %.2f as %d", loadAvg, load)
}
// Set flags.
var flags []string
if !patrol.HTTPSConnectivityConfirmed() {
flags = append(flags, hub.FlagNetError)
}
// Sort Lanes for comparing.
sort.Strings(flags)
// Run maintenance with the new data.
changed, err := publicIdentity.MaintainStatus(lanes, &load, flags, false)
if err != nil {
return fmt.Errorf("failed to maintain status: %w", err)
}
if !changed {
return nil
}
// Update on map.
navigator.Main.UpdateHub(publicIdentity.Hub)
log.Debug("spn/captain: updated own hub on map after status change")
// export status
statusData, err := publicIdentity.ExportStatus()
if err != nil {
return fmt.Errorf("failed to export status: %w", err)
}
// forward to other connected Hubs
gossipRelayMsg("", GossipHubStatusMsg, statusData)
log.Infof(
"spn/captain: updated status with load %d and current lanes: %v",
publicIdentity.Hub.Status.Load,
publicIdentity.Hub.Status.Lanes,
)
return nil
}
func publishShutdownStatus() {
// Create offline status.
offlineStatusData, err := publicIdentity.MakeOfflineStatus()
if err != nil {
log.Errorf("spn/captain: failed to create offline status: %s", err)
return
}
// Forward to other connected Hubs.
gossipRelayMsg("", GossipHubStatusMsg, offlineStatusData)
// Leave some time for the message to broadcast.
time.Sleep(2 * time.Second)
log.Infof("spn/captain: broadcasted offline status")
}

154
spn/captain/status.go Normal file
View File

@@ -0,0 +1,154 @@
package captain
import (
"fmt"
"sort"
"sync"
"time"
"github.com/safing/portbase/config"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/runtime"
"github.com/safing/portbase/utils/debug"
"github.com/safing/portmaster/service/intel/geoip"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/navigator"
)
// SPNStatus holds SPN status information.
type SPNStatus struct {
record.Base
sync.Mutex
Status SPNStatusName
HomeHubID string
HomeHubName string
ConnectedIP string
ConnectedTransport string
ConnectedCountry *geoip.CountryInfo
ConnectedSince *time.Time
}
// SPNStatusName is a SPN status.
type SPNStatusName string
// SPN Stati.
const (
StatusFailed SPNStatusName = "failed"
StatusDisabled SPNStatusName = "disabled"
StatusConnecting SPNStatusName = "connecting"
StatusConnected SPNStatusName = "connected"
)
var (
spnStatus = &SPNStatus{
Status: StatusDisabled,
}
spnStatusPushFunc runtime.PushFunc
)
func registerSPNStatusProvider() (err error) {
spnStatus.SetKey("runtime:spn/status")
spnStatus.UpdateMeta()
spnStatusPushFunc, err = runtime.Register("spn/status", runtime.ProvideRecord(spnStatus))
return
}
func resetSPNStatus(statusName SPNStatusName, overrideEvenIfConnected bool) {
// Lock for updating values.
spnStatus.Lock()
defer spnStatus.Unlock()
// Ignore when connected and not overriding
if !overrideEvenIfConnected && spnStatus.Status == StatusConnected {
return
}
// Reset status.
spnStatus.Status = statusName
spnStatus.HomeHubID = ""
spnStatus.HomeHubName = ""
spnStatus.ConnectedIP = ""
spnStatus.ConnectedTransport = ""
spnStatus.ConnectedCountry = nil
spnStatus.ConnectedSince = nil
// Push new status.
pushSPNStatusUpdate()
}
// pushSPNStatusUpdate pushes an update of spnStatus, which must be locked.
func pushSPNStatusUpdate() {
spnStatus.UpdateMeta()
spnStatusPushFunc(spnStatus)
}
// GetSPNStatus returns the current SPN status.
func GetSPNStatus() *SPNStatus {
spnStatus.Lock()
defer spnStatus.Unlock()
return &SPNStatus{
Status: spnStatus.Status,
HomeHubID: spnStatus.HomeHubID,
HomeHubName: spnStatus.HomeHubName,
ConnectedIP: spnStatus.ConnectedIP,
ConnectedTransport: spnStatus.ConnectedTransport,
ConnectedCountry: spnStatus.ConnectedCountry,
ConnectedSince: spnStatus.ConnectedSince,
}
}
// AddToDebugInfo adds the SPN status to the given debug.Info.
func AddToDebugInfo(di *debug.Info) {
spnStatus.Lock()
defer spnStatus.Unlock()
// Check if SPN module is enabled.
var moduleStatus string
spnEnabled := config.GetAsBool(CfgOptionEnableSPNKey, false)
if spnEnabled() {
moduleStatus = "enabled"
} else {
moduleStatus = "disabled"
}
// Collect status data.
lines := make([]string, 0, 20)
lines = append(lines, fmt.Sprintf("HomeHubID: %v", spnStatus.HomeHubID))
lines = append(lines, fmt.Sprintf("HomeHubName: %v", spnStatus.HomeHubName))
lines = append(lines, fmt.Sprintf("HomeHubIP: %v", spnStatus.ConnectedIP))
lines = append(lines, fmt.Sprintf("Transport: %v", spnStatus.ConnectedTransport))
if spnStatus.ConnectedSince != nil {
lines = append(lines, fmt.Sprintf("Connected: %v ago", time.Since(*spnStatus.ConnectedSince).Round(time.Minute)))
}
lines = append(lines, "---")
lines = append(lines, fmt.Sprintf("Client: %v", conf.Client()))
lines = append(lines, fmt.Sprintf("PublicHub: %v", conf.PublicHub()))
lines = append(lines, fmt.Sprintf("HubHasIPv4: %v", conf.HubHasIPv4()))
lines = append(lines, fmt.Sprintf("HubHasIPv6: %v", conf.HubHasIPv6()))
// Collect status data of map.
if navigator.Main != nil {
lines = append(lines, "---")
mainMapStats := navigator.Main.Stats()
lines = append(lines, fmt.Sprintf("Map %s:", navigator.Main.Name))
lines = append(lines, fmt.Sprintf("Active Terminals: %d Hubs", mainMapStats.ActiveTerminals))
// Collect hub states.
mapStateSummary := make([]string, 0, len(mainMapStats.States))
for state, cnt := range mainMapStats.States {
if cnt > 0 {
mapStateSummary = append(mapStateSummary, fmt.Sprintf("State %s: %d Hubs", state, cnt))
}
}
sort.Strings(mapStateSummary)
lines = append(lines, mapStateSummary...)
}
// Add all data as section.
di.AddSection(
fmt.Sprintf("SPN: %s (module %s)", spnStatus.Status, moduleStatus),
debug.UseCodeSection|debug.AddContentLineBreaks,
lines...,
)
}

17
spn/conf/map.go Normal file
View File

@@ -0,0 +1,17 @@
package conf
import (
"flag"
"github.com/safing/portmaster/spn/hub"
)
// Primary Map Configuration.
var (
MainMapName = "main"
MainMapScope = hub.ScopePublic
)
func init() {
flag.StringVar(&MainMapName, "spn-map", "main", "set main SPN map - use only for testing")
}

30
spn/conf/mode.go Normal file
View File

@@ -0,0 +1,30 @@
package conf
import (
"github.com/tevino/abool"
)
var (
publicHub = abool.New()
client = abool.New()
)
// PublicHub returns whether this is a public Hub.
func PublicHub() bool {
return publicHub.IsSet()
}
// EnablePublicHub enables the public hub mode.
func EnablePublicHub(enable bool) {
publicHub.SetTo(enable)
}
// Client returns whether this is a client.
func Client() bool {
return client.IsSet()
}
// EnableClient enables the client mode.
func EnableClient(enable bool) {
client.SetTo(enable)
}

110
spn/conf/networks.go Normal file
View File

@@ -0,0 +1,110 @@
package conf
import (
"net"
"sync"
"github.com/tevino/abool"
)
var (
hubHasV4 = abool.New()
hubHasV6 = abool.New()
)
// SetHubNetworks sets the available IP networks on the Hub.
func SetHubNetworks(v4, v6 bool) {
hubHasV4.SetTo(v4)
hubHasV6.SetTo(v6)
}
// HubHasIPv4 returns whether the Hub has IPv4 support.
func HubHasIPv4() bool {
return hubHasV4.IsSet()
}
// HubHasIPv6 returns whether the Hub has IPv6 support.
func HubHasIPv6() bool {
return hubHasV6.IsSet()
}
var (
bindIPv4 net.IP
bindIPv6 net.IP
bindIPLock sync.Mutex
)
// SetBindAddr sets the preferred connect (bind) addresses.
func SetBindAddr(ip4, ip6 net.IP) {
bindIPLock.Lock()
defer bindIPLock.Unlock()
bindIPv4 = ip4
bindIPv6 = ip6
}
// BindAddrIsSet returns whether any bind address is set.
func BindAddrIsSet() bool {
bindIPLock.Lock()
defer bindIPLock.Unlock()
return bindIPv4 != nil || bindIPv6 != nil
}
// GetBindAddr returns an address with the preferred binding address for the
// given dial network.
// The dial network must have a suffix specifying the IP version.
func GetBindAddr(dialNetwork string) net.Addr {
bindIPLock.Lock()
defer bindIPLock.Unlock()
switch dialNetwork {
case "ip4":
if bindIPv4 != nil {
return &net.IPAddr{IP: bindIPv4}
}
case "ip6":
if bindIPv6 != nil {
return &net.IPAddr{IP: bindIPv6}
}
case "tcp4":
if bindIPv4 != nil {
return &net.TCPAddr{IP: bindIPv4}
}
case "tcp6":
if bindIPv6 != nil {
return &net.TCPAddr{IP: bindIPv6}
}
case "udp4":
if bindIPv4 != nil {
return &net.UDPAddr{IP: bindIPv4}
}
case "udp6":
if bindIPv6 != nil {
return &net.UDPAddr{IP: bindIPv6}
}
}
return nil
}
// GetBindIPs returns the preferred binding IPs.
// Returns a slice with a single nil IP if no preferred binding IPs are set.
func GetBindIPs() []net.IP {
bindIPLock.Lock()
defer bindIPLock.Unlock()
switch {
case bindIPv4 == nil && bindIPv6 == nil:
// Match most common case first.
return []net.IP{nil}
case bindIPv4 != nil && bindIPv6 != nil:
return []net.IP{bindIPv4, bindIPv6}
case bindIPv4 != nil:
return []net.IP{bindIPv4}
case bindIPv6 != nil:
return []net.IP{bindIPv6}
}
return []net.IP{nil}
}

9
spn/conf/version.go Normal file
View File

@@ -0,0 +1,9 @@
package conf
const (
// VersionOne is the first protocol version.
VersionOne = 1
// CurrentVersion always holds the newest version in production.
CurrentVersion = 1
)

482
spn/crew/connect.go Normal file
View File

@@ -0,0 +1,482 @@
package crew
import (
"context"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/spn/access"
"github.com/safing/portmaster/spn/docks"
"github.com/safing/portmaster/spn/navigator"
"github.com/safing/portmaster/spn/terminal"
)
// connectLock locks all routing operations to mitigate racy stuff for now.
// TODO: Find a nice way to parallelize route creation.
var connectLock sync.Mutex
// HandleSluiceRequest handles a sluice request to build a tunnel.
func HandleSluiceRequest(connInfo *network.Connection, conn net.Conn) {
if conn == nil {
log.Debugf("spn/crew: closing tunnel for %s before starting because of shutdown", connInfo)
// This is called within the connInfo lock.
connInfo.Failed("tunnel entry closed", "")
connInfo.SaveWhenFinished()
return
}
t := &Tunnel{
connInfo: connInfo,
conn: conn,
}
module.StartWorker("tunnel handler", t.connectWorker)
}
// Tunnel represents the local information and endpoint of a data tunnel.
type Tunnel struct {
connInfo *network.Connection
conn net.Conn
dstPin *navigator.Pin
dstTerminal terminal.Terminal
route *navigator.Route
failedTries int
stickied bool
}
func (t *Tunnel) connectWorker(ctx context.Context) (err error) {
// Get tracing logger.
ctx, tracer := log.AddTracer(ctx)
defer tracer.Submit()
// Save start time.
started := time.Now()
// Check the status of the Home Hub.
home, homeTerminal := navigator.Main.GetHome()
if home == nil || homeTerminal == nil || homeTerminal.IsBeingAbandoned() {
reportConnectError(terminal.ErrUnknownError.With("home terminal is abandoned"))
t.connInfo.Lock()
defer t.connInfo.Unlock()
t.connInfo.Failed("SPN not ready for tunneling", "")
t.connInfo.Save()
tracer.Infof("spn/crew: not tunneling %s, as the SPN is not ready", t.connInfo)
return nil
}
// Create path through the SPN.
err = t.establish(ctx)
if err != nil {
log.Warningf("spn/crew: failed to establish route for %s: %s", t.connInfo, err)
// TODO: Clean this up.
t.connInfo.Lock()
defer t.connInfo.Unlock()
t.connInfo.Failed(fmt.Sprintf("SPN failed to establish route: %s", err), "")
t.connInfo.Save()
tracer.Warningf("spn/crew: failed to establish route for %s: %s", t.connInfo, err)
return nil
}
// Connect via established tunnel.
_, tErr := NewConnectOp(t)
if tErr != nil {
tErr = tErr.Wrap("failed to initialize tunnel")
reportConnectError(tErr)
t.connInfo.Lock()
defer t.connInfo.Unlock()
t.connInfo.Failed(fmt.Sprintf("SPN failed to initialize data tunnel (connect op): %s", tErr.Error()), "")
t.connInfo.Save()
// TODO: try with another route?
tracer.Warningf("spn/crew: failed to initialize data tunnel (connect op) for %s: %s", t.connInfo, err)
return tErr
}
// Report time taken to find, build and check route and send connect request.
connectOpTTCRDurationHistogram.UpdateDuration(started)
t.connInfo.Lock()
defer t.connInfo.Unlock()
addTunnelContextToConnection(t)
t.connInfo.Save()
tracer.Infof("spn/crew: connected %s via %s", t.connInfo, t.dstPin.Hub)
return nil
}
func (t *Tunnel) establish(ctx context.Context) (err error) {
var routes *navigator.Routes
// Check if the destination sticks to a Hub.
sticksTo := getStickiedHub(t.connInfo)
switch {
case sticksTo == nil:
// Continue.
case sticksTo.Avoid:
log.Tracer(ctx).Tracef("spn/crew: avoiding %s", sticksTo.Pin.Hub)
// Avoid this Hub.
// TODO: Remember more than one hub to avoid.
avoidPolicy := []endpoints.Endpoint{
&endpoints.EndpointDomain{
OriginalValue: sticksTo.Pin.Hub.ID,
Domain: strings.ToLower(sticksTo.Pin.Hub.ID) + ".",
},
}
// Append to policies.
t.connInfo.TunnelOpts.Destination.HubPolicies = append(t.connInfo.TunnelOpts.Destination.HubPolicies, avoidPolicy)
default:
log.Tracer(ctx).Tracef("spn/crew: using stickied %s", sticksTo.Pin.Hub)
// Check if the stickied Hub has an active terminal.
dstTerminal := sticksTo.Pin.GetActiveTerminal()
if dstTerminal != nil {
t.dstPin = sticksTo.Pin
t.dstTerminal = dstTerminal
t.route = sticksTo.Route
t.stickied = true
return nil
}
// If not, attempt to find a route to the stickied hub.
routes, err = navigator.Main.FindRouteToHub(
sticksTo.Pin.Hub.ID,
t.connInfo.TunnelOpts,
)
if err != nil {
log.Tracer(ctx).Tracef("spn/crew: failed to find route to stickied %s: %s", sticksTo.Pin.Hub, err)
routes = nil
} else {
t.stickied = true
}
}
// Find possible routes to destination.
if routes == nil {
log.Tracer(ctx).Trace("spn/crew: finding routes...")
routes, err = navigator.Main.FindRoutes(
t.connInfo.Entity.IP,
t.connInfo.TunnelOpts,
)
if err != nil {
return fmt.Errorf("failed to find routes to %s: %w", t.connInfo.Entity.IP, err)
}
}
// Check if routes are okay (again).
if len(routes.All) == 0 {
return fmt.Errorf("no routes to %s", t.connInfo.Entity.IP)
}
// Try routes until one succeeds.
log.Tracer(ctx).Trace("spn/crew: establishing route...")
var dstPin *navigator.Pin
var dstTerminal terminal.Terminal
for tries, route := range routes.All {
dstPin, dstTerminal, err = establishRoute(route)
if err != nil {
continue
}
// Assign route data to tunnel.
t.dstPin = dstPin
t.dstTerminal = dstTerminal
t.route = route
t.failedTries = tries
// Push changes to Pins and return.
navigator.Main.PushPinChanges()
return nil
}
return fmt.Errorf("failed to establish a route to %s: %w", t.connInfo.Entity.IP, err)
}
type hopCheck struct {
pin *navigator.Pin
route *navigator.Route
expansion *docks.ExpansionTerminal
authOp *access.AuthorizeOp
pingOp *PingOp
}
func establishRoute(route *navigator.Route) (dstPin *navigator.Pin, dstTerminal terminal.Terminal, err error) {
connectLock.Lock()
defer connectLock.Unlock()
// Check for path length.
if len(route.Path) < 1 {
return nil, nil, errors.New("path too short")
}
// Check for failing hubs in path.
for _, hop := range route.Path[1:] {
if hop.Pin().GetState().Has(navigator.StateFailing) {
return nil, nil, fmt.Errorf("failing hub in path: %s", hop.Pin().Hub.Name())
}
}
// Get home hub.
previousHop, homeTerminal := navigator.Main.GetHome()
if previousHop == nil || homeTerminal == nil {
return nil, nil, navigator.ErrHomeHubUnset
}
// Convert to interface for later use.
var previousTerminal terminal.Terminal = homeTerminal
// Check if first hub in path is the home hub.
if route.Path[0].HubID != previousHop.Hub.ID {
return nil, nil, errors.New("path start does not match home hub")
}
// Check if path only exists of home hub.
if len(route.Path) == 1 {
return previousHop, previousTerminal, nil
}
// TODO: Check what needs locking.
// Build path and save created paths.
hopChecks := make([]*hopCheck, 0, len(route.Path)-1)
for i, hop := range route.Path[1:] {
// Check if we already have a connection to the Hub.
activeTerminal := hop.Pin().GetActiveTerminal()
if activeTerminal != nil {
// Ping terminal if not recently checked.
if activeTerminal.NeedsReachableCheck(1 * time.Minute) {
pingOp, tErr := NewPingOp(activeTerminal)
if tErr.IsError() {
return nil, nil, tErr.Wrap("failed start ping to %s", hop.Pin())
}
// Add for checking results later.
hopChecks = append(hopChecks, &hopCheck{
pin: hop.Pin(),
route: route.CopyUpTo(i + 2),
expansion: activeTerminal,
pingOp: pingOp,
})
}
previousHop = hop.Pin()
previousTerminal = activeTerminal
continue
}
// Expand to next Hub.
expansion, authOp, tErr := expand(previousTerminal, previousHop, hop.Pin())
if tErr != nil {
return nil, nil, tErr.Wrap("failed to expand to %s", hop.Pin())
}
// Add for checking results later.
hopChecks = append(hopChecks, &hopCheck{
pin: hop.Pin(),
route: route.CopyUpTo(i + 2),
expansion: expansion,
authOp: authOp,
})
// Save previous pin for next loop or end.
previousHop = hop.Pin()
previousTerminal = expansion
}
// Check results.
for _, check := range hopChecks {
switch {
case check.authOp != nil:
// Wait for authOp result.
select {
case tErr := <-check.authOp.Result:
switch {
case tErr.IsError():
// There was a network or authentication error.
check.pin.MarkAsFailingFor(3 * time.Minute)
log.Warningf("spn/crew: failed to auth to %s: %s", check.pin.Hub, tErr)
return nil, nil, tErr.Wrap("failed to authenticate to %s: %w", check.pin.Hub, tErr)
case tErr.Is(terminal.ErrExplicitAck):
// Authentication was successful.
default:
// Authentication was aborted.
if tErr != nil {
tErr = terminal.ErrUnknownError
}
log.Warningf("spn/crew: auth to %s aborted with %s", check.pin.Hub, tErr)
return nil, nil, tErr.Wrap("authentication to %s aborted: %w", check.pin.Hub, tErr)
}
case <-time.After(5 * time.Second):
// Mark as failing for just a minute, until server load may be less.
check.pin.MarkAsFailingFor(1 * time.Minute)
log.Warningf("spn/crew: auth to %s timed out", check.pin.Hub)
return nil, nil, terminal.ErrTimeout.With("waiting for auth to %s", check.pin.Hub)
}
// Add terminal extension to the map.
check.pin.SetActiveTerminal(&navigator.PinConnection{
Terminal: check.expansion,
Route: check.route,
})
check.expansion.MarkReachable()
log.Infof("spn/crew: added conn to %s via %s", check.pin, check.route)
case check.pingOp != nil:
// Wait for ping result.
select {
case tErr := <-check.pingOp.Result:
if !tErr.Is(terminal.ErrExplicitAck) {
// Mark as failing long enough to expire connections and session and shutdown connections.
// TODO: Should we forcibly disconnect instead?
// TODO: This might also be triggered if a relay fails and ends the operation.
check.pin.MarkAsFailingFor(7 * time.Minute)
// Forget about existing active terminal, re-create if needed.
check.pin.SetActiveTerminal(nil)
log.Warningf("spn/crew: failed to check reachability of %s: %s", check.pin.Hub, tErr)
return nil, nil, tErr.Wrap("failed to check reachability of %s: %w", check.pin.Hub, tErr)
}
case <-time.After(5 * time.Second):
// Mark as failing for just a minute, until server load may be less.
check.pin.MarkAsFailingFor(1 * time.Minute)
// Forget about existing active terminal, re-create if needed.
check.pin.SetActiveTerminal(nil)
log.Warningf("spn/crew: reachability check to %s timed out", check.pin.Hub)
return nil, nil, terminal.ErrTimeout.With("waiting for ping to %s", check.pin.Hub)
}
check.expansion.MarkReachable()
log.Debugf("spn/crew: checked conn to %s via %s", check.pin.Hub, check.route)
default:
log.Errorf("spn/crew: invalid hop check for %s", check.pin.Hub)
return nil, nil, terminal.ErrInternalError.With("invalid hop check")
}
}
// Return last hop.
return previousHop, previousTerminal, nil
}
func expand(fromTerminal terminal.Terminal, from, to *navigator.Pin) (expansion *docks.ExpansionTerminal, authOp *access.AuthorizeOp, tErr *terminal.Error) {
expansion, tErr = docks.ExpandTo(fromTerminal, to.Hub.ID, to.Hub)
if tErr != nil {
return nil, nil, tErr.Wrap("failed to expand to %s", to.Hub)
}
authOp, tErr = access.AuthorizeToTerminal(expansion)
if tErr != nil {
expansion.Abandon(nil)
return nil, nil, tErr.Wrap("failed to authorize")
}
log.Infof("spn/crew: expanded to %s (from %s)", to.Hub, from.Hub)
return expansion, authOp, nil
}
// TunnelContext holds additional information about the tunnel to be added to a
// connection.
type TunnelContext struct {
Path []*TunnelContextHop
PathCost float32
RoutingAlg string
tunnel *Tunnel
}
// GetExitNodeID returns the ID of the exit node.
// It returns an empty string in case no path exists.
func (tc *TunnelContext) GetExitNodeID() string {
if len(tc.Path) == 0 {
return ""
}
return tc.Path[len(tc.Path)-1].ID
}
// StopTunnel stops the tunnel.
func (tc *TunnelContext) StopTunnel() error {
if tc.tunnel != nil && tc.tunnel.conn != nil {
return tc.tunnel.conn.Close()
}
return nil
}
// TunnelContextHop holds hop data for TunnelContext.
type TunnelContextHop struct {
ID string
Name string
IPv4 *TunnelContextHopIPInfo `json:",omitempty"`
IPv6 *TunnelContextHopIPInfo `json:",omitempty"`
}
// TunnelContextHopIPInfo holds hop IP data for TunnelContextHop.
type TunnelContextHopIPInfo struct {
IP net.IP
Country string
ASN uint
ASOwner string
}
func addTunnelContextToConnection(t *Tunnel) {
// Create and add basic info.
tunnelCtx := &TunnelContext{
Path: make([]*TunnelContextHop, len(t.route.Path)),
PathCost: t.route.TotalCost,
RoutingAlg: t.route.Algorithm,
tunnel: t,
}
t.connInfo.TunnelContext = tunnelCtx
// Add path info.
for i, hop := range t.route.Path {
// Add hub info.
hopCtx := &TunnelContextHop{
ID: hop.HubID,
Name: hop.Pin().Hub.Info.Name,
}
tunnelCtx.Path[i] = hopCtx
// Add hub IPv4 info.
if hop.Pin().Hub.Info.IPv4 != nil {
hopCtx.IPv4 = &TunnelContextHopIPInfo{
IP: hop.Pin().Hub.Info.IPv4,
}
if hop.Pin().LocationV4 != nil {
hopCtx.IPv4.Country = hop.Pin().LocationV4.Country.Code
hopCtx.IPv4.ASN = hop.Pin().LocationV4.AutonomousSystemNumber
hopCtx.IPv4.ASOwner = hop.Pin().LocationV4.AutonomousSystemOrganization
}
}
// Add hub IPv6 info.
if hop.Pin().Hub.Info.IPv6 != nil {
hopCtx.IPv6 = &TunnelContextHopIPInfo{
IP: hop.Pin().Hub.Info.IPv6,
}
if hop.Pin().LocationV6 != nil {
hopCtx.IPv6.Country = hop.Pin().LocationV6.Country.Code
hopCtx.IPv6.ASN = hop.Pin().LocationV6.AutonomousSystemNumber
hopCtx.IPv6.ASOwner = hop.Pin().LocationV6.AutonomousSystemOrganization
}
}
}
}

223
spn/crew/metrics.go Normal file
View File

@@ -0,0 +1,223 @@
package crew
import (
"sync/atomic"
"github.com/tevino/abool"
"github.com/safing/portbase/api"
"github.com/safing/portbase/metrics"
)
var (
connectOpCnt *metrics.Counter
connectOpCntError *metrics.Counter
connectOpCntBadRequest *metrics.Counter
connectOpCntCanceled *metrics.Counter
connectOpCntFailed *metrics.Counter
connectOpCntConnected *metrics.Counter
connectOpCntRateLimited *metrics.Counter
connectOpIncomingBytes *metrics.Counter
connectOpOutgoingBytes *metrics.Counter
connectOpTTCRDurationHistogram *metrics.Histogram
connectOpTTFBDurationHistogram *metrics.Histogram
connectOpDurationHistogram *metrics.Histogram
connectOpIncomingDataHistogram *metrics.Histogram
connectOpOutgoingDataHistogram *metrics.Histogram
metricsRegistered = abool.New()
)
func registerMetrics() (err error) {
// Only register metrics once.
if !metricsRegistered.SetToIf(false, true) {
return nil
}
// Connect Op Stats on client.
connectOpCnt, err = metrics.NewCounter(
"spn/op/connect/total",
nil,
&metrics.Options{
Name: "SPN Total Connect Operations",
InternalID: "spn_connect_count",
Permission: api.PermitUser,
Persist: true,
},
)
if err != nil {
return err
}
// Connect Op Stats on server.
connectOpCntOptions := &metrics.Options{
Name: "SPN Total Connect Operations",
Permission: api.PermitUser,
Persist: true,
}
connectOpCntError, err = metrics.NewCounter(
"spn/op/connect/total",
map[string]string{"result": "error"},
connectOpCntOptions,
)
if err != nil {
return err
}
connectOpCntBadRequest, err = metrics.NewCounter(
"spn/op/connect/total",
map[string]string{"result": "bad_request"},
connectOpCntOptions,
)
if err != nil {
return err
}
connectOpCntCanceled, err = metrics.NewCounter(
"spn/op/connect/total",
map[string]string{"result": "canceled"},
connectOpCntOptions,
)
if err != nil {
return err
}
connectOpCntFailed, err = metrics.NewCounter(
"spn/op/connect/total",
map[string]string{"result": "failed"},
connectOpCntOptions,
)
if err != nil {
return err
}
connectOpCntConnected, err = metrics.NewCounter(
"spn/op/connect/total",
map[string]string{"result": "connected"},
connectOpCntOptions,
)
if err != nil {
return err
}
connectOpCntRateLimited, err = metrics.NewCounter(
"spn/op/connect/total",
map[string]string{"result": "rate_limited"},
connectOpCntOptions,
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/op/connect/active",
nil,
getActiveConnectOpsStat,
&metrics.Options{
Name: "SPN Active Connect Operations",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
connectOpIncomingBytes, err = metrics.NewCounter(
"spn/op/connect/incoming/bytes",
nil,
&metrics.Options{
Name: "SPN Connect Operation Incoming Bytes",
InternalID: "spn_connect_in_bytes",
Permission: api.PermitUser,
Persist: true,
},
)
if err != nil {
return err
}
connectOpOutgoingBytes, err = metrics.NewCounter(
"spn/op/connect/outgoing/bytes",
nil,
&metrics.Options{
Name: "SPN Connect Operation Outgoing Bytes",
InternalID: "spn_connect_out_bytes",
Permission: api.PermitUser,
Persist: true,
},
)
if err != nil {
return err
}
connectOpTTCRDurationHistogram, err = metrics.NewHistogram(
"spn/op/connect/histogram/ttcr/seconds",
nil,
&metrics.Options{
Name: "SPN Connect Operation time-to-connect-request Histogram",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
connectOpTTFBDurationHistogram, err = metrics.NewHistogram(
"spn/op/connect/histogram/ttfb/seconds",
nil,
&metrics.Options{
Name: "SPN Connect Operation time-to-first-byte (from TTCR) Histogram",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
connectOpDurationHistogram, err = metrics.NewHistogram(
"spn/op/connect/histogram/duration/seconds",
nil,
&metrics.Options{
Name: "SPN Connect Operation Duration Histogram",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
connectOpIncomingDataHistogram, err = metrics.NewHistogram(
"spn/op/connect/histogram/incoming/bytes",
nil,
&metrics.Options{
Name: "SPN Connect Operation Downloaded Data Histogram",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
connectOpOutgoingDataHistogram, err = metrics.NewHistogram(
"spn/op/connect/histogram/outgoing/bytes",
nil,
&metrics.Options{
Name: "SPN Connect Operation Outgoing Data Histogram",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
return nil
}
func getActiveConnectOpsStat() float64 {
return float64(atomic.LoadInt64(activeConnectOps))
}

44
spn/crew/module.go Normal file
View File

@@ -0,0 +1,44 @@
package crew
import (
"time"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/spn/terminal"
)
var module *modules.Module
func init() {
module = modules.Register("crew", nil, start, stop, "terminal", "docks", "navigator", "intel", "cabin")
}
func start() error {
module.NewTask("sticky cleaner", cleanStickyHubs).
Repeat(10 * time.Minute)
return registerMetrics()
}
func stop() error {
clearStickyHubs()
terminal.StopScheduler()
return nil
}
var connectErrors = make(chan *terminal.Error, 10)
func reportConnectError(tErr *terminal.Error) {
select {
case connectErrors <- tErr:
default:
}
}
// ConnectErrors returns errors of connect operations.
// It only has a small and shared buffer and may only be used for indications,
// not for full monitoring.
func ConnectErrors() <-chan *terminal.Error {
return connectErrors
}

13
spn/crew/module_test.go Normal file
View File

@@ -0,0 +1,13 @@
package crew
import (
"testing"
"github.com/safing/portmaster/service/core/pmtesting"
"github.com/safing/portmaster/spn/conf"
)
func TestMain(m *testing.M) {
conf.EnablePublicHub(true)
pmtesting.TestMain(m, module)
}

585
spn/crew/op_connect.go Normal file
View File

@@ -0,0 +1,585 @@
package crew
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync/atomic"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/terminal"
)
// ConnectOpType is the type ID for the connection operation.
const ConnectOpType string = "connect"
var activeConnectOps = new(int64)
// ConnectOp is used to connect data tunnels to servers on the Internet.
type ConnectOp struct {
terminal.OperationBase
// Flow Control
dfq *terminal.DuplexFlowQueue
// Context and shutdown handling
// ctx is the context of the Terminal.
ctx context.Context
// cancelCtx cancels ctx.
cancelCtx context.CancelFunc
// doneWriting signals that the writer has finished writing.
doneWriting chan struct{}
// Metrics
incomingTraffic atomic.Uint64
outgoingTraffic atomic.Uint64
started time.Time
// Connection
t terminal.Terminal
conn net.Conn
request *ConnectRequest
entry bool
tunnel *Tunnel
}
// Type returns the type ID.
func (op *ConnectOp) Type() string {
return ConnectOpType
}
// Ctx returns the operation context.
func (op *ConnectOp) Ctx() context.Context {
return op.ctx
}
// ConnectRequest holds all the information necessary for a connect operation.
type ConnectRequest struct {
Domain string `json:"d,omitempty"`
IP net.IP `json:"ip,omitempty"`
UsePriorityDataMsgs bool `json:"pr,omitempty"`
Protocol packet.IPProtocol `json:"p,omitempty"`
Port uint16 `json:"po,omitempty"`
QueueSize uint32 `json:"qs,omitempty"`
}
// DialNetwork returns the address of the connect request.
func (r *ConnectRequest) DialNetwork() string {
if ip4 := r.IP.To4(); ip4 != nil {
switch r.Protocol { //nolint:exhaustive // Only looking for supported protocols.
case packet.TCP:
return "tcp4"
case packet.UDP:
return "udp4"
}
} else {
switch r.Protocol { //nolint:exhaustive // Only looking for supported protocols.
case packet.TCP:
return "tcp6"
case packet.UDP:
return "udp6"
}
}
return ""
}
// Address returns the address of the connext request.
func (r *ConnectRequest) Address() string {
return net.JoinHostPort(r.IP.String(), strconv.Itoa(int(r.Port)))
}
func (r *ConnectRequest) String() string {
if r.Domain != "" {
return fmt.Sprintf("%s (%s %s)", r.Domain, r.Protocol, r.Address())
}
return fmt.Sprintf("%s %s", r.Protocol, r.Address())
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: ConnectOpType,
Requires: terminal.MayConnect,
Start: startConnectOp,
})
}
// NewConnectOp starts a new connect operation.
func NewConnectOp(tunnel *Tunnel) (*ConnectOp, *terminal.Error) {
// Submit metrics.
connectOpCnt.Inc()
// Create request.
request := &ConnectRequest{
Domain: tunnel.connInfo.Entity.Domain,
IP: tunnel.connInfo.Entity.IP,
Protocol: packet.IPProtocol(tunnel.connInfo.Entity.Protocol),
Port: tunnel.connInfo.Entity.Port,
UsePriorityDataMsgs: terminal.UsePriorityDataMsgs,
}
// Set defaults.
if request.QueueSize == 0 {
request.QueueSize = terminal.DefaultQueueSize
}
// Create new op.
op := &ConnectOp{
doneWriting: make(chan struct{}),
t: tunnel.dstTerminal,
conn: tunnel.conn,
request: request,
entry: true,
tunnel: tunnel,
}
op.ctx, op.cancelCtx = context.WithCancel(module.Ctx)
op.dfq = terminal.NewDuplexFlowQueue(op.Ctx(), request.QueueSize, op.submitUpstream)
// Prepare init msg.
data, err := dsd.Dump(request, dsd.CBOR)
if err != nil {
return nil, terminal.ErrInternalError.With("failed to pack connect request: %w", err)
}
// Initialize.
tErr := op.t.StartOperation(op, container.New(data), 5*time.Second)
if err != nil {
return nil, tErr
}
// Setup metrics.
op.started = time.Now()
module.StartWorker("connect op conn reader", op.connReader)
module.StartWorker("connect op conn writer", op.connWriter)
module.StartWorker("connect op flow handler", op.dfq.FlowHandler)
log.Infof("spn/crew: connected to %s via %s", request, tunnel.dstPin.Hub)
return op, nil
}
func startConnectOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Check if we are running a public hub.
if !conf.PublicHub() {
return nil, terminal.ErrPermissionDenied.With("connecting is only allowed on public hubs")
}
// Parse connect request.
request := &ConnectRequest{}
_, err := dsd.Load(data.CompileData(), request)
if err != nil {
connectOpCntError.Inc() // More like a protocol/system error than a bad request.
return nil, terminal.ErrMalformedData.With("failed to parse connect request: %w", err)
}
if request.QueueSize == 0 || request.QueueSize > terminal.MaxQueueSize {
connectOpCntError.Inc() // More like a protocol/system error than a bad request.
return nil, terminal.ErrInvalidOptions.With("invalid queue size of %d", request.QueueSize)
}
// Check if IP seems valid.
if len(request.IP) != net.IPv4len && len(request.IP) != net.IPv6len {
connectOpCntError.Inc() // More like a protocol/system error than a bad request.
return nil, terminal.ErrInvalidOptions.With("ip address is not valid")
}
// Create and initialize operation.
op := &ConnectOp{
doneWriting: make(chan struct{}),
t: t,
request: request,
}
op.InitOperationBase(t, opID)
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
op.dfq = terminal.NewDuplexFlowQueue(op.Ctx(), request.QueueSize, op.submitUpstream)
// Start worker to complete setting up the connection.
module.StartWorker("connect op setup", op.handleSetup)
return op, nil
}
func (op *ConnectOp) handleSetup(_ context.Context) error {
// Get terminal session for rate limiting.
var session *terminal.Session
if sessionTerm, ok := op.t.(terminal.SessionTerminal); ok {
session = sessionTerm.GetSession()
} else {
connectOpCntError.Inc()
log.Errorf("spn/crew: %T is not a session terminal, aborting op %s#%d", op.t, op.t.FmtID(), op.ID())
op.Stop(op, terminal.ErrInternalError.With("no session available"))
return nil
}
// Limit concurrency of connecting.
cancelErr := session.LimitConcurrency(op.Ctx(), func() {
op.setup(session)
})
// If context was canceled, stop operation.
if cancelErr != nil {
connectOpCntCanceled.Inc()
op.Stop(op, terminal.ErrCanceled.With(cancelErr.Error()))
}
// Do not return a worker error.
return nil
}
func (op *ConnectOp) setup(session *terminal.Session) {
// Rate limit before connecting.
if tErr := session.RateLimit(); tErr != nil {
// Add rate limit info to error.
if tErr.Is(terminal.ErrRateLimited) {
connectOpCntRateLimited.Inc()
op.Stop(op, tErr.With(session.RateLimitInfo()))
return
}
connectOpCntError.Inc()
op.Stop(op, tErr)
return
}
// Check if connection target is in global scope.
ipScope := netutils.GetIPScope(op.request.IP)
if ipScope != netutils.Global {
session.ReportSuspiciousActivity(terminal.SusFactorQuiteUnusual)
connectOpCntBadRequest.Inc()
op.Stop(op, terminal.ErrPermissionDenied.With("denied request to connect to non-global IP %s", op.request.IP))
return
}
// Check exit policy.
if tErr := checkExitPolicy(op.request); tErr != nil {
session.ReportSuspiciousActivity(terminal.SusFactorQuiteUnusual)
connectOpCntBadRequest.Inc()
op.Stop(op, tErr)
return
}
// Check one last time before connecting if operation was not canceled.
if op.Ctx().Err() != nil {
op.Stop(op, terminal.ErrCanceled.With(op.Ctx().Err().Error()))
connectOpCntCanceled.Inc()
return
}
// Connect to destination.
dialNet := op.request.DialNetwork()
if dialNet == "" {
session.ReportSuspiciousActivity(terminal.SusFactorCommon)
connectOpCntBadRequest.Inc()
op.Stop(op, terminal.ErrIncorrectUsage.With("protocol %s is not supported", op.request.Protocol))
return
}
dialer := &net.Dialer{
Timeout: 10 * time.Second,
LocalAddr: conf.GetBindAddr(dialNet),
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
KeepAlive: -1, // Disable keep-alive.
}
conn, err := dialer.DialContext(op.Ctx(), dialNet, op.request.Address())
if err != nil {
// Connection errors are common, but still a bit suspicious.
var netError net.Error
switch {
case errors.As(err, &netError) && netError.Timeout():
session.ReportSuspiciousActivity(terminal.SusFactorCommon)
connectOpCntFailed.Inc()
case errors.Is(err, context.Canceled):
session.ReportSuspiciousActivity(terminal.SusFactorCommon)
connectOpCntCanceled.Inc()
default:
session.ReportSuspiciousActivity(terminal.SusFactorWeirdButOK)
connectOpCntFailed.Inc()
}
op.Stop(op, terminal.ErrConnectionError.With("failed to connect to %s: %w", op.request, err))
return
}
op.conn = conn
// Start worker.
module.StartWorker("connect op conn reader", op.connReader)
module.StartWorker("connect op conn writer", op.connWriter)
module.StartWorker("connect op flow handler", op.dfq.FlowHandler)
connectOpCntConnected.Inc()
log.Infof("spn/crew: connected op %s#%d to %s", op.t.FmtID(), op.ID(), op.request)
}
func (op *ConnectOp) submitUpstream(msg *terminal.Msg, timeout time.Duration) {
err := op.Send(msg, timeout)
if err != nil {
msg.Finish()
op.Stop(op, err.Wrap("failed to send data (op) read from %s", op.connectedType()))
}
}
const (
readBufSize = 1500
// High priority up to first 10MB.
highPrioThreshold = 10_000_000
// Rate limit to 128 Mbit/s after 1GB traffic.
// Do NOT use time.Sleep per packet, as it is very inaccurate and will sleep a lot longer than desired.
rateLimitThreshold = 1_000_000_000
rateLimitMaxMbit = 128
)
func (op *ConnectOp) connReader(_ context.Context) error {
// Metrics setup and submitting.
atomic.AddInt64(activeConnectOps, 1)
defer func() {
atomic.AddInt64(activeConnectOps, -1)
connectOpDurationHistogram.UpdateDuration(op.started)
connectOpIncomingDataHistogram.Update(float64(op.incomingTraffic.Load()))
}()
rateLimiter := terminal.NewRateLimiter(rateLimitMaxMbit)
for {
// Read from connection.
buf := make([]byte, readBufSize)
n, err := op.conn.Read(buf)
if err != nil {
if errors.Is(err, io.EOF) {
op.Stop(op, terminal.ErrStopping.With("connection to %s was closed on read", op.connectedType()))
} else {
op.Stop(op, terminal.ErrConnectionError.With("failed to read from %s: %w", op.connectedType(), err))
}
return nil
}
if n == 0 {
log.Tracef("spn/crew: connect op %s>%d read 0 bytes from %s", op.t.FmtID(), op.ID(), op.connectedType())
continue
}
// Submit metrics.
connectOpIncomingBytes.Add(n)
inBytes := op.incomingTraffic.Add(uint64(n))
// Rate limit if over threshold.
if inBytes > rateLimitThreshold {
rateLimiter.Limit(uint64(n))
}
// Create message from data.
msg := op.NewMsg(buf[:n])
// Define priority and possibly wait for slot.
switch {
case inBytes > highPrioThreshold:
msg.Unit.WaitForSlot()
case op.request.UsePriorityDataMsgs:
msg.Unit.MakeHighPriority()
}
// Send packet.
tErr := op.dfq.Send(
msg,
30*time.Second,
)
if tErr != nil {
msg.Finish()
op.Stop(op, tErr.Wrap("failed to send data (dfq) from %s", op.connectedType()))
return nil
}
}
}
// Deliver delivers a messages to the operation.
func (op *ConnectOp) Deliver(msg *terminal.Msg) *terminal.Error {
return op.dfq.Deliver(msg)
}
func (op *ConnectOp) connWriter(_ context.Context) error {
// Metrics submitting.
defer func() {
connectOpOutgoingDataHistogram.Update(float64(op.outgoingTraffic.Load()))
}()
defer func() {
// Signal that we are done with writing.
close(op.doneWriting)
// Close connection.
_ = op.conn.Close()
}()
var msg *terminal.Msg
defer msg.Finish()
rateLimiter := terminal.NewRateLimiter(rateLimitMaxMbit)
writing:
for {
msg.Finish()
select {
case msg = <-op.dfq.Receive():
case <-op.ctx.Done():
op.Stop(op, terminal.ErrCanceled)
return nil
default:
// Handle all data before also listening for the context cancel.
// This ensures all data is written properly before stopping.
select {
case msg = <-op.dfq.Receive():
case op.doneWriting <- struct{}{}:
op.Stop(op, terminal.ErrStopping)
return nil
case <-op.ctx.Done():
op.Stop(op, terminal.ErrCanceled)
return nil
}
}
// TODO: Instead of compiling data here again, can we send it as in the container?
data := msg.Data.CompileData()
if len(data) == 0 {
continue writing
}
// Submit metrics.
connectOpOutgoingBytes.Add(len(data))
out := op.outgoingTraffic.Add(uint64(len(data)))
// Rate limit if over threshold.
if out > rateLimitThreshold {
rateLimiter.Limit(uint64(len(data)))
}
// Special handling after first data was received on client.
if op.entry &&
out == uint64(len(data)) {
// Report time taken to receive first byte.
connectOpTTFBDurationHistogram.UpdateDuration(op.started)
// If not stickied yet, stick destination to Hub.
if !op.tunnel.stickied {
op.tunnel.stickDestinationToHub()
}
}
// Send all given data.
for {
n, err := op.conn.Write(data)
switch {
case err != nil:
if errors.Is(err, io.EOF) {
op.Stop(op, terminal.ErrStopping.With("connection to %s was closed on write", op.connectedType()))
} else {
op.Stop(op, terminal.ErrConnectionError.With("failed to send to %s: %w", op.connectedType(), err))
}
return nil
case n == 0:
op.Stop(op, terminal.ErrConnectionError.With("sent 0 bytes to %s", op.connectedType()))
return nil
case n < len(data):
// If not all data was sent, try again.
log.Debugf("spn/crew: %s#%d only sent %d/%d bytes to %s", op.t.FmtID(), op.ID(), n, len(data), op.connectedType())
data = data[n:]
default:
continue writing
}
}
}
}
func (op *ConnectOp) connectedType() string {
if op.entry {
return "origin"
}
return "destination"
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *ConnectOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
if err.IsError() {
reportConnectError(err)
}
// If the connection has sent or received any data so far, finish the data
// flows as it makes sense.
if op.incomingTraffic.Load() > 0 || op.outgoingTraffic.Load() > 0 {
// If the op was ended locally, send all data before closing.
// If the op was ended remotely, don't bother sending remaining data.
if !err.IsExternal() {
// Flushing could mean sending a full buffer of 50000 packets.
op.dfq.Flush(5 * time.Minute)
}
// If the op was ended remotely, write all remaining received data.
// If the op was ended locally, don't bother writing remaining data.
if err.IsExternal() {
select {
case <-op.doneWriting:
default:
select {
case <-op.doneWriting:
case <-time.After(5 * time.Second):
}
}
}
}
// Cancel workers.
op.cancelCtx()
// Special client-side handling.
if op.entry {
// Mark the connection as failed if there was an error and no data was sent to the app yet.
if err.IsError() && op.outgoingTraffic.Load() == 0 {
// Set connection to failed and save it to propagate the update.
c := op.tunnel.connInfo
func() {
c.Lock()
defer c.Unlock()
if err.IsExternal() {
c.Failed(fmt.Sprintf(
"the exit node reported an error: %s", err,
), "")
} else {
c.Failed(fmt.Sprintf(
"connection failed locally: %s", err,
), "")
}
c.Save()
}()
}
// Avoid connecting to the destination via this Hub if:
// - The error is external - ie. from the server.
// - The error is a connection error.
// - No data was received.
// This indicates that there is some network level issue that we can
// possibly work around by using another exit node.
if err.IsError() && err.IsExternal() &&
err.Is(terminal.ErrConnectionError) &&
op.outgoingTraffic.Load() == 0 {
op.tunnel.avoidDestinationHub()
}
// Don't leak local errors to the server.
if !err.IsExternal() {
// Change error that is reported.
return terminal.ErrStopping
}
}
return err
}

115
spn/crew/op_connect_test.go Normal file
View File

@@ -0,0 +1,115 @@
package crew
import (
"fmt"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/safing/portmaster/service/intel"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/navigator"
"github.com/safing/portmaster/spn/terminal"
)
const (
testPadding = 8
testQueueSize = 10
)
func TestConnectOp(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skipping test in short mode, as it interacts with the network")
}
// Create test terminal pair.
a, b, err := terminal.NewSimpleTestTerminalPair(0, 0,
&terminal.TerminalOpts{
FlowControl: terminal.FlowControlDFQ,
FlowControlSize: testQueueSize,
Padding: testPadding,
},
)
if err != nil {
t.Fatalf("failed to create test terminal pair: %s", err)
}
// Set up connect op.
b.GrantPermission(terminal.MayConnect)
conf.EnablePublicHub(true)
identity, err := cabin.CreateIdentity(module.Ctx, "test")
if err != nil {
t.Fatalf("failed to create identity: %s", err)
}
_, err = identity.MaintainAnnouncement(&hub.Announcement{
Transports: []string{
"tcp:17",
},
Exit: []string{
"+ * */80",
"- *",
},
}, true)
if err != nil {
t.Fatalf("failed to update identity: %s", err)
}
EnableConnecting(identity.Hub)
for i := 0; i < 1; i++ {
appConn, sluiceConn := net.Pipe()
_, tErr := NewConnectOp(&Tunnel{
connInfo: &network.Connection{
Entity: (&intel.Entity{
Protocol: 6,
Port: 80,
Domain: "orf.at.",
IP: net.IPv4(194, 232, 104, 142),
}).Init(0),
},
conn: sluiceConn,
dstTerminal: a,
dstPin: &navigator.Pin{
Hub: identity.Hub,
},
})
if tErr != nil {
t.Fatalf("failed to start connect op: %s", tErr)
}
// Send request.
requestURL, err := url.Parse("http://orf.at/")
if err != nil {
t.Fatalf("failed to parse request url: %s", err)
}
r := http.Request{
Method: http.MethodHead,
URL: requestURL,
}
err = r.Write(appConn)
if err != nil {
t.Fatalf("failed to write request: %s", err)
}
// Recv response.
data := make([]byte, 1500)
n, err := appConn.Read(data)
if err != nil {
t.Fatalf("failed to read request: %s", err)
}
if n == 0 {
t.Fatal("received empty reply")
}
t.Log("received data:")
fmt.Println(string(data[:n]))
time.Sleep(500 * time.Millisecond)
}
}

149
spn/crew/op_ping.go Normal file
View File

@@ -0,0 +1,149 @@
package crew
import (
"crypto/subtle"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/rng"
"github.com/safing/portmaster/spn/terminal"
)
const (
// PingOpType is the type ID of the latency test operation.
PingOpType = "ping"
pingOpNonceSize = 16
pingOpTimeout = 3 * time.Second
)
// PingOp is used to measure latency.
type PingOp struct {
terminal.OneOffOperationBase
started time.Time
nonce []byte
}
// PingOpRequest is a ping request.
type PingOpRequest struct {
Nonce []byte `json:"n,omitempty"`
}
// PingOpResponse is a ping response.
type PingOpResponse struct {
Nonce []byte `json:"n,omitempty"`
Time time.Time `json:"t,omitempty"`
}
// Type returns the type ID.
func (op *PingOp) Type() string {
return PingOpType
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: PingOpType,
Start: startPingOp,
})
}
// NewPingOp runs a latency test.
func NewPingOp(t terminal.Terminal) (*PingOp, *terminal.Error) {
// Generate nonce.
nonce, err := rng.Bytes(pingOpNonceSize)
if err != nil {
return nil, terminal.ErrInternalError.With("failed to generate ping nonce: %w", err)
}
// Create operation and init.
op := &PingOp{
started: time.Now().UTC(),
nonce: nonce,
}
op.OneOffOperationBase.Init()
// Create request.
pingRequest, err := dsd.Dump(&PingOpRequest{
Nonce: op.nonce,
}, dsd.CBOR)
if err != nil {
return nil, terminal.ErrInternalError.With("failed to create ping request: %w", err)
}
// Send ping.
tErr := t.StartOperation(op, container.New(pingRequest), pingOpTimeout)
if tErr != nil {
return nil, tErr
}
return op, nil
}
// Deliver delivers a message to the operation.
func (op *PingOp) Deliver(msg *terminal.Msg) *terminal.Error {
defer msg.Finish()
// Parse response.
response := &PingOpResponse{}
_, err := dsd.Load(msg.Data.CompileData(), response)
if err != nil {
return terminal.ErrMalformedData.With("failed to parse ping response: %w", err)
}
// Check if the nonce matches.
if subtle.ConstantTimeCompare(op.nonce, response.Nonce) != 1 {
return terminal.ErrIntegrity.With("ping nonce mismatched")
}
return terminal.ErrExplicitAck
}
func startPingOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Parse request.
request := &PingOpRequest{}
_, err := dsd.Load(data.CompileData(), request)
if err != nil {
return nil, terminal.ErrMalformedData.With("failed to parse ping request: %w", err)
}
// Create response.
response, err := dsd.Dump(&PingOpResponse{
Nonce: request.Nonce,
Time: time.Now().UTC(),
}, dsd.CBOR)
if err != nil {
return nil, terminal.ErrInternalError.With("failed to create ping response: %w", err)
}
// Send response.
msg := terminal.NewMsg(response)
msg.FlowID = opID
msg.Unit.MakeHighPriority()
if terminal.UsePriorityDataMsgs {
msg.Type = terminal.MsgTypePriorityData
}
tErr := t.Send(msg, pingOpTimeout)
if tErr != nil {
// Finish message unit on failure.
msg.Finish()
return nil, tErr.With("failed to send ping response")
}
// Operation is just one response and finished successfully.
return nil, nil
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *PingOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
// Prevent remote from sending explicit ack, as we use it as a success signal internally.
if err.Is(terminal.ErrExplicitAck) && err.IsExternal() {
err = terminal.ErrStopping.AsExternal()
}
// Continue with usual handling of inherited base.
return op.OneOffOperationBase.HandleStop(err)
}

32
spn/crew/op_ping_test.go Normal file
View File

@@ -0,0 +1,32 @@
package crew
import (
"testing"
"time"
"github.com/safing/portmaster/spn/terminal"
)
func TestPingOp(t *testing.T) {
t.Parallel()
// Create test terminal pair.
a, _, err := terminal.NewSimpleTestTerminalPair(0, 0, nil)
if err != nil {
t.Fatalf("failed to create test terminal pair: %s", err)
}
// Create ping op.
op, tErr := NewPingOp(a)
if tErr.IsError() {
t.Fatal(tErr)
}
// Wait for result.
select {
case result := <-op.Result:
t.Logf("ping result: %s", result.Error())
case <-time.After(pingOpTimeout):
t.Fatal("timed out")
}
}

51
spn/crew/policy.go Normal file
View File

@@ -0,0 +1,51 @@
package crew
import (
"context"
"sync"
"github.com/safing/portmaster/service/intel"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/terminal"
)
var (
connectingHubLock sync.Mutex
connectingHub *hub.Hub
)
// EnableConnecting enables connecting from this Hub.
func EnableConnecting(my *hub.Hub) {
connectingHubLock.Lock()
defer connectingHubLock.Unlock()
connectingHub = my
}
func checkExitPolicy(request *ConnectRequest) *terminal.Error {
connectingHubLock.Lock()
defer connectingHubLock.Unlock()
// Check if connect requests are allowed.
if connectingHub == nil {
return terminal.ErrPermissionDenied.With("connect requests disabled")
}
// Create entity.
entity := (&intel.Entity{
IP: request.IP,
Protocol: uint8(request.Protocol),
Port: request.Port,
Domain: request.Domain,
}).Init(0)
entity.FetchData(context.TODO())
// Check against policy.
result, reason := connectingHub.GetInfo().ExitPolicy().Match(context.TODO(), entity)
if result == endpoints.Denied {
return terminal.ErrPermissionDenied.With("connect request for %s violates the exit policy: %s", request, reason)
}
return nil
}

176
spn/crew/sticky.go Normal file
View File

@@ -0,0 +1,176 @@
package crew
import (
"context"
"fmt"
"sync"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/spn/navigator"
)
const (
stickyTTL = 1 * time.Hour
)
var (
stickyIPs = make(map[string]*stickyHub)
stickyDomains = make(map[string]*stickyHub)
stickyLock sync.Mutex
)
type stickyHub struct {
Pin *navigator.Pin
Route *navigator.Route
LastSeen time.Time
Avoid bool
}
func (sh *stickyHub) isExpired() bool {
return time.Now().Add(-stickyTTL).After(sh.LastSeen)
}
func makeStickyIPKey(conn *network.Connection) string {
if p := conn.Process().Profile(); p != nil {
return fmt.Sprintf(
"%s/%s>%s",
p.LocalProfile().Source,
p.LocalProfile().ID,
conn.Entity.IP,
)
}
return "?>" + string(conn.Entity.IP)
}
func makeStickyDomainKey(conn *network.Connection) string {
if p := conn.Process().Profile(); p != nil {
return fmt.Sprintf(
"%s/%s>%s",
p.LocalProfile().Source,
p.LocalProfile().ID,
conn.Entity.Domain,
)
}
return "?>" + conn.Entity.Domain
}
func getStickiedHub(conn *network.Connection) (sticksTo *stickyHub) {
stickyLock.Lock()
defer stickyLock.Unlock()
// Check if IP is sticky.
sticksTo = stickyIPs[makeStickyIPKey(conn)] // byte comparison
if sticksTo != nil && !sticksTo.isExpired() {
sticksTo.LastSeen = time.Now()
}
// If the IP did not stick and we have a domain, check if that sticks.
if sticksTo == nil && conn.Entity.Domain != "" {
sticksTo, ok := stickyDomains[makeStickyDomainKey(conn)]
if ok && !sticksTo.isExpired() {
sticksTo.LastSeen = time.Now()
}
}
// If nothing sticked, return now.
if sticksTo == nil {
return nil
}
// Get intel from map before locking pin to avoid simultaneous locking.
mapIntel := navigator.Main.GetIntel()
// Lock Pin for checking.
sticksTo.Pin.Lock()
defer sticksTo.Pin.Unlock()
// Check if the stickied Hub supports the needed IP version.
switch {
case conn.IPVersion == packet.IPv4 && sticksTo.Pin.EntityV4 == nil:
// Connection is IPv4, but stickied Hub has no IPv4.
return nil
case conn.IPVersion == packet.IPv6 && sticksTo.Pin.EntityV6 == nil:
// Connection is IPv4, but stickied Hub has no IPv4.
return nil
}
// Disregard stickied Hub if it is disregard with the current options.
matcher := conn.TunnelOpts.Destination.Matcher(mapIntel)
if !matcher(sticksTo.Pin) {
return nil
}
// Return fully checked stickied Hub.
return sticksTo
}
func (t *Tunnel) stickDestinationToHub() {
stickyLock.Lock()
defer stickyLock.Unlock()
// Stick to IP.
ipKey := makeStickyIPKey(t.connInfo)
stickyIPs[ipKey] = &stickyHub{
Pin: t.dstPin,
Route: t.route,
LastSeen: time.Now(),
}
log.Infof("spn/crew: sticking %s to %s", ipKey, t.dstPin.Hub)
// Stick to Domain, if present.
if t.connInfo.Entity.Domain != "" {
domainKey := makeStickyDomainKey(t.connInfo)
stickyDomains[domainKey] = &stickyHub{
Pin: t.dstPin,
Route: t.route,
LastSeen: time.Now(),
}
log.Infof("spn/crew: sticking %s to %s", domainKey, t.dstPin.Hub)
}
}
func (t *Tunnel) avoidDestinationHub() {
stickyLock.Lock()
defer stickyLock.Unlock()
// Stick to Hub/IP Pair.
ipKey := makeStickyIPKey(t.connInfo)
stickyIPs[ipKey] = &stickyHub{
Pin: t.dstPin,
LastSeen: time.Now(),
Avoid: true,
}
log.Warningf("spn/crew: avoiding %s for %s", t.dstPin.Hub, ipKey)
}
func cleanStickyHubs(ctx context.Context, task *modules.Task) error {
stickyLock.Lock()
defer stickyLock.Unlock()
for _, stickyRegistry := range []map[string]*stickyHub{stickyIPs, stickyDomains} {
for key, stickedEntry := range stickyRegistry {
if stickedEntry.isExpired() {
delete(stickyRegistry, key)
}
}
}
return nil
}
func clearStickyHubs() {
stickyLock.Lock()
defer stickyLock.Unlock()
for _, stickyRegistry := range []map[string]*stickyHub{stickyIPs, stickyDomains} {
for key := range stickyRegistry {
delete(stickyRegistry, key)
}
}
}

View File

@@ -0,0 +1,90 @@
package docks
import (
"testing"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portmaster/spn/terminal"
)
func TestEffectiveBandwidth(t *testing.T) { //nolint:paralleltest // Run alone.
// Skip in CI.
if testing.Short() {
t.Skip()
}
var (
bwTestDelay = 50 * time.Millisecond
bwTestQueueSize uint32 = 1000
bwTestVolume = 10000000 // 10MB
bwTestTime = 10 * time.Second
)
// Create test terminal pair.
a, b, err := terminal.NewSimpleTestTerminalPair(
bwTestDelay,
int(bwTestQueueSize),
&terminal.TerminalOpts{
FlowControl: terminal.FlowControlDFQ,
FlowControlSize: bwTestQueueSize,
},
)
if err != nil {
t.Fatalf("failed to create test terminal pair: %s", err)
}
// Grant permission for op on remote terminal and start op.
b.GrantPermission(terminal.IsCraneController)
// Re-use the capacity test for the bandwidth test.
op := &CapacityTestOp{
opts: &CapacityTestOptions{
TestVolume: bwTestVolume,
MaxTime: bwTestTime,
testing: true,
},
recvQueue: make(chan *terminal.Msg),
dataSent: new(int64),
dataSentWasAckd: abool.New(),
result: make(chan *terminal.Error, 1),
}
// Disable sender again.
op.senderStarted = true
op.dataSentWasAckd.Set()
// Make capacity test request.
request, err := dsd.Dump(op.opts, dsd.CBOR)
if err != nil {
t.Fatal(terminal.ErrInternalError.With("failed to serialize capactity test options: %w", err))
}
// Send test request.
tErr := a.StartOperation(op, container.New(request), 1*time.Second)
if tErr != nil {
t.Fatal(tErr)
}
// Start handler.
module.StartWorker("op capacity handler", op.handler)
// Wait for result and check error.
tErr = <-op.Result()
if !tErr.IsOK() {
t.Fatalf("op failed: %s", tErr)
}
t.Logf("measured capacity: %d bit/s", op.testResult)
// Calculate expected bandwidth.
expectedBitsPerSecond := (float64(capacityTestMsgSize*8*int64(bwTestQueueSize)) / float64(bwTestDelay)) * float64(time.Second)
t.Logf("expected capacity: %f bit/s", expectedBitsPerSecond)
// Check if measured bandwidth is within parameters.
if float64(op.testResult) > expectedBitsPerSecond*1.6 {
t.Fatal("measured capacity too high")
}
// TODO: Check if we can raise this to at least 90%.
if float64(op.testResult) < expectedBitsPerSecond*0.2 {
t.Fatal("measured capacity too low")
}
}

100
spn/docks/controller.go Normal file
View File

@@ -0,0 +1,100 @@
package docks
import (
"github.com/safing/portbase/container"
"github.com/safing/portmaster/spn/terminal"
)
// CraneControllerTerminal is a terminal for the crane itself.
type CraneControllerTerminal struct {
*terminal.TerminalBase
Crane *Crane
}
// NewLocalCraneControllerTerminal returns a new local crane controller.
func NewLocalCraneControllerTerminal(
crane *Crane,
initMsg *terminal.TerminalOpts,
) (*CraneControllerTerminal, *container.Container, *terminal.Error) {
// Remove unnecessary options from the crane controller.
initMsg.Padding = 0
// Create Terminal Base.
t, initData, err := terminal.NewLocalBaseTerminal(
crane.ctx,
0,
crane.ID,
nil,
initMsg,
terminal.UpstreamSendFunc(crane.sendImportantTerminalMsg),
)
if err != nil {
return nil, nil, err
}
return initCraneController(crane, t, initMsg), initData, nil
}
// NewRemoteCraneControllerTerminal returns a new remote crane controller.
func NewRemoteCraneControllerTerminal(
crane *Crane,
initData *container.Container,
) (*CraneControllerTerminal, *terminal.TerminalOpts, *terminal.Error) {
// Create Terminal Base.
t, initMsg, err := terminal.NewRemoteBaseTerminal(
crane.ctx,
0,
crane.ID,
nil,
initData,
terminal.UpstreamSendFunc(crane.sendImportantTerminalMsg),
)
if err != nil {
return nil, nil, err
}
return initCraneController(crane, t, initMsg), initMsg, nil
}
func initCraneController(
crane *Crane,
t *terminal.TerminalBase,
initMsg *terminal.TerminalOpts,
) *CraneControllerTerminal {
// Create Crane Terminal and assign it as the extended Terminal.
cct := &CraneControllerTerminal{
TerminalBase: t,
Crane: crane,
}
t.SetTerminalExtension(cct)
// Assign controller to crane.
crane.Controller = cct
crane.terminals[cct.ID()] = cct
// Copy the options to the crane itself.
crane.opts = *initMsg
// Grant crane controller permission.
t.GrantPermission(terminal.IsCraneController)
// Start workers.
t.StartWorkers(module, "crane controller terminal")
return cct
}
// HandleAbandon gives the terminal the ability to cleanly shut down.
func (controller *CraneControllerTerminal) HandleAbandon(err *terminal.Error) (errorToSend *terminal.Error) {
// Abandon terminal.
controller.Crane.AbandonTerminal(0, err)
return err
}
// HandleDestruction gives the terminal the ability to clean up.
func (controller *CraneControllerTerminal) HandleDestruction(err *terminal.Error) {
// Stop controlled crane.
controller.Crane.Stop(nil)
}

913
spn/docks/crane.go Normal file
View File

@@ -0,0 +1,913 @@
package docks
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/tevino/abool"
"github.com/safing/jess"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/log"
"github.com/safing/portbase/rng"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/ships"
"github.com/safing/portmaster/spn/terminal"
)
const (
// QOTD holds the quote of the day to return on idling unused connections.
QOTD = "Privacy is not an option, and it shouldn't be the price we accept for just getting on the Internet.\nGary Kovacs\n"
// maxUnloadSize defines the maximum size of a message to unload.
maxUnloadSize = 16384
maxSegmentLength = 16384
maxCraneStoppingDuration = 6 * time.Hour
maxCraneStopDuration = 10 * time.Second
)
var (
// optimalMinLoadSize defines minimum for Crane.targetLoadSize.
optimalMinLoadSize = 3072 // Targeting around 4096.
// loadingMaxWaitDuration is the maximum time a crane will wait for
// additional data to send.
loadingMaxWaitDuration = 5 * time.Millisecond
)
// Errors.
var (
ErrDone = errors.New("crane is done")
)
// Crane is the primary duplexer and connection manager.
type Crane struct {
// ID is the ID of the Crane.
ID string
// opts holds options.
opts terminal.TerminalOpts
// ctx is the context of the Terminal.
ctx context.Context
// cancelCtx cancels ctx.
cancelCtx context.CancelFunc
// stopping indicates if the Crane will be stopped soon. The Crane may still
// be used until stopped, but must not be advertised anymore.
stopping *abool.AtomicBool
// stopped indicates if the Crane has been stopped. Whoever stopped the Crane
// already took care of notifying everyone, so a silent fail is normally the
// best response.
stopped *abool.AtomicBool
// authenticated indicates if there is has been any successful authentication.
authenticated *abool.AtomicBool
// ConnectedHub is the identity of the remote Hub.
ConnectedHub *hub.Hub
// NetState holds the network optimization state.
// It must always be set and the reference must not be changed.
// Access to fields within are coordinated by itself.
NetState *NetworkOptimizationState
// identity is identity of this instance and is usually only populated on a server.
identity *cabin.Identity
// jession is the jess session used for encryption.
jession *jess.Session
// jessionLock locks jession.
jessionLock sync.Mutex
// Controller is the Crane's Controller Terminal.
Controller *CraneControllerTerminal
// ship represents the underlying physical connection.
ship ships.Ship
// unloading moves containers from the ship to the crane.
unloading chan *container.Container
// loading moves containers from the crane to the ship.
loading chan *container.Container
// terminalMsgs holds containers from terminals waiting to be laoded.
terminalMsgs chan *terminal.Msg
// controllerMsgs holds important containers from terminals waiting to be laoded.
controllerMsgs chan *terminal.Msg
// terminals holds all the connected terminals.
terminals map[uint32]terminal.Terminal
// terminalsLock locks terminals.
terminalsLock sync.Mutex
// nextTerminalID holds the next terminal ID.
nextTerminalID uint32
// targetLoadSize defines the optimal loading size.
targetLoadSize int
}
// NewCrane returns a new crane.
func NewCrane(ship ships.Ship, connectedHub *hub.Hub, id *cabin.Identity) (*Crane, error) {
// Cranes always run in module context.
ctx, cancelCtx := context.WithCancel(module.Ctx)
newCrane := &Crane{
ctx: ctx,
cancelCtx: cancelCtx,
stopping: abool.NewBool(false),
stopped: abool.NewBool(false),
authenticated: abool.NewBool(false),
ConnectedHub: connectedHub,
NetState: newNetworkOptimizationState(),
identity: id,
ship: ship,
unloading: make(chan *container.Container),
loading: make(chan *container.Container, 100),
terminalMsgs: make(chan *terminal.Msg, 100),
controllerMsgs: make(chan *terminal.Msg, 100),
terminals: make(map[uint32]terminal.Terminal),
}
err := registerCrane(newCrane)
if err != nil {
return nil, fmt.Errorf("failed to register crane: %w", err)
}
// Shift next terminal IDs on the server.
if !ship.IsMine() {
newCrane.nextTerminalID += 4
}
// Calculate target load size.
loadSize := ship.LoadSize()
if loadSize <= 0 {
loadSize = ships.BaseMTU
}
newCrane.targetLoadSize = loadSize
for newCrane.targetLoadSize < optimalMinLoadSize {
newCrane.targetLoadSize += loadSize
}
// Subtract overhead needed for encryption.
newCrane.targetLoadSize -= 25 // Manually tested for jess.SuiteWireV1
// Subtract space needed for length encoding the final chunk.
newCrane.targetLoadSize -= varint.EncodedSize(uint64(newCrane.targetLoadSize))
return newCrane, nil
}
// IsMine returns whether the crane was started on this side.
func (crane *Crane) IsMine() bool {
return crane.ship.IsMine()
}
// Public returns whether the crane has been published.
func (crane *Crane) Public() bool {
return crane.ship.Public()
}
// IsStopping returns whether the crane is stopping.
func (crane *Crane) IsStopping() bool {
return crane.stopping.IsSet()
}
// MarkStoppingRequested marks the crane as stopping requested.
func (crane *Crane) MarkStoppingRequested() {
crane.NetState.lock.Lock()
defer crane.NetState.lock.Unlock()
if !crane.NetState.stoppingRequested {
crane.NetState.stoppingRequested = true
crane.startSyncStateOp()
}
}
// MarkStopping marks the crane as stopping.
func (crane *Crane) MarkStopping() (stopping bool) {
// Can only stop owned cranes.
if !crane.IsMine() {
return false
}
if !crane.stopping.SetToIf(false, true) {
return false
}
crane.NetState.lock.Lock()
defer crane.NetState.lock.Unlock()
crane.NetState.markedStoppingAt = time.Now()
crane.startSyncStateOp()
return true
}
// AbortStopping aborts the stopping.
func (crane *Crane) AbortStopping() (aborted bool) {
aborted = crane.stopping.SetToIf(true, false)
crane.NetState.lock.Lock()
defer crane.NetState.lock.Unlock()
abortedStoppingRequest := crane.NetState.stoppingRequested
crane.NetState.stoppingRequested = false
crane.NetState.markedStoppingAt = time.Time{}
// Sync if any state changed.
if aborted || abortedStoppingRequest {
crane.startSyncStateOp()
}
return aborted
}
// Authenticated returns whether the other side of the crane has authenticated
// itself with an access code.
func (crane *Crane) Authenticated() bool {
return crane.authenticated.IsSet()
}
// Publish publishes the connection as a lane.
func (crane *Crane) Publish() error {
// Check if crane is connected.
if crane.ConnectedHub == nil {
return fmt.Errorf("spn/docks: %s: cannot publish without defined connected hub", crane)
}
// Submit metrics.
if !crane.Public() {
newPublicCranes.Inc()
}
// Mark crane as public.
maskedID := crane.ship.MaskAddress(crane.ship.RemoteAddr())
crane.ship.MarkPublic()
// Assign crane to make it available to others.
AssignCrane(crane.ConnectedHub.ID, crane)
log.Infof("spn/docks: %s (was %s) is now public", crane, maskedID)
return nil
}
// LocalAddr returns ship's local address.
func (crane *Crane) LocalAddr() net.Addr {
return crane.ship.LocalAddr()
}
// RemoteAddr returns ship's local address.
func (crane *Crane) RemoteAddr() net.Addr {
return crane.ship.RemoteAddr()
}
// Transport returns ship's transport.
func (crane *Crane) Transport() *hub.Transport {
return crane.ship.Transport()
}
func (crane *Crane) getNextTerminalID() uint32 {
crane.terminalsLock.Lock()
defer crane.terminalsLock.Unlock()
for {
// Bump to next ID.
crane.nextTerminalID += 8
// Check if it's free.
_, ok := crane.terminals[crane.nextTerminalID]
if !ok {
return crane.nextTerminalID
}
}
}
func (crane *Crane) terminalCount() int {
crane.terminalsLock.Lock()
defer crane.terminalsLock.Unlock()
return len(crane.terminals)
}
func (crane *Crane) getTerminal(id uint32) (t terminal.Terminal, ok bool) {
crane.terminalsLock.Lock()
defer crane.terminalsLock.Unlock()
t, ok = crane.terminals[id]
return
}
func (crane *Crane) setTerminal(t terminal.Terminal) {
crane.terminalsLock.Lock()
defer crane.terminalsLock.Unlock()
crane.terminals[t.ID()] = t
}
func (crane *Crane) deleteTerminal(id uint32) (t terminal.Terminal, ok bool) {
crane.terminalsLock.Lock()
defer crane.terminalsLock.Unlock()
t, ok = crane.terminals[id]
if ok {
delete(crane.terminals, id)
return t, true
}
return nil, false
}
// AbandonTerminal abandons the terminal with the given ID.
func (crane *Crane) AbandonTerminal(id uint32, err *terminal.Error) {
// Get active terminal.
t, ok := crane.deleteTerminal(id)
if ok {
// If the terminal was registered, abandon it.
// Log reason the terminal is ending. Override stopping error with nil.
switch {
case err == nil || err.IsOK():
log.Debugf("spn/docks: %T %s is being abandoned", t, t.FmtID())
case err.Is(terminal.ErrStopping):
err = nil
log.Debugf("spn/docks: %T %s is being abandoned by peer", t, t.FmtID())
case err.Is(terminal.ErrNoActivity):
err = nil
log.Debugf("spn/docks: %T %s is being abandoned due to no activity", t, t.FmtID())
default:
log.Warningf("spn/docks: %T %s: %s", t, t.FmtID(), err)
}
// Call the terminal's abandon function.
t.Abandon(err)
} else { //nolint:gocritic
// When a crane terminal is abandoned, it calls crane.AbandonTerminal when
// finished. This time, the terminal won't be in the registry anymore and
// it finished shutting down, so we can now check if the crane needs to be
// stopped.
// If the crane is stopping, check if we can stop.
// We can stop when all terminals are abandoned or after a timeout.
// FYI: The crane controller will always take up one slot.
if crane.stopping.IsSet() &&
crane.terminalCount() <= 1 {
// Stop the crane in worker, so the caller can do some work.
module.StartWorker("retire crane", func(_ context.Context) error {
// Let enough time for the last errors to be sent, as terminals are abandoned in a goroutine.
time.Sleep(3 * time.Second)
crane.Stop(nil)
return nil
})
}
}
}
func (crane *Crane) sendImportantTerminalMsg(msg *terminal.Msg, timeout time.Duration) *terminal.Error {
select {
case crane.controllerMsgs <- msg:
return nil
case <-crane.ctx.Done():
msg.Finish()
return terminal.ErrCanceled
}
}
// Send is used by others to send a message through the crane.
func (crane *Crane) Send(msg *terminal.Msg, timeout time.Duration) *terminal.Error {
select {
case crane.terminalMsgs <- msg:
return nil
case <-crane.ctx.Done():
msg.Finish()
return terminal.ErrCanceled
}
}
func (crane *Crane) encrypt(shipment *container.Container) (encrypted *container.Container, err error) {
// Skip if encryption is not enabled.
if crane.jession == nil {
return shipment, nil
}
crane.jessionLock.Lock()
defer crane.jessionLock.Unlock()
letter, err := crane.jession.Close(shipment.CompileData())
if err != nil {
return nil, err
}
encrypted, err = letter.ToWire()
if err != nil {
return nil, fmt.Errorf("failed to pack letter: %w", err)
}
return encrypted, nil
}
func (crane *Crane) decrypt(shipment *container.Container) (decrypted *container.Container, err error) {
// Skip if encryption is not enabled.
if crane.jession == nil {
return shipment, nil
}
crane.jessionLock.Lock()
defer crane.jessionLock.Unlock()
letter, err := jess.LetterFromWire(shipment)
if err != nil {
return nil, fmt.Errorf("failed to parse letter: %w", err)
}
decryptedData, err := crane.jession.Open(letter)
if err != nil {
return nil, err
}
return container.New(decryptedData), nil
}
func (crane *Crane) unloader(workerCtx context.Context) error {
// Unclean shutdown safeguard.
defer crane.Stop(terminal.ErrUnknownError.With("unloader died"))
for {
// Get first couple bytes to get the packet length.
// 2 bytes are enough to encode 65535.
// On the other hand, packets can be only 2 bytes small.
lenBuf := make([]byte, 2)
err := crane.unloadUntilFull(lenBuf)
if err != nil {
if errors.Is(err, io.EOF) {
crane.Stop(terminal.ErrStopping.With("connection closed"))
} else {
crane.Stop(terminal.ErrInternalError.With("failed to unload: %w", err))
}
return nil
}
// Unpack length.
containerLen, n, err := varint.Unpack64(lenBuf)
if err != nil {
crane.Stop(terminal.ErrMalformedData.With("failed to get container length: %w", err))
return nil
}
switch {
case containerLen <= 0:
crane.Stop(terminal.ErrMalformedData.With("received empty container with length %d", containerLen))
return nil
case containerLen > maxUnloadSize:
crane.Stop(terminal.ErrMalformedData.With("received oversized container with length %d", containerLen))
return nil
}
// Build shipment.
var shipmentBuf []byte
leftovers := len(lenBuf) - n
if leftovers == int(containerLen) {
// We already have all the shipment data.
shipmentBuf = lenBuf[n:]
} else {
// Create a shipment buffer, copy leftovers and read the rest from the connection.
shipmentBuf = make([]byte, containerLen)
if leftovers > 0 {
copy(shipmentBuf, lenBuf[n:])
}
// Read remaining shipment.
err = crane.unloadUntilFull(shipmentBuf[leftovers:])
if err != nil {
crane.Stop(terminal.ErrInternalError.With("failed to unload: %w", err))
return nil
}
}
// Submit to handler.
select {
case <-crane.ctx.Done():
crane.Stop(nil)
return nil
case crane.unloading <- container.New(shipmentBuf):
}
}
}
func (crane *Crane) unloadUntilFull(buf []byte) error {
var bytesRead int
for {
// Get shipment from ship.
n, err := crane.ship.UnloadTo(buf[bytesRead:])
if err != nil {
return err
}
if n == 0 {
log.Tracef("spn/docks: %s unloaded 0 bytes", crane)
}
bytesRead += n
// Return if buffer has been fully filled.
if bytesRead == len(buf) {
// Submit metrics.
crane.submitCraneTrafficStats(bytesRead)
crane.NetState.ReportTraffic(uint64(bytesRead), true)
return nil
}
}
}
func (crane *Crane) handler(workerCtx context.Context) error {
var partialShipment *container.Container
var segmentLength uint32
// Unclean shutdown safeguard.
defer crane.Stop(terminal.ErrUnknownError.With("handler died"))
handling:
for {
select {
case <-crane.ctx.Done():
crane.Stop(nil)
return nil
case shipment := <-crane.unloading:
// log.Debugf("spn/crane %s: before decrypt: %v ... %v", crane.ID, c.CompileData()[:10], c.CompileData()[c.Length()-10:])
// Decrypt shipment.
shipment, err := crane.decrypt(shipment)
if err != nil {
crane.Stop(terminal.ErrIntegrity.With("failed to decrypt: %w", err))
return nil
}
// Process all segments/containers of the shipment.
for shipment.HoldsData() {
if partialShipment != nil {
// Continue processing partial segment.
// Append new shipment to previous partial segment.
partialShipment.AppendContainer(shipment)
shipment, partialShipment = partialShipment, nil
}
// Get next segment length.
if segmentLength == 0 {
segmentLength, err = shipment.GetNextN32()
if err != nil {
if errors.Is(err, varint.ErrBufTooSmall) {
// Continue handling when there is not yet enough data.
partialShipment = shipment
segmentLength = 0
continue handling
}
crane.Stop(terminal.ErrMalformedData.With("failed to get segment length: %w", err))
return nil
}
if segmentLength == 0 {
// Remainder is padding.
continue handling
}
// Check if the segment is within the boundary.
if segmentLength > maxSegmentLength {
crane.Stop(terminal.ErrMalformedData.With("received oversized segment with length %d", segmentLength))
return nil
}
}
// Check if we have enough data for the segment.
if uint32(shipment.Length()) < segmentLength {
partialShipment = shipment
continue handling
}
// Get segment from shipment.
segment, err := shipment.GetAsContainer(int(segmentLength))
if err != nil {
crane.Stop(terminal.ErrMalformedData.With("failed to get segment: %w", err))
return nil
}
segmentLength = 0
// Get terminal ID and message type of segment.
terminalID, terminalMsgType, err := terminal.ParseIDType(segment)
if err != nil {
crane.Stop(terminal.ErrMalformedData.With("failed to get terminal ID and msg type: %w", err))
return nil
}
switch terminalMsgType {
case terminal.MsgTypeInit:
crane.establishTerminal(terminalID, segment)
case terminal.MsgTypeData, terminal.MsgTypePriorityData:
// Get terminal and let it further handle the message.
t, ok := crane.getTerminal(terminalID)
if ok {
// Create msg and set priority.
msg := terminal.NewEmptyMsg()
msg.FlowID = terminalID
msg.Type = terminalMsgType
msg.Data = segment
if msg.Type == terminal.MsgTypePriorityData {
msg.Unit.MakeHighPriority()
}
// Deliver to terminal.
deliveryErr := t.Deliver(msg)
if deliveryErr != nil {
msg.Finish()
// This is a hot path. Start a worker for abandoning the terminal.
module.StartWorker("end terminal", func(_ context.Context) error {
crane.AbandonTerminal(t.ID(), deliveryErr.Wrap("failed to deliver data"))
return nil
})
}
} else {
log.Tracef("spn/docks: %s received msg for unknown terminal %d", crane, terminalID)
}
case terminal.MsgTypeStop:
// Parse error.
receivedErr, err := terminal.ParseExternalError(segment.CompileData())
if err != nil {
log.Warningf("spn/docks: %s failed to parse abandon error: %s", crane, err)
receivedErr = terminal.ErrUnknownError.AsExternal()
}
// This is a hot path. Start a worker for abandoning the terminal.
module.StartWorker("end terminal", func(_ context.Context) error {
crane.AbandonTerminal(terminalID, receivedErr)
return nil
})
}
}
}
}
}
func (crane *Crane) loader(workerCtx context.Context) (err error) {
shipment := container.New()
var partialShipment *container.Container
var loadingTimer *time.Timer
// Unclean shutdown safeguard.
defer crane.Stop(terminal.ErrUnknownError.With("loader died"))
// Return the loading wait channel if waiting.
loadNow := func() <-chan time.Time {
if loadingTimer != nil {
return loadingTimer.C
}
return nil
}
// Make sure any received message is finished
var msg, firstMsg *terminal.Msg
defer msg.Finish()
defer firstMsg.Finish()
for {
// Reset first message in shipment.
firstMsg.Finish()
firstMsg = nil
fillingShipment:
for shipment.Length() < crane.targetLoadSize {
// Gather segments until shipment is filled.
// Prioritize messages from the controller.
select {
case msg = <-crane.controllerMsgs:
case <-crane.ctx.Done():
crane.Stop(nil)
return nil
default:
// Then listen for all.
select {
case msg = <-crane.controllerMsgs:
case msg = <-crane.terminalMsgs:
case <-loadNow():
break fillingShipment
case <-crane.ctx.Done():
crane.Stop(nil)
return nil
}
}
// Debug unit leaks.
msg.Debug()
// Handle new message.
if msg != nil {
// Pack msg and add to segment.
msg.Pack()
newSegment := msg.Data
// Check if this is the first message.
// This is the only message where we wait for a slot.
if firstMsg == nil {
firstMsg = msg
firstMsg.Unit.WaitForSlot()
} else {
msg.Finish()
}
// Check length.
if newSegment.Length() > maxSegmentLength {
log.Warningf("spn/docks: %s ignored oversized segment with length %d", crane, newSegment.Length())
continue fillingShipment
}
// Append to shipment.
shipment.AppendContainer(newSegment)
// Set loading max wait timer on first segment.
if loadingTimer == nil {
loadingTimer = time.NewTimer(loadingMaxWaitDuration)
}
} else if crane.stopped.IsSet() {
// If there is no new segment, this might have been triggered by a
// closed channel. Check if the crane is still active.
return nil
}
}
sendingShipment:
for {
// Check if we are over the target load size and split the shipment.
if shipment.Length() > crane.targetLoadSize {
partialShipment, err = shipment.GetAsContainer(crane.targetLoadSize)
if err != nil {
crane.Stop(terminal.ErrInternalError.With("failed to split segment: %w", err))
return nil
}
shipment, partialShipment = partialShipment, shipment
}
// Load shipment.
err = crane.load(shipment)
if err != nil {
crane.Stop(terminal.ErrShipSunk.With("failed to load shipment: %w", err))
return nil
}
// Reset loading timer.
loadingTimer = nil
// Continue loading with partial shipment, or a new one.
if partialShipment != nil {
// Continue loading with a partial previous shipment.
shipment, partialShipment = partialShipment, nil
// If shipment is not big enough to send immediately, wait for more data.
if shipment.Length() < crane.targetLoadSize {
loadingTimer = time.NewTimer(loadingMaxWaitDuration)
break sendingShipment
}
} else {
// Continue loading with new shipment.
shipment = container.New()
break sendingShipment
}
}
}
}
func (crane *Crane) load(c *container.Container) error {
// Add Padding if needed.
if crane.opts.Padding > 0 {
paddingNeeded := int(crane.opts.Padding) -
((c.Length() + varint.EncodedSize(uint64(c.Length()))) % int(crane.opts.Padding))
// As the length changes slightly with the padding, we should avoid loading
// lengths around the varint size hops:
// - 128
// - 16384
// - 2097152
// - 268435456
// Pad to target load size at maximum.
maxPadding := crane.targetLoadSize - c.Length()
if paddingNeeded > maxPadding {
paddingNeeded = maxPadding
}
if paddingNeeded > 0 {
// Add padding indicator.
c.Append([]byte{0})
paddingNeeded--
// Add needed padding data.
if paddingNeeded > 0 {
padding, err := rng.Bytes(paddingNeeded)
if err != nil {
log.Debugf("spn/docks: %s failed to get random padding data, using zeros instead", crane)
padding = make([]byte, paddingNeeded)
}
c.Append(padding)
}
}
}
// Encrypt shipment.
c, err := crane.encrypt(c)
if err != nil {
return fmt.Errorf("failed to encrypt: %w", err)
}
// Finalize data.
c.PrependLength()
readyToSend := c.CompileData()
// Submit metrics.
crane.submitCraneTrafficStats(len(readyToSend))
crane.NetState.ReportTraffic(uint64(len(readyToSend)), false)
// Load onto ship.
err = crane.ship.Load(readyToSend)
if err != nil {
return fmt.Errorf("failed to load ship: %w", err)
}
return nil
}
// Stop stops the crane.
func (crane *Crane) Stop(err *terminal.Error) {
if !crane.stopped.SetToIf(false, true) {
return
}
// Log error message.
if err != nil {
if err.IsOK() {
log.Infof("spn/docks: %s is done", crane)
} else {
log.Warningf("spn/docks: %s is stopping: %s", crane, err)
}
}
// Unregister crane.
unregisterCrane(crane)
// Stop all terminals.
for _, t := range crane.allTerms() {
t.Abandon(err) // Async!
}
// Stop controller.
if crane.Controller != nil {
crane.Controller.Abandon(err) // Async!
}
// Wait shortly for all terminals to finish abandoning.
waitStep := 50 * time.Millisecond
for i := time.Duration(0); i < maxCraneStopDuration; i += waitStep {
// Check if all terminals are done.
if crane.terminalCount() == 0 {
break
}
time.Sleep(waitStep)
}
// Close connection.
crane.ship.Sink()
// Cancel crane context.
crane.cancelCtx()
// Notify about change.
crane.NotifyUpdate()
}
func (crane *Crane) allTerms() []terminal.Terminal {
crane.terminalsLock.Lock()
defer crane.terminalsLock.Unlock()
terms := make([]terminal.Terminal, 0, len(crane.terminals))
for _, term := range crane.terminals {
terms = append(terms, term)
}
return terms
}
func (crane *Crane) String() string {
remoteAddr := crane.ship.RemoteAddr()
switch {
case remoteAddr == nil:
return fmt.Sprintf("crane %s", crane.ID)
case crane.ship.IsMine():
return fmt.Sprintf("crane %s to %s", crane.ID, crane.ship.MaskAddress(crane.ship.RemoteAddr()))
default:
return fmt.Sprintf("crane %s from %s", crane.ID, crane.ship.MaskAddress(crane.ship.RemoteAddr()))
}
}
// Stopped returns whether the crane has stopped.
func (crane *Crane) Stopped() bool {
return crane.stopped.IsSet()
}

View File

@@ -0,0 +1,81 @@
package docks
import (
"context"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/terminal"
)
const (
defaultTerminalIdleTimeout = 15 * time.Minute
remoteTerminalIdleTimeout = 30 * time.Minute
)
// EstablishNewTerminal establishes a new terminal with the crane.
func (crane *Crane) EstablishNewTerminal(
localTerm terminal.Terminal,
initData *container.Container,
) *terminal.Error {
// Create message.
msg := terminal.NewEmptyMsg()
msg.FlowID = localTerm.ID()
msg.Type = terminal.MsgTypeInit
msg.Data = initData
// Register terminal with crane.
crane.setTerminal(localTerm)
// Send message.
select {
case crane.controllerMsgs <- msg:
log.Debugf("spn/docks: %s initiated new terminal %d", crane, localTerm.ID())
return nil
case <-crane.ctx.Done():
crane.AbandonTerminal(localTerm.ID(), terminal.ErrStopping.With("initiation aborted"))
return terminal.ErrStopping
}
}
func (crane *Crane) establishTerminal(id uint32, initData *container.Container) {
// Create new remote crane terminal.
newTerminal, _, err := NewRemoteCraneTerminal(
crane,
id,
initData,
)
if err == nil {
// Connections via public cranes have a timeout.
if crane.Public() {
newTerminal.TerminalBase.SetTimeout(remoteTerminalIdleTimeout)
}
// Register terminal with crane.
crane.setTerminal(newTerminal)
log.Debugf("spn/docks: %s established new crane terminal %d", crane, newTerminal.ID())
return
}
// If something goes wrong, send an error back.
log.Warningf("spn/docks: %s failed to establish crane terminal: %s", crane, err)
// Build abandon message.
msg := terminal.NewMsg(err.Pack())
msg.FlowID = id
msg.Type = terminal.MsgTypeStop
// Send message directly, or async.
select {
case crane.terminalMsgs <- msg:
default:
// Send error async.
module.StartWorker("abandon terminal", func(ctx context.Context) error {
select {
case crane.terminalMsgs <- msg:
case <-ctx.Done():
}
return nil
})
}
}

339
spn/docks/crane_init.go Normal file
View File

@@ -0,0 +1,339 @@
package docks
import (
"context"
"time"
"github.com/safing/jess"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/info"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/terminal"
)
/*
Crane Init Message Format:
used by init procedures
- Data [bytes block]
- MsgType [varint]
- Data [bytes; only when MsgType is Verify or Start*]
Crane Init Response Format:
- Data [bytes block]
Crane Operational Message Format:
- Data [bytes block]
- possibly encrypted
*/
// Crane Msg Types.
const (
CraneMsgTypeEnd = 0
CraneMsgTypeInfo = 1
CraneMsgTypeRequestHubInfo = 2
CraneMsgTypeVerify = 3
CraneMsgTypeStartEncrypted = 4
CraneMsgTypeStartUnencrypted = 5
)
// Start starts the crane.
func (crane *Crane) Start(callerCtx context.Context) error {
log.Infof("spn/docks: %s is starting", crane)
// Submit metrics.
newCranes.Inc()
// Start crane depending on situation.
var tErr *terminal.Error
if crane.ship.IsMine() {
tErr = crane.startLocal(callerCtx)
} else {
tErr = crane.startRemote(callerCtx)
}
// Stop crane again if starting failed.
if tErr != nil {
crane.Stop(tErr)
return tErr
}
log.Debugf("spn/docks: %s started", crane)
// Return an explicit nil for working "!= nil" checks.
return nil
}
func (crane *Crane) startLocal(callerCtx context.Context) *terminal.Error {
module.StartWorker("crane unloader", crane.unloader)
if !crane.ship.IsSecure() {
// Start encrypted channel.
// Check if we have all the data we need from the Hub.
if crane.ConnectedHub == nil {
return terminal.ErrIncorrectUsage.With("cannot start encrypted channel without connected hub")
}
// Always request hub info, as we don't know if the hub has restarted in
// the meantime and lost ephemeral keys.
hubInfoRequest := container.New(
varint.Pack8(CraneMsgTypeRequestHubInfo),
)
hubInfoRequest.PrependLength()
err := crane.ship.Load(hubInfoRequest.CompileData())
if err != nil {
return terminal.ErrShipSunk.With("failed to request hub info: %w", err)
}
// Wait for reply.
var reply *container.Container
select {
case reply = <-crane.unloading:
case <-time.After(30 * time.Second):
return terminal.ErrTimeout.With("waiting for hub info")
case <-crane.ctx.Done():
return terminal.ErrShipSunk.With("waiting for hub info")
case <-callerCtx.Done():
return terminal.ErrCanceled.With("waiting for hub info")
}
// Parse and import Announcement and Status.
announcementData, err := reply.GetNextBlock()
if err != nil {
return terminal.ErrMalformedData.With("failed to get announcement: %w", err)
}
statusData, err := reply.GetNextBlock()
if err != nil {
return terminal.ErrMalformedData.With("failed to get status: %w", err)
}
h, _, tErr := ImportAndVerifyHubInfo(
callerCtx,
crane.ConnectedHub.ID,
announcementData, statusData, conf.MainMapName, conf.MainMapScope,
)
if tErr != nil {
return tErr.Wrap("failed to import and verify hub")
}
// Update reference in case it was changed by the import.
crane.ConnectedHub = h
// Now, try to select a public key again.
signet := crane.ConnectedHub.SelectSignet()
if signet == nil {
return terminal.ErrHubNotReady.With("failed to select signet (after updating hub info)")
}
// Configure encryption.
env := jess.NewUnconfiguredEnvelope()
env.SuiteID = jess.SuiteWireV1
env.Recipients = []*jess.Signet{signet}
// Do not encrypt directly, rather get session for future use, then encrypt.
crane.jession, err = env.WireCorrespondence(nil)
if err != nil {
return terminal.ErrInternalError.With("failed to create encryption session: %w", err)
}
}
// Create crane controller.
_, initData, tErr := NewLocalCraneControllerTerminal(crane, terminal.DefaultCraneControllerOpts())
if tErr != nil {
return tErr.Wrap("failed to set up controller")
}
// Prepare init message for sending.
if crane.ship.IsSecure() {
initData.PrependNumber(CraneMsgTypeStartUnencrypted)
} else {
// Encrypt controller initializer.
letter, err := crane.jession.Close(initData.CompileData())
if err != nil {
return terminal.ErrInternalError.With("failed to encrypt initial packet: %w", err)
}
initData, err = letter.ToWire()
if err != nil {
return terminal.ErrInternalError.With("failed to pack initial packet: %w", err)
}
initData.PrependNumber(CraneMsgTypeStartEncrypted)
}
// Send start message.
initData.PrependLength()
err := crane.ship.Load(initData.CompileData())
if err != nil {
return terminal.ErrShipSunk.With("failed to send init msg: %w", err)
}
// Start remaining workers.
module.StartWorker("crane loader", crane.loader)
module.StartWorker("crane handler", crane.handler)
return nil
}
func (crane *Crane) startRemote(callerCtx context.Context) *terminal.Error {
var initMsg *container.Container
module.StartWorker("crane unloader", crane.unloader)
handling:
for {
// Wait for request.
var request *container.Container
select {
case request = <-crane.unloading:
case <-time.After(30 * time.Second):
return terminal.ErrTimeout.With("waiting for crane init msg")
case <-crane.ctx.Done():
return terminal.ErrShipSunk.With("waiting for crane init msg")
case <-callerCtx.Done():
return terminal.ErrCanceled.With("waiting for crane init msg")
}
msgType, err := request.GetNextN8()
if err != nil {
return terminal.ErrMalformedData.With("failed to parse crane msg type: %s", err)
}
switch msgType {
case CraneMsgTypeEnd:
// End connection.
return terminal.ErrStopping
case CraneMsgTypeInfo:
// Info is a terminating request.
err := crane.handleCraneInfo()
if err != nil {
return err
}
log.Debugf("spn/docks: %s sent version info", crane)
case CraneMsgTypeRequestHubInfo:
// Handle Hub info request.
err := crane.handleCraneHubInfo()
if err != nil {
return err
}
log.Debugf("spn/docks: %s sent hub info", crane)
case CraneMsgTypeVerify:
// Verify is a terminating request.
err := crane.handleCraneVerification(request)
if err != nil {
return err
}
log.Infof("spn/docks: %s sent hub verification", crane)
case CraneMsgTypeStartUnencrypted:
initMsg = request
// Start crane with initMsg.
log.Debugf("spn/docks: %s initiated unencrypted channel", crane)
break handling
case CraneMsgTypeStartEncrypted:
if crane.identity == nil {
return terminal.ErrIncorrectUsage.With("cannot start incoming crane without designated identity")
}
// Set up encryption.
letter, err := jess.LetterFromWire(container.New(request.CompileData()))
if err != nil {
return terminal.ErrMalformedData.With("failed to unpack initial packet: %w", err)
}
crane.jession, err = letter.WireCorrespondence(crane.identity)
if err != nil {
return terminal.ErrInternalError.With("failed to create encryption session: %w", err)
}
initMsgData, err := crane.jession.Open(letter)
if err != nil {
return terminal.ErrIntegrity.With("failed to decrypt initial packet: %w", err)
}
initMsg = container.New(initMsgData)
// Start crane with initMsg.
log.Debugf("spn/docks: %s initiated encrypted channel", crane)
break handling
}
}
_, _, err := NewRemoteCraneControllerTerminal(crane, initMsg)
if err != nil {
return err.Wrap("failed to start crane controller")
}
// Start remaining workers.
module.StartWorker("crane loader", crane.loader)
module.StartWorker("crane handler", crane.handler)
return nil
}
func (crane *Crane) endInit() *terminal.Error {
endMsg := container.New(
varint.Pack8(CraneMsgTypeEnd),
)
endMsg.PrependLength()
err := crane.ship.Load(endMsg.CompileData())
if err != nil {
return terminal.ErrShipSunk.With("failed to send end msg: %w", err)
}
return nil
}
func (crane *Crane) handleCraneInfo() *terminal.Error {
// Pack info data.
infoData, err := dsd.Dump(info.GetInfo(), dsd.JSON)
if err != nil {
return terminal.ErrInternalError.With("failed to pack info: %w", err)
}
msg := container.New(infoData)
// Manually send reply.
msg.PrependLength()
err = crane.ship.Load(msg.CompileData())
if err != nil {
return terminal.ErrShipSunk.With("failed to send info reply: %w", err)
}
return nil
}
func (crane *Crane) handleCraneHubInfo() *terminal.Error {
msg := container.New()
// Check if we have an identity.
if crane.identity == nil {
return terminal.ErrIncorrectUsage.With("cannot handle hub info request without designated identity")
}
// Add Hub Announcement.
announcementData, err := crane.identity.ExportAnnouncement()
if err != nil {
return terminal.ErrInternalError.With("failed to export announcement: %w", err)
}
msg.AppendAsBlock(announcementData)
// Add Hub Status.
statusData, err := crane.identity.ExportStatus()
if err != nil {
return terminal.ErrInternalError.With("failed to export status: %w", err)
}
msg.AppendAsBlock(statusData)
// Manually send reply.
msg.PrependLength()
err = crane.ship.Load(msg.CompileData())
if err != nil {
return terminal.ErrShipSunk.With("failed to send hub info reply: %w", err)
}
return nil
}

131
spn/docks/crane_netstate.go Normal file
View File

@@ -0,0 +1,131 @@
package docks
import (
"sync"
"sync/atomic"
"time"
)
// NetStatePeriodInterval defines the interval some of the net state should be reset.
const NetStatePeriodInterval = 15 * time.Minute
// NetworkOptimizationState holds data for optimization purposes.
type NetworkOptimizationState struct {
lock sync.Mutex
// lastSuggestedAt holds the time when the connection to the connected Hub was last suggested by the network optimization.
lastSuggestedAt time.Time
// stoppingRequested signifies whether stopping this lane is requested.
stoppingRequested bool
// stoppingRequestedByPeer signifies whether stopping this lane is requested by the peer.
stoppingRequestedByPeer bool
// markedStoppingAt holds the time when the crane was last marked as stopping.
markedStoppingAt time.Time
lifetimeBytesIn *uint64
lifetimeBytesOut *uint64
lifetimeStarted time.Time
periodBytesIn *uint64
periodBytesOut *uint64
periodStarted time.Time
}
func newNetworkOptimizationState() *NetworkOptimizationState {
return &NetworkOptimizationState{
lifetimeBytesIn: new(uint64),
lifetimeBytesOut: new(uint64),
lifetimeStarted: time.Now(),
periodBytesIn: new(uint64),
periodBytesOut: new(uint64),
periodStarted: time.Now(),
}
}
// UpdateLastSuggestedAt sets when the lane was last suggested to the current time.
func (netState *NetworkOptimizationState) UpdateLastSuggestedAt() {
netState.lock.Lock()
defer netState.lock.Unlock()
netState.lastSuggestedAt = time.Now()
}
// StoppingState returns when the stopping state.
func (netState *NetworkOptimizationState) StoppingState() (requested, requestedByPeer bool, markedAt time.Time) {
netState.lock.Lock()
defer netState.lock.Unlock()
return netState.stoppingRequested, netState.stoppingRequestedByPeer, netState.markedStoppingAt
}
// RequestStoppingSuggested returns whether the crane should request stopping.
func (netState *NetworkOptimizationState) RequestStoppingSuggested(maxNotSuggestedDuration time.Duration) bool {
netState.lock.Lock()
defer netState.lock.Unlock()
return time.Now().Add(-maxNotSuggestedDuration).After(netState.lastSuggestedAt)
}
// StoppingSuggested returns whether the crane should be marked as stopping.
func (netState *NetworkOptimizationState) StoppingSuggested() bool {
netState.lock.Lock()
defer netState.lock.Unlock()
return netState.stoppingRequested &&
netState.stoppingRequestedByPeer
}
// StopSuggested returns whether the crane should be stopped.
func (netState *NetworkOptimizationState) StopSuggested() bool {
netState.lock.Lock()
defer netState.lock.Unlock()
return netState.stoppingRequested &&
netState.stoppingRequestedByPeer &&
!netState.markedStoppingAt.IsZero() &&
time.Now().Add(-maxCraneStoppingDuration).After(netState.markedStoppingAt)
}
// ReportTraffic adds the reported transferred data to the traffic stats.
func (netState *NetworkOptimizationState) ReportTraffic(bytes uint64, in bool) {
if in {
atomic.AddUint64(netState.lifetimeBytesIn, bytes)
atomic.AddUint64(netState.periodBytesIn, bytes)
} else {
atomic.AddUint64(netState.lifetimeBytesOut, bytes)
atomic.AddUint64(netState.periodBytesOut, bytes)
}
}
// LapsePeriod lapses the net state period, if needed.
func (netState *NetworkOptimizationState) LapsePeriod() {
netState.lock.Lock()
defer netState.lock.Unlock()
// Reset period if interval elapsed.
if time.Now().Add(-NetStatePeriodInterval).After(netState.periodStarted) {
atomic.StoreUint64(netState.periodBytesIn, 0)
atomic.StoreUint64(netState.periodBytesOut, 0)
netState.periodStarted = time.Now()
}
}
// GetTrafficStats returns the traffic stats.
func (netState *NetworkOptimizationState) GetTrafficStats() (
lifetimeBytesIn uint64,
lifetimeBytesOut uint64,
lifetimeStarted time.Time,
periodBytesIn uint64,
periodBytesOut uint64,
periodStarted time.Time,
) {
netState.lock.Lock()
defer netState.lock.Unlock()
return atomic.LoadUint64(netState.lifetimeBytesIn),
atomic.LoadUint64(netState.lifetimeBytesOut),
netState.lifetimeStarted,
atomic.LoadUint64(netState.periodBytesIn),
atomic.LoadUint64(netState.periodBytesOut),
netState.periodStarted
}

122
spn/docks/crane_terminal.go Normal file
View File

@@ -0,0 +1,122 @@
package docks
import (
"net"
"github.com/safing/portbase/container"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/terminal"
)
// CraneTerminal is a terminal started by a crane.
type CraneTerminal struct {
*terminal.TerminalBase
// Add-Ons
terminal.SessionAddOn
crane *Crane
}
// NewLocalCraneTerminal returns a new local crane terminal.
func NewLocalCraneTerminal(
crane *Crane,
remoteHub *hub.Hub,
initMsg *terminal.TerminalOpts,
) (*CraneTerminal, *container.Container, *terminal.Error) {
// Create Terminal Base.
t, initData, err := terminal.NewLocalBaseTerminal(
crane.ctx,
crane.getNextTerminalID(),
crane.ID,
remoteHub,
initMsg,
crane,
)
if err != nil {
return nil, nil, err
}
return initCraneTerminal(crane, t), initData, nil
}
// NewRemoteCraneTerminal returns a new remote crane terminal.
func NewRemoteCraneTerminal(
crane *Crane,
id uint32,
initData *container.Container,
) (*CraneTerminal, *terminal.TerminalOpts, *terminal.Error) {
// Create Terminal Base.
t, initMsg, err := terminal.NewRemoteBaseTerminal(
crane.ctx,
id,
crane.ID,
crane.identity,
initData,
crane,
)
if err != nil {
return nil, nil, err
}
return initCraneTerminal(crane, t), initMsg, nil
}
func initCraneTerminal(
crane *Crane,
t *terminal.TerminalBase,
) *CraneTerminal {
// Create Crane Terminal and assign it as the extended Terminal.
ct := &CraneTerminal{
TerminalBase: t,
crane: crane,
}
t.SetTerminalExtension(ct)
// Start workers.
t.StartWorkers(module, "crane terminal")
return ct
}
// GrantPermission grants the given permissions.
// Additionally, it will mark the crane as authenticated, if not public.
func (t *CraneTerminal) GrantPermission(grant terminal.Permission) {
// Forward granted permission to base terminal.
t.TerminalBase.GrantPermission(grant)
// Mark crane as authenticated if not public or already authenticated.
if !t.crane.Public() && !t.crane.Authenticated() {
t.crane.authenticated.Set()
// Submit metrics.
newAuthenticatedCranes.Inc()
}
}
// LocalAddr returns the crane's local address.
func (t *CraneTerminal) LocalAddr() net.Addr {
return t.crane.LocalAddr()
}
// RemoteAddr returns the crane's remote address.
func (t *CraneTerminal) RemoteAddr() net.Addr {
return t.crane.RemoteAddr()
}
// Transport returns the crane's transport.
func (t *CraneTerminal) Transport() *hub.Transport {
return t.crane.Transport()
}
// IsBeingAbandoned returns whether the terminal is being abandoned.
func (t *CraneTerminal) IsBeingAbandoned() bool {
return t.Abandoning.IsSet()
}
// HandleDestruction gives the terminal the ability to clean up.
// The terminal has already fully shut down at this point.
// Should never be called directly. Call Abandon() instead.
func (t *CraneTerminal) HandleDestruction(err *terminal.Error) {
t.crane.AbandonTerminal(t.ID(), err)
}

267
spn/docks/crane_test.go Normal file
View File

@@ -0,0 +1,267 @@
package docks
import (
"context"
"fmt"
"os"
"runtime/pprof"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/ships"
"github.com/safing/portmaster/spn/terminal"
)
func TestCraneCommunication(t *testing.T) {
t.Parallel()
testCraneWithCounter(t, "plain-counter-load-100", false, 100, 1000)
testCraneWithCounter(t, "plain-counter-load-1000", false, 1000, 1000)
testCraneWithCounter(t, "plain-counter-load-10000", false, 10000, 1000)
testCraneWithCounter(t, "encrypted-counter", true, 1000, 1000)
}
func testCraneWithCounter(t *testing.T, testID string, encrypting bool, loadSize int, countTo uint64) { //nolint:unparam,thelper
var identity *cabin.Identity
var connectedHub *hub.Hub
if encrypting {
identity, connectedHub = getTestIdentity(t)
}
// Build ship and cranes.
optimalMinLoadSize = loadSize * 2
ship := ships.NewTestShip(!encrypting, loadSize)
var crane1, crane2 *Crane
var craneWg sync.WaitGroup
craneWg.Add(2)
go func() {
var err error
crane1, err = NewCrane(ship, connectedHub, nil)
if err != nil {
panic(fmt.Sprintf("crane test %s could not create crane1: %s", testID, err))
}
err = crane1.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("crane test %s could not start crane1: %s", testID, err))
}
craneWg.Done()
}()
go func() {
var err error
crane2, err = NewCrane(ship.Reverse(), nil, identity)
if err != nil {
panic(fmt.Sprintf("crane test %s could not create crane2: %s", testID, err))
}
err = crane2.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("crane test %s could not start crane2: %s", testID, err))
}
craneWg.Done()
}()
craneWg.Wait()
t.Logf("crane test %s setup complete", testID)
// Wait async for test to complete, print stack after timeout.
finished := make(chan struct{})
go func() {
select {
case <-finished:
case <-time.After(10 * time.Second):
t.Logf("crane test %s is taking too long, print stack:", testID)
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
os.Exit(1)
}
}()
t.Logf("crane1 controller: %+v", crane1.Controller)
t.Logf("crane2 controller: %+v", crane2.Controller)
// Start counters for testing.
op1, tErr := terminal.NewCounterOp(crane1.Controller, terminal.CounterOpts{
ClientCountTo: countTo,
ServerCountTo: countTo,
})
if tErr != nil {
t.Fatalf("crane test %s failed to run counter op: %s", testID, tErr)
}
// Wait for completion.
op1.Wait()
close(finished)
// Wait a little so that all errors can be propagated, so we can truly see
// if we succeeded.
time.Sleep(1 * time.Second)
// Check errors.
if op1.Error != nil {
t.Fatalf("crane test %s counter op1 failed: %s", testID, op1.Error)
}
}
type StreamingTerminal struct {
terminal.BareTerminal
test *testing.T
id uint32
crane *Crane
recv chan *terminal.Msg
testData []byte
}
func (t *StreamingTerminal) ID() uint32 {
return t.id
}
func (t *StreamingTerminal) Ctx() context.Context {
return module.Ctx
}
func (t *StreamingTerminal) Deliver(msg *terminal.Msg) *terminal.Error {
t.recv <- msg
msg.Finish()
return nil
}
func (t *StreamingTerminal) Abandon(err *terminal.Error) {
t.crane.AbandonTerminal(t.ID(), err)
if err != nil {
t.test.Errorf("streaming terminal %d failed: %s", t.id, err)
}
}
func (t *StreamingTerminal) FmtID() string {
return fmt.Sprintf("test-%d", t.id)
}
func TestCraneLoadingUnloading(t *testing.T) {
t.Parallel()
testCraneWithStreaming(t, "plain-streaming", false, 100)
testCraneWithStreaming(t, "encrypted-streaming", true, 100)
}
func testCraneWithStreaming(t *testing.T, testID string, encrypting bool, loadSize int) { //nolint:thelper
var identity *cabin.Identity
var connectedHub *hub.Hub
if encrypting {
identity, connectedHub = getTestIdentity(t)
}
// Build ship and cranes.
optimalMinLoadSize = loadSize * 2
ship := ships.NewTestShip(!encrypting, loadSize)
var crane1, crane2 *Crane
var craneWg sync.WaitGroup
craneWg.Add(2)
go func() {
var err error
crane1, err = NewCrane(ship, connectedHub, nil)
if err != nil {
panic(fmt.Sprintf("crane test %s could not create crane1: %s", testID, err))
}
err = crane1.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("crane test %s could not start crane1: %s", testID, err))
}
craneWg.Done()
}()
go func() {
var err error
crane2, err = NewCrane(ship.Reverse(), nil, identity)
if err != nil {
panic(fmt.Sprintf("crane test %s could not create crane2: %s", testID, err))
}
err = crane2.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("crane test %s could not start crane2: %s", testID, err))
}
craneWg.Done()
}()
craneWg.Wait()
t.Logf("crane test %s setup complete", testID)
// Wait async for test to complete, print stack after timeout.
finished := make(chan struct{})
go func() {
select {
case <-finished:
case <-time.After(10 * time.Second):
t.Logf("crane test %s is taking too long, print stack:", testID)
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
os.Exit(1)
}
}()
t.Logf("crane1 controller: %+v", crane1.Controller)
t.Logf("crane2 controller: %+v", crane2.Controller)
// Create terminals and run test.
st := &StreamingTerminal{
test: t,
id: 8,
crane: crane2,
recv: make(chan *terminal.Msg),
testData: []byte("The quick brown fox jumps over the lazy dog."),
}
crane2.terminals[st.ID()] = st
// Run streaming test.
var streamingWg sync.WaitGroup
streamingWg.Add(2)
count := 10000
go func() {
for i := 1; i <= count; i++ {
msg := terminal.NewMsg(st.testData)
msg.FlowID = st.id
err := crane1.Send(msg, 1*time.Second)
if err != nil {
msg.Finish()
crane1.Stop(err.Wrap("failed to submit terminal msg"))
}
// log.Tracef("spn/testing: + %d", i)
}
t.Logf("crane test %s done with sending", testID)
streamingWg.Done()
}()
go func() {
for i := 1; i <= count; i++ {
msg := <-st.recv
assert.Equal(t, st.testData, msg.Data.CompileData(), "data mismatched")
// log.Tracef("spn/testing: - %d", i)
}
t.Logf("crane test %s done with receiving", testID)
streamingWg.Done()
}()
// Wait for completion.
streamingWg.Wait()
close(finished)
}
var testIdentity *cabin.Identity
func getTestIdentity(t *testing.T) (*cabin.Identity, *hub.Hub) {
t.Helper()
if testIdentity == nil {
var err error
testIdentity, err = cabin.CreateIdentity(module.Ctx, "test")
if err != nil {
t.Fatalf("failed to create identity: %s", err)
}
}
return testIdentity, testIdentity.Hub
}

85
spn/docks/crane_verify.go Normal file
View File

@@ -0,0 +1,85 @@
package docks
import (
"context"
"errors"
"fmt"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/terminal"
)
const (
hubVerificationPurpose = "hub identify verification"
)
// VerifyConnectedHub verifies the connected Hub.
func (crane *Crane) VerifyConnectedHub(callerCtx context.Context) error {
if !crane.ship.IsMine() || crane.nextTerminalID != 0 || crane.Public() {
return errors.New("hub verification can only be executed in init phase by the client")
}
// Create verification request.
v, request, err := cabin.CreateVerificationRequest(hubVerificationPurpose, "", "")
if err != nil {
return fmt.Errorf("failed to create verification request: %w", err)
}
// Send it.
msg := container.New(
varint.Pack8(CraneMsgTypeVerify),
request,
)
msg.PrependLength()
err = crane.ship.Load(msg.CompileData())
if err != nil {
return terminal.ErrShipSunk.With("failed to send verification request: %w", err)
}
// Wait for reply.
var reply *container.Container
select {
case reply = <-crane.unloading:
case <-time.After(2 * time.Minute):
// Use a big timeout here, as this might keep servers from joining the
// network at all, as every servers needs to verify every server, no
// matter how far away.
return terminal.ErrTimeout.With("waiting for verification reply")
case <-crane.ctx.Done():
return terminal.ErrShipSunk.With("waiting for verification reply")
case <-callerCtx.Done():
return terminal.ErrShipSunk.With("waiting for verification reply")
}
// Verify reply.
return v.Verify(reply.CompileData(), crane.ConnectedHub)
}
func (crane *Crane) handleCraneVerification(request *container.Container) *terminal.Error {
// Check if we have an identity.
if crane.identity == nil {
return terminal.ErrIncorrectUsage.With("cannot handle verification request without designated identity")
}
response, err := crane.identity.SignVerificationRequest(
request.CompileData(),
hubVerificationPurpose,
"", "",
)
if err != nil {
return terminal.ErrInternalError.With("failed to sign verification request: %w", err)
}
msg := container.New(response)
// Manually send reply.
msg.PrependLength()
err = crane.ship.Load(msg.CompileData())
if err != nil {
return terminal.ErrShipSunk.With("failed to send verification reply: %w", err)
}
return nil
}

46
spn/docks/cranehooks.go Normal file
View File

@@ -0,0 +1,46 @@
package docks
import (
"sync"
"github.com/safing/portbase/log"
)
var (
craneUpdateHook func(crane *Crane)
craneUpdateHookLock sync.Mutex
)
// RegisterCraneUpdateHook allows the captain to hook into receiving updates for cranes.
func RegisterCraneUpdateHook(fn func(crane *Crane)) {
craneUpdateHookLock.Lock()
defer craneUpdateHookLock.Unlock()
if craneUpdateHook == nil {
craneUpdateHook = fn
} else {
log.Error("spn/docks: crane update hook already registered")
}
}
// ResetCraneUpdateHook resets the hook for receiving updates for cranes.
func ResetCraneUpdateHook() {
craneUpdateHookLock.Lock()
defer craneUpdateHookLock.Unlock()
craneUpdateHook = nil
}
// NotifyUpdate calls the registers crane update hook function.
func (crane *Crane) NotifyUpdate() {
if crane == nil {
return
}
craneUpdateHookLock.Lock()
defer craneUpdateHookLock.Unlock()
if craneUpdateHook != nil {
craneUpdateHook(crane)
}
}

189
spn/docks/hub_import.go Normal file
View File

@@ -0,0 +1,189 @@
package docks
import (
"context"
"fmt"
"net"
"sync"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/ships"
"github.com/safing/portmaster/spn/terminal"
)
var hubImportLock sync.Mutex
// ImportAndVerifyHubInfo imports the given hub message and verifies them.
func ImportAndVerifyHubInfo(ctx context.Context, hubID string, announcementData, statusData []byte, mapName string, scope hub.Scope) (h *hub.Hub, forward bool, tErr *terminal.Error) {
var firstErr *terminal.Error
// Synchronize import, as we might easily learn of a new hub from different
// gossip channels simultaneously.
hubImportLock.Lock()
defer hubImportLock.Unlock()
// Check arguments.
if announcementData == nil && statusData == nil {
return nil, false, terminal.ErrInternalError.With("no announcement or status supplied")
}
// Import Announcement, if given.
var hubKnown, hubChanged bool
if announcementData != nil {
hubFromMsg, known, changed, err := hub.ApplyAnnouncement(nil, announcementData, mapName, scope, false)
if err != nil && firstErr == nil {
firstErr = terminal.ErrInternalError.With("failed to apply announcement: %w", err)
}
if known {
hubKnown = true
}
if changed {
hubChanged = true
}
if hubFromMsg != nil {
h = hubFromMsg
}
}
// Import Status, if given.
if statusData != nil {
hubFromMsg, known, changed, err := hub.ApplyStatus(h, statusData, mapName, scope, false)
if err != nil && firstErr == nil {
firstErr = terminal.ErrInternalError.With("failed to apply status: %w", err)
}
if known && announcementData == nil {
// If we parsed an announcement before, "known" will always be true here,
// as we supply hub.ApplyStatus with a hub.
hubKnown = true
}
if changed {
hubChanged = true
}
if hubFromMsg != nil {
h = hubFromMsg
}
}
// Only continue if we now have a Hub.
if h == nil {
if firstErr != nil {
return nil, false, firstErr
}
return nil, false, terminal.ErrInternalError.With("got not hub after data import")
}
// Abort if the given hub ID does not match.
// We may have just connected to the wrong IP address.
if hubID != "" && h.ID != hubID {
return nil, false, terminal.ErrInternalError.With("hub mismatch")
}
// Verify hub if:
// - There is no error up until here.
// - There has been any change.
// - The hub is not verified yet.
// - We're a public Hub.
// - We're not testing.
if firstErr == nil && hubChanged && !h.Verified() && conf.PublicHub() && !runningTests {
if !conf.HubHasIPv4() && !conf.HubHasIPv6() {
firstErr = terminal.ErrInternalError.With("no hub networks set")
}
if h.Info.IPv4 != nil && conf.HubHasIPv4() {
err := verifyHubIP(ctx, h, h.Info.IPv4)
if err != nil {
firstErr = terminal.ErrIntegrity.With("failed to verify IPv4 address %s of %s: %w", h.Info.IPv4, h, err)
}
}
if h.Info.IPv6 != nil && conf.HubHasIPv6() {
err := verifyHubIP(ctx, h, h.Info.IPv6)
if err != nil {
firstErr = terminal.ErrIntegrity.With("failed to verify IPv6 address %s of %s: %w", h.Info.IPv6, h, err)
}
}
if firstErr != nil {
func() {
h.Lock()
defer h.Unlock()
h.InvalidInfo = true
}()
log.Warningf("spn/docks: failed to verify IPs of %s: %s", h, firstErr)
} else {
func() {
h.Lock()
defer h.Unlock()
h.VerifiedIPs = true
}()
log.Infof("spn/docks: verified IPs of %s: IPv4=%s IPv6=%s", h, h.Info.IPv4, h.Info.IPv6)
}
}
// Dismiss initial imports with errors.
if !hubKnown && firstErr != nil {
return nil, false, firstErr
}
// Don't do anything if nothing changed.
if !hubChanged {
return h, false, firstErr
}
// We now have one of:
// - A unknown Hub without error.
// - A known Hub without error.
// - A known Hub with error, which we want to save and propagate.
// Save the Hub to the database.
err := h.Save()
if err != nil {
log.Errorf("spn/docks: failed to persist %s: %s", h, err)
}
// Save the raw messages to the database.
if announcementData != nil {
err = hub.SaveHubMsg(h.ID, h.Map, hub.MsgTypeAnnouncement, announcementData)
if err != nil {
log.Errorf("spn/docks: failed to save raw announcement msg of %s: %s", h, err)
}
}
if statusData != nil {
err = hub.SaveHubMsg(h.ID, h.Map, hub.MsgTypeStatus, statusData)
if err != nil {
log.Errorf("spn/docks: failed to save raw status msg of %s: %s", h, err)
}
}
return h, true, firstErr
}
func verifyHubIP(ctx context.Context, h *hub.Hub, ip net.IP) error {
// Create connection.
ship, err := ships.Launch(ctx, h, nil, ip)
if err != nil {
return fmt.Errorf("failed to launch ship to %s: %w", ip, err)
}
// Start crane for receiving reply.
crane, err := NewCrane(ship, h, nil)
if err != nil {
return fmt.Errorf("failed to create crane: %w", err)
}
module.StartWorker("crane unloader", crane.unloader)
defer crane.Stop(nil)
// Verify Hub.
err = crane.VerifyConnectedHub(ctx)
if err != nil {
return err
}
// End connection.
tErr := crane.endInit()
if tErr != nil {
log.Debugf("spn/docks: failed to end verification connection to %s: %s", ip, tErr)
}
return nil
}

108
spn/docks/measurements.go Normal file
View File

@@ -0,0 +1,108 @@
package docks
import (
"context"
"fmt"
"time"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/ships"
"github.com/safing/portmaster/spn/terminal"
)
// Measurement Configuration.
const (
CraneMeasurementTTLDefault = 30 * time.Minute
CraneMeasurementTTLByCostBase = 1 * time.Minute
CraneMeasurementTTLByCostMin = 30 * time.Minute
CraneMeasurementTTLByCostMax = 3 * time.Hour
// With a base TTL of 1m, this leads to:
// 20c -> 20m -> raised to 30m
// 50c -> 50m
// 100c -> 1h40m
// 1000c -> 16h40m -> capped to 3h.
)
// MeasureHub measures the connection to this Hub and saves the results to the
// Hub.
func MeasureHub(ctx context.Context, h *hub.Hub, checkExpiryWith time.Duration) *terminal.Error {
// Check if we are measuring before building a connection.
if capacityTestRunning.IsSet() {
return terminal.ErrTryAgainLater.With("another capacity op is already running")
}
// Check if we have a connection to this Hub.
crane := GetAssignedCrane(h.ID)
if crane == nil {
// Connect to Hub.
var err error
crane, err = establishCraneForMeasuring(ctx, h)
if err != nil {
return terminal.ErrConnectionError.With("failed to connect to %s: %s", h, err)
}
// Stop crane if established just for measuring.
defer crane.Stop(nil)
}
// Run latency test.
_, expires := h.GetMeasurements().GetLatency()
if checkExpiryWith == 0 || time.Now().Add(-checkExpiryWith).After(expires) {
latOp, tErr := NewLatencyTestOp(crane.Controller)
if !tErr.IsOK() {
return tErr
}
select {
case tErr = <-latOp.Result():
if !tErr.IsOK() {
return tErr
}
case <-ctx.Done():
return terminal.ErrCanceled
case <-time.After(1 * time.Minute):
crane.Controller.StopOperation(latOp, terminal.ErrTimeout)
return terminal.ErrTimeout.With("waiting for latency test")
}
}
// Run capacity test.
_, expires = h.GetMeasurements().GetCapacity()
if checkExpiryWith == 0 || time.Now().Add(-checkExpiryWith).After(expires) {
capOp, tErr := NewCapacityTestOp(crane.Controller, nil)
if !tErr.IsOK() {
return tErr
}
select {
case tErr = <-capOp.Result():
if !tErr.IsOK() {
return tErr
}
case <-ctx.Done():
return terminal.ErrCanceled
case <-time.After(1 * time.Minute):
crane.Controller.StopOperation(capOp, terminal.ErrTimeout)
return terminal.ErrTimeout.With("waiting for capacity test")
}
}
return nil
}
func establishCraneForMeasuring(ctx context.Context, dst *hub.Hub) (*Crane, error) {
ship, err := ships.Launch(ctx, dst, nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to launch ship: %w", err)
}
crane, err := NewCrane(ship, dst, nil)
if err != nil {
return nil, fmt.Errorf("failed to create crane: %w", err)
}
err = crane.Start(ctx)
if err != nil {
return nil, fmt.Errorf("failed to start crane: %w", err)
}
return crane, nil
}

404
spn/docks/metrics.go Normal file
View File

@@ -0,0 +1,404 @@
package docks
import (
"sync"
"sync/atomic"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/api"
"github.com/safing/portbase/metrics"
)
var (
newCranes *metrics.Counter
newPublicCranes *metrics.Counter
newAuthenticatedCranes *metrics.Counter
trafficBytesPublicCranes *metrics.Counter
trafficBytesAuthenticatedCranes *metrics.Counter
trafficBytesPrivateCranes *metrics.Counter
newExpandOp *metrics.Counter
expandOpDurationHistogram *metrics.Histogram
expandOpRelayedDataHistogram *metrics.Histogram
metricsRegistered = abool.New()
)
func registerMetrics() (err error) {
// Only register metrics once.
if !metricsRegistered.SetToIf(false, true) {
return nil
}
// Total Crane Stats.
newCranes, err = metrics.NewCounter(
"spn/cranes/total",
nil,
&metrics.Options{
Name: "SPN New Cranes",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
newPublicCranes, err = metrics.NewCounter(
"spn/cranes/public/total",
nil,
&metrics.Options{
Name: "SPN New Public Cranes",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
newAuthenticatedCranes, err = metrics.NewCounter(
"spn/cranes/authenticated/total",
nil,
&metrics.Options{
Name: "SPN New Authenticated Cranes",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
// Active Crane Stats.
_, err = metrics.NewGauge(
"spn/cranes/active",
map[string]string{
"status": "public",
},
getActivePublicCranes,
&metrics.Options{
Name: "SPN Active Public Cranes",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/cranes/active",
map[string]string{
"status": "authenticated",
},
getActiveAuthenticatedCranes,
&metrics.Options{
Name: "SPN Active Authenticated Cranes",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/cranes/active",
map[string]string{
"status": "private",
},
getActivePrivateCranes,
&metrics.Options{
Name: "SPN Active Private Cranes",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/cranes/active",
map[string]string{
"status": "stopping",
},
getActiveStoppingCranes,
&metrics.Options{
Name: "SPN Active Stopping Cranes",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
// Crane Traffic Stats.
trafficBytesPublicCranes, err = metrics.NewCounter(
"spn/cranes/bytes",
map[string]string{
"status": "public",
},
&metrics.Options{
Name: "SPN Public Crane Traffic",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
trafficBytesAuthenticatedCranes, err = metrics.NewCounter(
"spn/cranes/bytes",
map[string]string{
"status": "authenticated",
},
&metrics.Options{
Name: "SPN Authenticated Crane Traffic",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
trafficBytesPrivateCranes, err = metrics.NewCounter(
"spn/cranes/bytes",
map[string]string{
"status": "private",
},
&metrics.Options{
Name: "SPN Private Crane Traffic",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
// Lane Stats.
_, err = metrics.NewGauge(
"spn/lanes/latency/avg/seconds",
nil,
getAvgLaneLatencyStat,
&metrics.Options{
Name: "SPN Avg Lane Latency",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/lanes/latency/min/seconds",
nil,
getMinLaneLatencyStat,
&metrics.Options{
Name: "SPN Min Lane Latency",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/lanes/capacity/avg/bytes",
nil,
getAvgLaneCapacityStat,
&metrics.Options{
Name: "SPN Avg Lane Capacity",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/lanes/capacity/max/bytes",
nil,
getMaxLaneCapacityStat,
&metrics.Options{
Name: "SPN Max Lane Capacity",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
// Expand Op Stats.
newExpandOp, err = metrics.NewCounter(
"spn/op/expand/total",
nil,
&metrics.Options{
Name: "SPN Total Expand Operations",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
_, err = metrics.NewGauge(
"spn/op/expand/active",
nil,
getActiveExpandOpsStat,
&metrics.Options{
Name: "SPN Active Expand Operations",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
expandOpDurationHistogram, err = metrics.NewHistogram(
"spn/op/expand/histogram/duration/seconds",
nil,
&metrics.Options{
Name: "SPN Expand Operation Duration Histogram",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
expandOpRelayedDataHistogram, err = metrics.NewHistogram(
"spn/op/expand/histogram/traffic/bytes",
nil,
&metrics.Options{
Name: "SPN Expand Operation Relayed Data Histogram",
Permission: api.PermitUser,
},
)
if err != nil {
return err
}
return err
}
func getActiveExpandOpsStat() float64 {
return float64(atomic.LoadInt64(activeExpandOps))
}
var (
craneStats *craneGauges
craneStatsExpires time.Time
craneStatsLock sync.Mutex
craneStatsTTL = 55 * time.Second
)
type craneGauges struct {
publicActive float64
authenticatedActive float64
privateActive float64
stoppingActive float64
laneLatencyAvg float64
laneLatencyMin float64
laneCapacityAvg float64
laneCapacityMax float64
}
func getActivePublicCranes() float64 { return getCraneStats().publicActive }
func getActiveAuthenticatedCranes() float64 { return getCraneStats().authenticatedActive }
func getActivePrivateCranes() float64 { return getCraneStats().privateActive }
func getActiveStoppingCranes() float64 { return getCraneStats().stoppingActive }
func getAvgLaneLatencyStat() float64 { return getCraneStats().laneLatencyAvg }
func getMinLaneLatencyStat() float64 { return getCraneStats().laneLatencyMin }
func getAvgLaneCapacityStat() float64 { return getCraneStats().laneCapacityAvg }
func getMaxLaneCapacityStat() float64 { return getCraneStats().laneCapacityMax }
func getCraneStats() *craneGauges {
craneStatsLock.Lock()
defer craneStatsLock.Unlock()
// Return cache if still valid.
if time.Now().Before(craneStatsExpires) {
return craneStats
}
// Refresh.
craneStats = &craneGauges{}
var laneStatCnt float64
for _, crane := range getAllCranes() {
switch {
case crane.Stopped():
continue
case crane.IsStopping():
craneStats.stoppingActive++
continue
case crane.Public():
craneStats.publicActive++
case crane.Authenticated():
craneStats.authenticatedActive++
continue
default:
craneStats.privateActive++
continue
}
// Get lane stats.
if crane.ConnectedHub == nil {
continue
}
measurements := crane.ConnectedHub.GetMeasurements()
laneLatency, _ := measurements.GetLatency()
if laneLatency == 0 {
continue
}
laneCapacity, _ := measurements.GetCapacity()
if laneCapacity == 0 {
continue
}
// Only use data if both latency and capacity is available.
laneStatCnt++
// Convert to base unit: seconds.
latency := laneLatency.Seconds()
// Add to avg and set min if lower.
craneStats.laneLatencyAvg += latency
if craneStats.laneLatencyMin > latency || craneStats.laneLatencyMin == 0 {
craneStats.laneLatencyMin = latency
}
// Convert in base unit: bytes.
capacity := float64(laneCapacity) / 8
// Add to avg and set max if higher.
craneStats.laneCapacityAvg += capacity
if craneStats.laneCapacityMax < capacity {
craneStats.laneCapacityMax = capacity
}
}
// Create averages.
if laneStatCnt > 0 {
craneStats.laneLatencyAvg /= laneStatCnt
craneStats.laneCapacityAvg /= laneStatCnt
}
craneStatsExpires = time.Now().Add(craneStatsTTL)
return craneStats
}
func (crane *Crane) submitCraneTrafficStats(bytes int) {
switch {
case crane.Stopped():
return
case crane.Public():
trafficBytesPublicCranes.Add(bytes)
case crane.Authenticated():
trafficBytesAuthenticatedCranes.Add(bytes)
default:
trafficBytesPrivateCranes.Add(bytes)
}
}

117
spn/docks/module.go Normal file
View File

@@ -0,0 +1,117 @@
package docks
import (
"encoding/hex"
"errors"
"fmt"
"sync"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/rng"
_ "github.com/safing/portmaster/spn/access"
)
var (
module *modules.Module
allCranes = make(map[string]*Crane) // ID = Crane ID
assignedCranes = make(map[string]*Crane) // ID = connected Hub ID
cranesLock sync.RWMutex
runningTests bool
)
func init() {
module = modules.Register("docks", nil, start, stopAllCranes, "terminal", "cabin", "access")
}
func start() error {
return registerMetrics()
}
func registerCrane(crane *Crane) error {
cranesLock.Lock()
defer cranesLock.Unlock()
// Generate new IDs until a unique one is found.
for i := 0; i < 100; i++ {
// Generate random ID.
randomID, err := rng.Bytes(3)
if err != nil {
return fmt.Errorf("failed to generate crane ID: %w", err)
}
newID := hex.EncodeToString(randomID)
// Check if ID already exists.
_, ok := allCranes[newID]
if !ok {
crane.ID = newID
allCranes[crane.ID] = crane
return nil
}
}
return errors.New("failed to find unique crane ID")
}
func unregisterCrane(crane *Crane) {
cranesLock.Lock()
defer cranesLock.Unlock()
delete(allCranes, crane.ID)
if crane.ConnectedHub != nil {
delete(assignedCranes, crane.ConnectedHub.ID)
}
}
func stopAllCranes() error {
for _, crane := range getAllCranes() {
crane.Stop(nil)
}
return nil
}
// AssignCrane assigns a crane to the given Hub ID.
func AssignCrane(hubID string, crane *Crane) {
cranesLock.Lock()
defer cranesLock.Unlock()
assignedCranes[hubID] = crane
}
// GetAssignedCrane returns the assigned crane of the given Hub ID.
func GetAssignedCrane(hubID string) *Crane {
cranesLock.RLock()
defer cranesLock.RUnlock()
crane, ok := assignedCranes[hubID]
if ok {
return crane
}
return nil
}
func getAllCranes() map[string]*Crane {
copiedCranes := make(map[string]*Crane, len(allCranes))
cranesLock.RLock()
defer cranesLock.RUnlock()
for id, crane := range allCranes {
copiedCranes[id] = crane
}
return copiedCranes
}
// GetAllAssignedCranes returns a copy of the map of all assigned cranes.
func GetAllAssignedCranes() map[string]*Crane {
copiedCranes := make(map[string]*Crane, len(assignedCranes))
cranesLock.RLock()
defer cranesLock.RUnlock()
for destination, crane := range assignedCranes {
copiedCranes[destination] = crane
}
return copiedCranes
}

16
spn/docks/module_test.go Normal file
View File

@@ -0,0 +1,16 @@
package docks
import (
"testing"
"github.com/safing/portmaster/service/core/pmtesting"
"github.com/safing/portmaster/spn/access"
"github.com/safing/portmaster/spn/conf"
)
func TestMain(m *testing.M) {
runningTests = true
conf.EnablePublicHub(true) // Make hub config available.
access.EnableTestMode() // Register test zone instead of real ones.
pmtesting.TestMain(m, module)
}

356
spn/docks/op_capacity.go Normal file
View File

@@ -0,0 +1,356 @@
package docks
import (
"bytes"
"context"
"sync/atomic"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/spn/terminal"
)
const (
// CapacityTestOpType is the type ID of the capacity test operation.
CapacityTestOpType = "capacity"
defaultCapacityTestVolume = 50000000 // 50MB
maxCapacityTestVolume = 100000000 // 100MB
defaultCapacityTestMaxTime = 5 * time.Second
maxCapacityTestMaxTime = 15 * time.Second
capacityTestTimeout = 30 * time.Second
capacityTestMsgSize = 1000
capacityTestSendTimeout = 1000 * time.Millisecond
)
var (
capacityTestSendData = make([]byte, capacityTestMsgSize)
capacityTestDataReceivedSignal = []byte("ACK")
capacityTestRunning = abool.New()
)
// CapacityTestOp is used for capacity test operations.
type CapacityTestOp struct { //nolint:maligned
terminal.OperationBase
opts *CapacityTestOptions
started bool
startTime time.Time
senderStarted bool
recvQueue chan *terminal.Msg
dataReceived int
dataReceivedAckWasAckd bool
dataSent *int64
dataSentWasAckd *abool.AtomicBool
testResult int
result chan *terminal.Error
}
// CapacityTestOptions holds options for the capacity test.
type CapacityTestOptions struct {
TestVolume int
MaxTime time.Duration
testing bool
}
// Type returns the type ID.
func (op *CapacityTestOp) Type() string {
return CapacityTestOpType
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: CapacityTestOpType,
Requires: terminal.IsCraneController,
Start: startCapacityTestOp,
})
}
// NewCapacityTestOp runs a capacity test.
func NewCapacityTestOp(t terminal.Terminal, opts *CapacityTestOptions) (*CapacityTestOp, *terminal.Error) {
// Check options.
if opts == nil {
opts = &CapacityTestOptions{
TestVolume: defaultCapacityTestVolume,
MaxTime: defaultCapacityTestMaxTime,
}
}
// Check if another test is already running.
if !opts.testing && !capacityTestRunning.SetToIf(false, true) {
return nil, terminal.ErrTryAgainLater.With("another capacity op is already running")
}
// Create and init.
op := &CapacityTestOp{
opts: opts,
recvQueue: make(chan *terminal.Msg),
dataSent: new(int64),
dataSentWasAckd: abool.New(),
result: make(chan *terminal.Error, 1),
}
// Make capacity test request.
request, err := dsd.Dump(op.opts, dsd.CBOR)
if err != nil {
capacityTestRunning.UnSet()
return nil, terminal.ErrInternalError.With("failed to serialize capactity test options: %w", err)
}
// Send test request.
tErr := t.StartOperation(op, container.New(request), 1*time.Second)
if tErr != nil {
capacityTestRunning.UnSet()
return nil, tErr
}
// Start handler.
module.StartWorker("op capacity handler", op.handler)
return op, nil
}
func startCapacityTestOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Check if another test is already running.
if !capacityTestRunning.SetToIf(false, true) {
return nil, terminal.ErrTryAgainLater.With("another capacity op is already running")
}
// Parse options.
opts := &CapacityTestOptions{}
_, err := dsd.Load(data.CompileData(), opts)
if err != nil {
capacityTestRunning.UnSet()
return nil, terminal.ErrMalformedData.With("failed to parse options: %w", err)
}
// Check options.
if opts.TestVolume > maxCapacityTestVolume {
capacityTestRunning.UnSet()
return nil, terminal.ErrInvalidOptions.With("maximum volume exceeded")
}
if opts.MaxTime > maxCapacityTestMaxTime {
capacityTestRunning.UnSet()
return nil, terminal.ErrInvalidOptions.With("maximum maxtime exceeded")
}
// Create operation.
op := &CapacityTestOp{
opts: opts,
recvQueue: make(chan *terminal.Msg, 1000),
dataSent: new(int64),
dataSentWasAckd: abool.New(),
result: make(chan *terminal.Error, 1),
}
op.InitOperationBase(t, opID)
// Start handler and sender.
op.senderStarted = true
module.StartWorker("op capacity handler", op.handler)
module.StartWorker("op capacity sender", op.sender)
return op, nil
}
func (op *CapacityTestOp) handler(ctx context.Context) error {
defer capacityTestRunning.UnSet()
returnErr := terminal.ErrStopping
defer func() {
// Linters don't get that returnErr is used when directly used as defer.
op.Stop(op, returnErr)
}()
var maxTestTimeReached <-chan time.Time
opTimeout := time.After(capacityTestTimeout)
// Setup unit handling
var msg *terminal.Msg
defer msg.Finish()
// Handle receives.
for {
msg.Finish()
select {
case <-ctx.Done():
returnErr = terminal.ErrCanceled
return nil
case <-opTimeout:
returnErr = terminal.ErrTimeout
return nil
case <-maxTestTimeReached:
returnErr = op.reportMeasuredCapacity()
return nil
case msg = <-op.recvQueue:
// Record start time and start sender.
if !op.started {
op.started = true
op.startTime = time.Now()
maxTestTimeReached = time.After(op.opts.MaxTime)
if !op.senderStarted {
op.senderStarted = true
module.StartWorker("op capacity sender", op.sender)
}
}
// Add to received data counter.
op.dataReceived += msg.Data.Length()
// Check if we received the data received signal.
if msg.Data.Length() == len(capacityTestDataReceivedSignal) &&
bytes.Equal(msg.Data.CompileData(), capacityTestDataReceivedSignal) {
op.dataSentWasAckd.Set()
}
// Send the data received signal when we received the full test volume.
if op.dataReceived >= op.opts.TestVolume && !op.dataReceivedAckWasAckd {
tErr := op.Send(op.NewMsg(capacityTestDataReceivedSignal), capacityTestSendTimeout)
if tErr != nil {
returnErr = tErr.Wrap("failed to send data received signal")
return nil
}
atomic.AddInt64(op.dataSent, int64(len(capacityTestDataReceivedSignal)))
op.dataReceivedAckWasAckd = true
// Flush last message.
op.Flush(10 * time.Second)
}
// Check if we can complete the test.
if op.dataReceivedAckWasAckd &&
op.dataSentWasAckd.IsSet() {
returnErr = op.reportMeasuredCapacity()
return nil
}
}
}
}
func (op *CapacityTestOp) sender(ctx context.Context) error {
for {
// Send next chunk.
msg := op.NewMsg(capacityTestSendData)
msg.Unit.MakeHighPriority()
tErr := op.Send(msg, capacityTestSendTimeout)
if tErr != nil {
op.Stop(op, tErr.Wrap("failed to send capacity test data"))
return nil
}
// Add to sent data counter and stop sending if sending is complete.
if atomic.AddInt64(op.dataSent, int64(len(capacityTestSendData))) >= int64(op.opts.TestVolume) {
return nil
}
// Check if we have received an ack.
if op.dataSentWasAckd.IsSet() {
return nil
}
// Check if op has ended.
if op.Stopped() {
return nil
}
}
}
func (op *CapacityTestOp) reportMeasuredCapacity() *terminal.Error {
// Calculate lane capacity and set it.
timeNeeded := time.Since(op.startTime)
if timeNeeded <= 0 {
timeNeeded = 1
}
duplexBits := float64((int64(op.dataReceived) + atomic.LoadInt64(op.dataSent)) * 8)
duplexNSBitRate := duplexBits / float64(timeNeeded)
bitRate := (duplexNSBitRate / 2) * float64(time.Second)
op.testResult = int(bitRate)
// Save the result to the crane.
if controller, ok := op.Terminal().(*CraneControllerTerminal); ok {
if controller.Crane.ConnectedHub != nil {
controller.Crane.ConnectedHub.GetMeasurements().SetCapacity(op.testResult)
log.Infof(
"docks: measured capacity to %s: %.2f Mbit/s (%.2fMB down / %.2fMB up in %s)",
controller.Crane.ConnectedHub,
float64(op.testResult)/1000000,
float64(op.dataReceived)/1000000,
float64(atomic.LoadInt64(op.dataSent))/1000000,
timeNeeded,
)
return nil
} else if controller.Crane.IsMine() {
return terminal.ErrInternalError.With("capacity operation was run on %s without a connected hub set", controller.Crane)
}
} else if !runningTests {
return terminal.ErrInternalError.With("capacity operation was run on terminal that is not a crane controller, but %T", op.Terminal())
}
return nil
}
// Deliver delivers a message.
func (op *CapacityTestOp) Deliver(msg *terminal.Msg) *terminal.Error {
// Optimized delivery with 1s timeout.
select {
case op.recvQueue <- msg:
default:
select {
case op.recvQueue <- msg:
case <-time.After(1 * time.Second):
msg.Finish()
return terminal.ErrTimeout
}
}
return nil
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *CapacityTestOp) HandleStop(tErr *terminal.Error) (errorToSend *terminal.Error) {
// Return result to waiting routine.
select {
case op.result <- tErr:
default:
}
// Drain the recvQueue to finish the message units.
drain:
for {
select {
case msg := <-op.recvQueue:
msg.Finish()
default:
select {
case msg := <-op.recvQueue:
msg.Finish()
case <-time.After(3 * time.Millisecond):
// Give some additional time buffer to drain the queue.
break drain
}
}
}
// Return error as is.
return tErr
}
// Result returns the result (end error) of the operation.
func (op *CapacityTestOp) Result() <-chan *terminal.Error {
return op.result
}

View File

@@ -0,0 +1,85 @@
package docks
import (
"testing"
"time"
"github.com/safing/portmaster/spn/terminal"
)
var (
testCapacityTestVolume = 1_000_000
testCapacitytestMaxTime = 1 * time.Second
)
func TestCapacityOp(t *testing.T) { //nolint:paralleltest // Performance test.
// Defaults.
testCapacityOp(t, &CapacityTestOptions{
TestVolume: testCapacityTestVolume,
MaxTime: testCapacitytestMaxTime,
testing: true,
})
// Hit max time first.
testCapacityOp(t, &CapacityTestOptions{
TestVolume: testCapacityTestVolume,
MaxTime: 100 * time.Millisecond,
testing: true,
})
// Hit volume first.
testCapacityOp(t, &CapacityTestOptions{
TestVolume: 100_000,
MaxTime: testCapacitytestMaxTime,
testing: true,
})
}
func testCapacityOp(t *testing.T, opts *CapacityTestOptions) {
t.Helper()
var (
capTestDelay = 5 * time.Millisecond
capTestQueueSize uint32 = 10
)
// Create test terminal pair.
a, b, err := terminal.NewSimpleTestTerminalPair(
capTestDelay,
int(capTestQueueSize),
&terminal.TerminalOpts{
FlowControl: terminal.FlowControlDFQ,
FlowControlSize: capTestQueueSize,
},
)
if err != nil {
t.Fatalf("failed to create test terminal pair: %s", err)
}
// Grant permission for op on remote terminal and start op.
b.GrantPermission(terminal.IsCraneController)
op, tErr := NewCapacityTestOp(a, opts)
if tErr != nil {
t.Fatalf("failed to start op: %s", err)
}
// Wait for result and check error.
tErr = <-op.Result()
if !tErr.IsOK() {
t.Fatalf("op failed: %s", tErr)
}
t.Logf("measured capacity: %d bit/s", op.testResult)
// Calculate expected bandwidth.
expectedBitsPerSecond := float64(capacityTestMsgSize*8*int64(capTestQueueSize)) / float64(capTestDelay) * float64(time.Second)
t.Logf("expected capacity: %f bit/s", expectedBitsPerSecond)
// Check if measured bandwidth is within parameters.
if float64(op.testResult) > expectedBitsPerSecond*1.6 {
t.Fatal("measured capacity too high")
}
// TODO: Check if we can raise this to at least 90%.
if float64(op.testResult) < expectedBitsPerSecond*0.2 {
t.Fatal("measured capacity too low")
}
}

393
spn/docks/op_expand.go Normal file
View File

@@ -0,0 +1,393 @@
package docks
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/container"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/terminal"
)
// ExpandOpType is the type ID of the expand operation.
const ExpandOpType string = "expand"
var activeExpandOps = new(int64)
// ExpandOp is used to expand to another Hub.
type ExpandOp struct {
terminal.OperationBase
opts *terminal.TerminalOpts
// ctx is the context of the Terminal.
ctx context.Context
// cancelCtx cancels ctx.
cancelCtx context.CancelFunc
dataRelayed *uint64
ended *abool.AtomicBool
relayTerminal *ExpansionRelayTerminal
// flowControl holds the flow control system.
flowControl terminal.FlowControl
// deliverProxy is populated with the configured deliver function
deliverProxy func(msg *terminal.Msg) *terminal.Error
// recvProxy is populated with the configured recv function
recvProxy func() <-chan *terminal.Msg
// sendProxy is populated with the configured send function
sendProxy func(msg *terminal.Msg, timeout time.Duration)
}
// ExpansionRelayTerminal is a relay used for expansion.
type ExpansionRelayTerminal struct {
terminal.BareTerminal
op *ExpandOp
id uint32
crane *Crane
abandoning *abool.AtomicBool
// flowControl holds the flow control system.
flowControl terminal.FlowControl
// deliverProxy is populated with the configured deliver function
deliverProxy func(msg *terminal.Msg) *terminal.Error
// recvProxy is populated with the configured recv function
recvProxy func() <-chan *terminal.Msg
// sendProxy is populated with the configured send function
sendProxy func(msg *terminal.Msg, timeout time.Duration)
}
// Type returns the type ID.
func (op *ExpandOp) Type() string {
return ExpandOpType
}
// ID returns the operation ID.
func (t *ExpansionRelayTerminal) ID() uint32 {
return t.id
}
// Ctx returns the operation context.
func (op *ExpandOp) Ctx() context.Context {
return op.ctx
}
// Ctx returns the relay terminal context.
func (t *ExpansionRelayTerminal) Ctx() context.Context {
return t.op.ctx
}
// Deliver delivers a message to the relay operation.
func (op *ExpandOp) Deliver(msg *terminal.Msg) *terminal.Error {
return op.deliverProxy(msg)
}
// Deliver delivers a message to the relay terminal.
func (t *ExpansionRelayTerminal) Deliver(msg *terminal.Msg) *terminal.Error {
return t.deliverProxy(msg)
}
// Flush writes all data in the queues.
func (op *ExpandOp) Flush(timeout time.Duration) {
if op.flowControl != nil {
op.flowControl.Flush(timeout)
}
}
// Flush writes all data in the queues.
func (t *ExpansionRelayTerminal) Flush(timeout time.Duration) {
if t.flowControl != nil {
t.flowControl.Flush(timeout)
}
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: ExpandOpType,
Requires: terminal.MayExpand,
Start: expand,
})
}
func expand(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Submit metrics.
newExpandOp.Inc()
// Check if we are running a public hub.
if !conf.PublicHub() {
return nil, terminal.ErrPermissionDenied.With("expanding is only allowed on public hubs")
}
// Parse destination hub ID.
dstData, err := data.GetNextBlock()
if err != nil {
return nil, terminal.ErrMalformedData.With("failed to parse destination: %w", err)
}
// Parse terminal options.
opts, tErr := terminal.ParseTerminalOpts(data)
if tErr != nil {
return nil, tErr.Wrap("failed to parse terminal options")
}
// Get crane with destination.
relayCrane := GetAssignedCrane(string(dstData))
if relayCrane == nil {
return nil, terminal.ErrHubUnavailable.With("no crane assigned to %q", string(dstData))
}
// TODO: Expand outside of hot path.
// Create operation and terminal.
op := &ExpandOp{
opts: opts,
dataRelayed: new(uint64),
ended: abool.New(),
relayTerminal: &ExpansionRelayTerminal{
crane: relayCrane,
id: relayCrane.getNextTerminalID(),
abandoning: abool.New(),
},
}
op.InitOperationBase(t, opID)
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
op.relayTerminal.op = op
// Create flow control.
switch opts.FlowControl {
case terminal.FlowControlDFQ:
// Operation
op.flowControl = terminal.NewDuplexFlowQueue(op.ctx, opts.FlowControlSize, op.submitBackwardUpstream)
op.deliverProxy = op.flowControl.Deliver
op.recvProxy = op.flowControl.Receive
op.sendProxy = op.submitBackwardFlowControl
// Relay Terminal
op.relayTerminal.flowControl = terminal.NewDuplexFlowQueue(op.ctx, opts.FlowControlSize, op.submitForwardUpstream)
op.relayTerminal.deliverProxy = op.relayTerminal.flowControl.Deliver
op.relayTerminal.recvProxy = op.relayTerminal.flowControl.Receive
op.relayTerminal.sendProxy = op.submitForwardFlowControl
case terminal.FlowControlNone:
// Operation
deliverToOp := make(chan *terminal.Msg, opts.FlowControlSize)
op.deliverProxy = terminal.MakeDirectDeliveryDeliverFunc(op.ctx, deliverToOp)
op.recvProxy = terminal.MakeDirectDeliveryRecvFunc(deliverToOp)
op.sendProxy = op.submitBackwardUpstream
// Relay Terminal
deliverToRelay := make(chan *terminal.Msg, opts.FlowControlSize)
op.relayTerminal.deliverProxy = terminal.MakeDirectDeliveryDeliverFunc(op.ctx, deliverToRelay)
op.relayTerminal.recvProxy = terminal.MakeDirectDeliveryRecvFunc(deliverToRelay)
op.relayTerminal.sendProxy = op.submitForwardUpstream
case terminal.FlowControlDefault:
fallthrough
default:
return nil, terminal.ErrInternalError.With("unknown flow control type %d", opts.FlowControl)
}
// Establish terminal on destination.
newInitData, tErr := opts.Pack()
if tErr != nil {
return nil, terminal.ErrInternalError.With("failed to re-pack options: %w", err)
}
tErr = op.relayTerminal.crane.EstablishNewTerminal(op.relayTerminal, newInitData)
if tErr != nil {
return nil, tErr
}
// Start workers.
module.StartWorker("expand op forward relay", op.forwardHandler)
module.StartWorker("expand op backward relay", op.backwardHandler)
if op.flowControl != nil {
op.flowControl.StartWorkers(module, "expand op")
}
if op.relayTerminal.flowControl != nil {
op.relayTerminal.flowControl.StartWorkers(module, "expand op terminal")
}
return op, nil
}
func (op *ExpandOp) submitForwardFlowControl(msg *terminal.Msg, timeout time.Duration) {
err := op.relayTerminal.flowControl.Send(msg, timeout)
if err != nil {
msg.Finish()
op.Stop(op, err.Wrap("failed to submit to forward flow control"))
}
}
func (op *ExpandOp) submitBackwardFlowControl(msg *terminal.Msg, timeout time.Duration) {
err := op.flowControl.Send(msg, timeout)
if err != nil {
msg.Finish()
op.Stop(op, err.Wrap("failed to submit to backward flow control"))
}
}
func (op *ExpandOp) submitForwardUpstream(msg *terminal.Msg, timeout time.Duration) {
msg.FlowID = op.relayTerminal.id
if msg.Unit.IsHighPriority() && op.opts.UsePriorityDataMsgs {
msg.Type = terminal.MsgTypePriorityData
} else {
msg.Type = terminal.MsgTypeData
}
err := op.relayTerminal.crane.Send(msg, timeout)
if err != nil {
msg.Finish()
op.Stop(op, err.Wrap("failed to submit to forward upstream"))
}
}
func (op *ExpandOp) submitBackwardUpstream(msg *terminal.Msg, timeout time.Duration) {
msg.FlowID = op.relayTerminal.id
if msg.Unit.IsHighPriority() && op.opts.UsePriorityDataMsgs {
msg.Type = terminal.MsgTypePriorityData
} else {
msg.Type = terminal.MsgTypeData
msg.Unit.RemovePriority()
}
// Note: op.Send() will transform high priority units to priority data msgs.
err := op.Send(msg, timeout)
if err != nil {
msg.Finish()
op.Stop(op, err.Wrap("failed to submit to backward upstream"))
}
}
func (op *ExpandOp) forwardHandler(_ context.Context) error {
// Metrics setup and submitting.
atomic.AddInt64(activeExpandOps, 1)
started := time.Now()
defer func() {
atomic.AddInt64(activeExpandOps, -1)
expandOpDurationHistogram.UpdateDuration(started)
expandOpRelayedDataHistogram.Update(float64(atomic.LoadUint64(op.dataRelayed)))
}()
for {
select {
case msg := <-op.recvProxy():
// Debugging:
// log.Debugf("spn/testing: forwarding at %s: %s", op.FmtID(), spew.Sdump(c.CompileData()))
// Wait for processing slot.
msg.Unit.WaitForSlot()
// Count relayed data for metrics.
atomic.AddUint64(op.dataRelayed, uint64(msg.Data.Length()))
// Receive data from the origin and forward it to the relay.
op.relayTerminal.sendProxy(msg, 1*time.Minute)
case <-op.ctx.Done():
return nil
}
}
}
func (op *ExpandOp) backwardHandler(_ context.Context) error {
for {
select {
case msg := <-op.relayTerminal.recvProxy():
// Debugging:
// log.Debugf("spn/testing: backwarding at %s: %s", op.FmtID(), spew.Sdump(c.CompileData()))
// Wait for processing slot.
msg.Unit.WaitForSlot()
// Count relayed data for metrics.
atomic.AddUint64(op.dataRelayed, uint64(msg.Data.Length()))
// Receive data from the relay and forward it to the origin.
op.sendProxy(msg, 1*time.Minute)
case <-op.ctx.Done():
return nil
}
}
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *ExpandOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
// Flush all messages before stopping.
op.Flush(1 * time.Minute)
op.relayTerminal.Flush(1 * time.Minute)
// Stop connected workers.
op.cancelCtx()
// Abandon connected terminal.
op.relayTerminal.Abandon(nil)
// Add context to error.
if err.IsError() {
return err.Wrap("relay operation failed with")
}
return err
}
// Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon().
func (t *ExpansionRelayTerminal) Abandon(err *terminal.Error) {
if t.abandoning.SetToIf(false, true) {
module.StartWorker("terminal abandon procedure", func(_ context.Context) error {
t.handleAbandonProcedure(err)
return nil
})
}
}
// HandleAbandon gives the terminal the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Abandon() instead.
func (t *ExpansionRelayTerminal) HandleAbandon(err *terminal.Error) (errorToSend *terminal.Error) {
// Stop the connected relay operation.
t.op.Stop(t.op, err)
// Add context to error.
if err.IsError() {
return err.Wrap("relay terminal failed with")
}
return err
}
// HandleDestruction gives the terminal the ability to clean up.
// The terminal has already fully shut down at this point.
// Should never be called directly. Call Abandon() instead.
func (t *ExpansionRelayTerminal) HandleDestruction(err *terminal.Error) {}
func (t *ExpansionRelayTerminal) handleAbandonProcedure(err *terminal.Error) {
// Call operation stop handle function for proper shutdown cleaning up.
err = t.HandleAbandon(err)
// Flush all messages before stopping.
t.Flush(1 * time.Minute)
// Send error to the connected Operation, if the error is internal.
if !err.IsExternal() {
if err == nil {
err = terminal.ErrStopping
}
msg := terminal.NewMsg(err.Pack())
msg.FlowID = t.ID()
msg.Type = terminal.MsgTypeStop
t.op.submitForwardUpstream(msg, 1*time.Second)
}
}
// FmtID returns the expansion ID hierarchy.
func (op *ExpandOp) FmtID() string {
return fmt.Sprintf("%s>%d <r> %s#%d", op.Terminal().FmtID(), op.ID(), op.relayTerminal.crane.ID, op.relayTerminal.id)
}
// FmtID returns the expansion ID hierarchy.
func (t *ExpansionRelayTerminal) FmtID() string {
return fmt.Sprintf("%s#%d", t.crane.ID, t.id)
}

298
spn/docks/op_latency.go Normal file
View File

@@ -0,0 +1,298 @@
package docks
import (
"bytes"
"context"
"fmt"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/log"
"github.com/safing/portbase/rng"
"github.com/safing/portmaster/spn/terminal"
)
const (
// LatencyTestOpType is the type ID of the latency test operation.
LatencyTestOpType = "latency"
latencyPingRequest = 1
latencyPingResponse = 2
latencyTestNonceSize = 16
latencyTestRuns = 10
)
var (
latencyTestPauseDuration = 1 * time.Second
latencyTestOpTimeout = latencyTestRuns * latencyTestPauseDuration * 3
)
// LatencyTestOp is used to measure latency.
type LatencyTestOp struct {
terminal.OperationBase
}
// LatencyTestClientOp is the client version of LatencyTestOp.
type LatencyTestClientOp struct {
LatencyTestOp
lastPingSentAt time.Time
lastPingNonce []byte
measuredLatencies []time.Duration
responses chan *terminal.Msg
testResult time.Duration
result chan *terminal.Error
}
// Type returns the type ID.
func (op *LatencyTestOp) Type() string {
return LatencyTestOpType
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: LatencyTestOpType,
Requires: terminal.IsCraneController,
Start: startLatencyTestOp,
})
}
// NewLatencyTestOp runs a latency test.
func NewLatencyTestOp(t terminal.Terminal) (*LatencyTestClientOp, *terminal.Error) {
// Create and init.
op := &LatencyTestClientOp{
responses: make(chan *terminal.Msg),
measuredLatencies: make([]time.Duration, 0, latencyTestRuns),
result: make(chan *terminal.Error, 1),
}
// Make ping request.
pingRequest, err := op.createPingRequest()
if err != nil {
return nil, terminal.ErrInternalError.With("%w", err)
}
// Send ping.
tErr := t.StartOperation(op, pingRequest, 1*time.Second)
if tErr != nil {
return nil, tErr
}
// Start handler.
module.StartWorker("op latency handler", op.handler)
return op, nil
}
func (op *LatencyTestClientOp) handler(ctx context.Context) error {
returnErr := terminal.ErrStopping
defer func() {
// Linters don't get that returnErr is used when directly used as defer.
op.Stop(op, returnErr)
}()
var nextTest <-chan time.Time
opTimeout := time.After(latencyTestOpTimeout)
for {
select {
case <-ctx.Done():
return nil
case <-opTimeout:
return nil
case <-nextTest:
// Create ping request msg.
pingRequest, err := op.createPingRequest()
if err != nil {
returnErr = terminal.ErrInternalError.With("%w", err)
return nil
}
msg := op.NewEmptyMsg()
msg.Unit.MakeHighPriority()
msg.Data = pingRequest
// Send it.
tErr := op.Send(msg, latencyTestOpTimeout)
if tErr != nil {
returnErr = tErr.Wrap("failed to send ping request")
return nil
}
op.Flush(1 * time.Second)
nextTest = nil
case msg := <-op.responses:
// Check if the op ended.
if msg == nil {
return nil
}
// Handle response
tErr := op.handleResponse(msg)
if tErr != nil {
returnErr = tErr
return nil //nolint:nilerr
}
// Check if we have enough latency tests.
if len(op.measuredLatencies) >= latencyTestRuns {
returnErr = op.reportMeasuredLatencies()
return nil
}
// Schedule next latency test, if not yet scheduled.
if nextTest == nil {
nextTest = time.After(latencyTestPauseDuration)
}
}
}
}
func (op *LatencyTestClientOp) createPingRequest() (*container.Container, error) {
// Generate nonce.
nonce, err := rng.Bytes(latencyTestNonceSize)
if err != nil {
return nil, fmt.Errorf("failed to create ping nonce")
}
// Set client request state.
op.lastPingSentAt = time.Now()
op.lastPingNonce = nonce
return container.New(
varint.Pack8(latencyPingRequest),
nonce,
), nil
}
func (op *LatencyTestClientOp) handleResponse(msg *terminal.Msg) *terminal.Error {
defer msg.Finish()
rType, err := msg.Data.GetNextN8()
if err != nil {
return terminal.ErrMalformedData.With("failed to get response type: %w", err)
}
switch rType {
case latencyPingResponse:
// Check if the ping nonce matches.
if !bytes.Equal(op.lastPingNonce, msg.Data.CompileData()) {
return terminal.ErrIntegrity.With("ping nonce mismatch")
}
op.lastPingNonce = nil
// Save latency.
op.measuredLatencies = append(op.measuredLatencies, time.Since(op.lastPingSentAt))
return nil
default:
return terminal.ErrIncorrectUsage.With("unknown response type")
}
}
func (op *LatencyTestClientOp) reportMeasuredLatencies() *terminal.Error {
// Find lowest value.
lowestLatency := time.Hour
for _, latency := range op.measuredLatencies {
if latency < lowestLatency {
lowestLatency = latency
}
}
op.testResult = lowestLatency
// Save the result to the crane.
if controller, ok := op.Terminal().(*CraneControllerTerminal); ok {
if controller.Crane.ConnectedHub != nil {
controller.Crane.ConnectedHub.GetMeasurements().SetLatency(op.testResult)
log.Infof("spn/docks: measured latency to %s: %s", controller.Crane.ConnectedHub, op.testResult)
return nil
} else if controller.Crane.IsMine() {
return terminal.ErrInternalError.With("latency operation was run on %s without a connected hub set", controller.Crane)
}
} else if !runningTests {
return terminal.ErrInternalError.With("latency operation was run on terminal that is not a crane controller, but %T", op.Terminal())
}
return nil
}
// Deliver delivers a message to the operation.
func (op *LatencyTestClientOp) Deliver(msg *terminal.Msg) *terminal.Error {
// Optimized delivery with 1s timeout.
select {
case op.responses <- msg:
default:
select {
case op.responses <- msg:
case <-time.After(1 * time.Second):
return terminal.ErrTimeout
}
}
return nil
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *LatencyTestClientOp) HandleStop(tErr *terminal.Error) (errorToSend *terminal.Error) {
close(op.responses)
select {
case op.result <- tErr:
default:
}
return tErr
}
// Result returns the result (end error) of the operation.
func (op *LatencyTestClientOp) Result() <-chan *terminal.Error {
return op.result
}
func startLatencyTestOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Create operation.
op := &LatencyTestOp{}
op.InitOperationBase(t, opID)
// Handle first request.
msg := op.NewEmptyMsg()
msg.Data = data
tErr := op.Deliver(msg)
if tErr != nil {
return nil, tErr
}
return op, nil
}
// Deliver delivers a message to the operation.
func (op *LatencyTestOp) Deliver(msg *terminal.Msg) *terminal.Error {
// Get request type.
rType, err := msg.Data.GetNextN8()
if err != nil {
return terminal.ErrMalformedData.With("failed to get response type: %w", err)
}
switch rType {
case latencyPingRequest:
// Keep the nonce and just replace the msg type.
msg.Data.PrependNumber(latencyPingResponse)
msg.Type = terminal.MsgTypeData
msg.Unit.ReUse()
msg.Unit.MakeHighPriority()
// Send response.
tErr := op.Send(msg, latencyTestOpTimeout)
if tErr != nil {
return tErr.Wrap("failed to send ping response")
}
op.Flush(1 * time.Second)
return nil
default:
return terminal.ErrIncorrectUsage.With("unknown request type")
}
}

View File

@@ -0,0 +1,59 @@
package docks
import (
"testing"
"time"
"github.com/safing/portmaster/spn/terminal"
)
func TestLatencyOp(t *testing.T) {
t.Parallel()
var (
latTestDelay = 10 * time.Millisecond
latTestQueueSize uint32 = 10
)
// Reduce waiting time.
latencyTestPauseDuration = 100 * time.Millisecond
// Create test terminal pair.
a, b, err := terminal.NewSimpleTestTerminalPair(
latTestDelay,
int(latTestQueueSize),
&terminal.TerminalOpts{
FlowControl: terminal.FlowControlNone,
FlowControlSize: latTestQueueSize,
},
)
if err != nil {
t.Fatalf("failed to create test terminal pair: %s", err)
}
// Grant permission for op on remote terminal and start op.
b.GrantPermission(terminal.IsCraneController)
op, tErr := NewLatencyTestOp(a)
if tErr != nil {
t.Fatalf("failed to start op: %s", err)
}
// Wait for result and check error.
tErr = <-op.Result()
if tErr.IsError() {
t.Fatalf("op failed: %s", tErr)
}
t.Logf("measured latency: %f ms", float64(op.testResult)/float64(time.Millisecond))
// Calculate expected latency.
expectedLatency := float64(latTestDelay * 2)
t.Logf("expected latency: %f ms", expectedLatency/float64(time.Millisecond))
// Check if measured latency is within parameters.
if float64(op.testResult) > expectedLatency*1.2 {
t.Fatal("measured latency too high")
}
if float64(op.testResult) < expectedLatency*0.9 {
t.Fatal("measured latency too low")
}
}

150
spn/docks/op_sync_state.go Normal file
View File

@@ -0,0 +1,150 @@
package docks
import (
"context"
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portmaster/spn/conf"
"github.com/safing/portmaster/spn/terminal"
)
// SyncStateOpType is the type ID of the sync state operation.
const SyncStateOpType = "sync/state"
// SyncStateOp is used to sync the crane state.
type SyncStateOp struct {
terminal.OneOffOperationBase
}
// SyncStateMessage holds the sync data.
type SyncStateMessage struct {
Stopping bool
RequestStopping bool
}
// Type returns the type ID.
func (op *SyncStateOp) Type() string {
return SyncStateOpType
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: SyncStateOpType,
Requires: terminal.IsCraneController,
Start: runSyncStateOp,
})
}
// startSyncStateOp starts a worker that runs the sync state operation.
func (crane *Crane) startSyncStateOp() {
module.StartWorker("sync crane state", func(ctx context.Context) error {
tErr := crane.Controller.SyncState(ctx)
if tErr != nil {
return tErr
}
return nil
})
}
// SyncState runs a sync state operation.
func (controller *CraneControllerTerminal) SyncState(ctx context.Context) *terminal.Error {
// Check if we are a public Hub, whether we own the crane and whether the lane is public too.
if !conf.PublicHub() || !controller.Crane.Public() {
return nil
}
// Create and init.
op := &SyncStateOp{}
op.Init()
// Get optimization states.
requestStopping := false
func() {
controller.Crane.NetState.lock.Lock()
defer controller.Crane.NetState.lock.Unlock()
requestStopping = controller.Crane.NetState.stoppingRequested
}()
// Create sync message.
msg := &SyncStateMessage{
Stopping: controller.Crane.stopping.IsSet(),
RequestStopping: requestStopping,
}
data, err := dsd.Dump(msg, dsd.CBOR)
if err != nil {
return terminal.ErrInternalError.With("%w", err)
}
// Send message.
tErr := controller.StartOperation(op, container.New(data), 30*time.Second)
if tErr != nil {
return tErr
}
// Wait for reply
select {
case tErr = <-op.Result:
if tErr.IsError() {
return tErr
}
return nil
case <-ctx.Done():
return nil
case <-time.After(1 * time.Minute):
return terminal.ErrTimeout.With("timed out while waiting for sync crane result")
}
}
func runSyncStateOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Check if we are a on a crane controller.
var ok bool
var controller *CraneControllerTerminal
if controller, ok = t.(*CraneControllerTerminal); !ok {
return nil, terminal.ErrIncorrectUsage.With("can only be used with a crane controller")
}
// Check if we are a public Hub and whether the lane is public too.
if !conf.PublicHub() || !controller.Crane.Public() {
return nil, terminal.ErrPermissionDenied.With("only public lanes can sync crane status")
}
// Load message.
syncState := &SyncStateMessage{}
_, err := dsd.Load(data.CompileData(), syncState)
if err != nil {
return nil, terminal.ErrMalformedData.With("failed to load sync state message: %w", err)
}
// Apply optimization state.
controller.Crane.NetState.lock.Lock()
defer controller.Crane.NetState.lock.Unlock()
controller.Crane.NetState.stoppingRequestedByPeer = syncState.RequestStopping
// Apply crane state only when we don't own the crane.
if !controller.Crane.IsMine() {
// Apply sync state.
var changed bool
if syncState.Stopping {
if controller.Crane.stopping.SetToIf(false, true) {
controller.Crane.NetState.markedStoppingAt = time.Now()
changed = true
}
} else {
if controller.Crane.stopping.SetToIf(true, false) {
controller.Crane.NetState.markedStoppingAt = time.Time{}
changed = true
}
}
// Notify of change.
if changed {
controller.Crane.NotifyUpdate()
}
}
return nil, nil
}

135
spn/docks/op_whoami.go Normal file
View File

@@ -0,0 +1,135 @@
package docks
import (
"time"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portmaster/spn/terminal"
)
const (
// WhoAmIType is the type ID of the latency test operation.
WhoAmIType = "whoami"
whoAmITimeout = 3 * time.Second
)
// WhoAmIOp is used to request some metadata about the other side.
type WhoAmIOp struct {
terminal.OneOffOperationBase
response *WhoAmIResponse
}
// WhoAmIResponse is a whoami response.
type WhoAmIResponse struct {
// Timestamp in nanoseconds
Timestamp int64 `cbor:"t,omitempty" json:"t,omitempty"`
// Addr is the remote address as reported by the crane terminal (IP and port).
Addr string `cbor:"a,omitempty" json:"a,omitempty"`
}
// Type returns the type ID.
func (op *WhoAmIOp) Type() string {
return WhoAmIType
}
func init() {
terminal.RegisterOpType(terminal.OperationFactory{
Type: WhoAmIType,
Start: startWhoAmI,
})
}
// WhoAmI executes a whoami operation and returns the response.
func WhoAmI(t terminal.Terminal) (*WhoAmIResponse, *terminal.Error) {
whoami, err := NewWhoAmIOp(t)
if err.IsError() {
return nil, err
}
// Wait for response.
select {
case tErr := <-whoami.Result:
if tErr.IsError() {
return nil, tErr
}
return whoami.response, nil
case <-time.After(whoAmITimeout * 2):
return nil, terminal.ErrTimeout
}
}
// NewWhoAmIOp starts a new whoami operation.
func NewWhoAmIOp(t terminal.Terminal) (*WhoAmIOp, *terminal.Error) {
// Create operation and init.
op := &WhoAmIOp{}
op.OneOffOperationBase.Init()
// Send ping.
tErr := t.StartOperation(op, nil, whoAmITimeout)
if tErr != nil {
return nil, tErr
}
return op, nil
}
// Deliver delivers a message to the operation.
func (op *WhoAmIOp) Deliver(msg *terminal.Msg) *terminal.Error {
defer msg.Finish()
// Parse response.
response := &WhoAmIResponse{}
_, err := dsd.Load(msg.Data.CompileData(), response)
if err != nil {
return terminal.ErrMalformedData.With("failed to parse ping response: %w", err)
}
op.response = response
return terminal.ErrExplicitAck
}
func startWhoAmI(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
// Get crane terminal, if available.
ct, _ := t.(*CraneTerminal)
// Create response.
r := &WhoAmIResponse{
Timestamp: time.Now().UnixNano(),
}
if ct != nil {
r.Addr = ct.RemoteAddr().String()
}
response, err := dsd.Dump(r, dsd.CBOR)
if err != nil {
return nil, terminal.ErrInternalError.With("failed to create whoami response: %w", err)
}
// Send response.
msg := terminal.NewMsg(response)
msg.FlowID = opID
msg.Unit.MakeHighPriority()
if terminal.UsePriorityDataMsgs {
msg.Type = terminal.MsgTypePriorityData
}
tErr := t.Send(msg, whoAmITimeout)
if tErr != nil {
// Finish message unit on failure.
msg.Finish()
return nil, tErr.With("failed to send ping response")
}
// Operation is just one response and finished successfully.
return nil, nil
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *WhoAmIOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
// Continue with usual handling of inherited base.
return op.OneOffOperationBase.HandleStop(err)
}

View File

@@ -0,0 +1,24 @@
package docks
import (
"testing"
"github.com/safing/portmaster/spn/terminal"
)
func TestWhoAmIOp(t *testing.T) {
t.Parallel()
// Create test terminal pair.
a, _, err := terminal.NewSimpleTestTerminalPair(0, 0, nil)
if err != nil {
t.Fatalf("failed to create test terminal pair: %s", err)
}
// Run op.
resp, tErr := WhoAmI(a)
if tErr.IsError() {
t.Fatal(tErr)
}
t.Logf("whoami: %+v", resp)
}

View File

@@ -0,0 +1,150 @@
package docks
import (
"fmt"
"sync"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/container"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/terminal"
)
// ExpansionTerminal is used for expanding to another Hub.
type ExpansionTerminal struct {
*terminal.TerminalBase
relayOp *ExpansionTerminalRelayOp
changeNotifyFuncReady *abool.AtomicBool
changeNotifyFunc func()
reachableChecked time.Time
reachableLock sync.Mutex
}
// ExpansionTerminalRelayOp is the operation that connects to the relay.
type ExpansionTerminalRelayOp struct {
terminal.OperationBase
expansionTerminal *ExpansionTerminal
}
// Type returns the type ID.
func (op *ExpansionTerminalRelayOp) Type() string {
return ExpandOpType
}
// ExpandTo initiates an expansion.
func ExpandTo(from terminal.Terminal, routeTo string, encryptFor *hub.Hub) (*ExpansionTerminal, *terminal.Error) {
// First, create the local endpoint terminal to generate the init data.
// Create options and bare expansion terminal.
opts := terminal.DefaultExpansionTerminalOpts()
opts.Encrypt = encryptFor != nil
expansion := &ExpansionTerminal{
changeNotifyFuncReady: abool.New(),
}
expansion.relayOp = &ExpansionTerminalRelayOp{
expansionTerminal: expansion,
}
// Create base terminal for expansion.
base, initData, tErr := terminal.NewLocalBaseTerminal(
module.Ctx,
0, // Ignore; The ID of the operation is used for communication.
from.FmtID(),
encryptFor,
opts,
expansion.relayOp,
)
if tErr != nil {
return nil, tErr.Wrap("failed to create expansion terminal base")
}
expansion.TerminalBase = base
base.SetTerminalExtension(expansion)
base.SetTimeout(defaultTerminalIdleTimeout)
// Second, start the actual relay operation.
// Create setup message for relay operation.
opInitData := container.New()
opInitData.AppendAsBlock([]byte(routeTo))
opInitData.AppendContainer(initData)
// Start relay operation on connected Hub.
tErr = from.StartOperation(expansion.relayOp, opInitData, 5*time.Second)
if tErr != nil {
return nil, tErr.Wrap("failed to start expansion operation")
}
// Start Workers.
base.StartWorkers(module, "expansion terminal")
return expansion, nil
}
// SetChangeNotifyFunc sets a callback function that is called when the terminal state changes.
func (t *ExpansionTerminal) SetChangeNotifyFunc(f func()) {
if t.changeNotifyFuncReady.IsSet() {
return
}
t.changeNotifyFunc = f
t.changeNotifyFuncReady.Set()
}
// NeedsReachableCheck returns whether the terminal should be checked if it is
// reachable via the existing network internal relayed connection.
func (t *ExpansionTerminal) NeedsReachableCheck(maxCheckAge time.Duration) bool {
t.reachableLock.Lock()
defer t.reachableLock.Unlock()
return time.Since(t.reachableChecked) > maxCheckAge
}
// MarkReachable marks the terminal as reachable via the existing network
// internal relayed connection.
func (t *ExpansionTerminal) MarkReachable() {
t.reachableLock.Lock()
defer t.reachableLock.Unlock()
t.reachableChecked = time.Now()
}
// HandleDestruction gives the terminal the ability to clean up.
// The terminal has already fully shut down at this point.
// Should never be called directly. Call Abandon() instead.
func (t *ExpansionTerminal) HandleDestruction(err *terminal.Error) {
// Trigger update of connected Pin.
if t.changeNotifyFuncReady.IsSet() {
t.changeNotifyFunc()
}
// Stop the relay operation.
// The error message is arlready sent by the terminal.
t.relayOp.Stop(t.relayOp, nil)
}
// CustomIDFormat formats the terminal ID.
func (t *ExpansionTerminal) CustomIDFormat() string {
return fmt.Sprintf("%s~%d", t.relayOp.Terminal().FmtID(), t.relayOp.ID())
}
// Deliver delivers a message to the operation.
func (op *ExpansionTerminalRelayOp) Deliver(msg *terminal.Msg) *terminal.Error {
// Proxy directly to expansion terminal.
return op.expansionTerminal.Deliver(msg)
}
// HandleStop gives the operation the ability to cleanly shut down.
// The returned error is the error to send to the other side.
// Should never be called directly. Call Stop() instead.
func (op *ExpansionTerminalRelayOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
// Stop the expansion terminal.
// The error message will be sent by the operation.
op.expansionTerminal.Abandon(nil)
return err
}

View File

@@ -0,0 +1,305 @@
package docks
import (
"fmt"
"os"
"runtime/pprof"
"sync"
"testing"
"time"
"github.com/safing/portmaster/spn/access"
"github.com/safing/portmaster/spn/cabin"
"github.com/safing/portmaster/spn/hub"
"github.com/safing/portmaster/spn/ships"
"github.com/safing/portmaster/spn/terminal"
)
const defaultTestQueueSize = 200
func TestExpansion(t *testing.T) {
t.Parallel()
// Test without and with encryption.
for _, encrypt := range []bool{false, true} {
// Test down/up separately and in parallel.
for _, parallel := range []bool{false, true} {
// Test with different flow controls.
for _, fc := range []struct {
flowControl terminal.FlowControlType
flowControlSize uint32
}{
{
flowControl: terminal.FlowControlNone,
flowControlSize: 5,
},
{
flowControl: terminal.FlowControlDFQ,
flowControlSize: defaultTestQueueSize,
},
} {
// Run tests with combined options.
testExpansion(
t,
"expansion-hop-test",
&terminal.TerminalOpts{
Encrypt: encrypt,
Padding: 8,
FlowControl: fc.flowControl,
FlowControlSize: fc.flowControlSize,
},
defaultTestQueueSize,
defaultTestQueueSize,
parallel,
)
}
}
}
stressTestOpts := &terminal.TerminalOpts{
Encrypt: true,
Padding: 8,
FlowControl: terminal.FlowControlDFQ,
FlowControlSize: defaultTestQueueSize,
}
testExpansion(t, "expansion-stress-test-down", stressTestOpts, defaultTestQueueSize*100, 0, false)
testExpansion(t, "expansion-stress-test-up", stressTestOpts, 0, defaultTestQueueSize*100, false)
testExpansion(t, "expansion-stress-test-duplex", stressTestOpts, defaultTestQueueSize*100, defaultTestQueueSize*100, false)
}
func testExpansion( //nolint:maintidx,thelper
t *testing.T,
testID string,
terminalOpts *terminal.TerminalOpts,
clientCountTo,
serverCountTo uint64,
inParallel bool,
) {
testID += fmt.Sprintf(":encrypt=%v,flowType=%d,parallel=%v", terminalOpts.Encrypt, terminalOpts.FlowControl, inParallel)
var identity2, identity3, identity4 *cabin.Identity
var connectedHub2, connectedHub3, connectedHub4 *hub.Hub
if terminalOpts.Encrypt {
identity2, connectedHub2 = getTestIdentity(t)
identity3, connectedHub3 = getTestIdentity(t)
identity4, connectedHub4 = getTestIdentity(t)
}
// Build ships and cranes.
optimalMinLoadSize = 100
ship1to2 := ships.NewTestShip(!terminalOpts.Encrypt, 100)
ship2to3 := ships.NewTestShip(!terminalOpts.Encrypt, 100)
ship3to4 := ships.NewTestShip(!terminalOpts.Encrypt, 100)
var crane1, crane2to1, crane2to3, crane3to2, crane3to4, crane4 *Crane
var craneWg sync.WaitGroup
craneWg.Add(6)
go func() {
var err error
crane1, err = NewCrane(ship1to2, connectedHub2, nil)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not create crane1: %s", testID, err))
}
crane1.ID = "c1"
err = crane1.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not start crane1: %s", testID, err))
}
crane1.ship.MarkPublic()
craneWg.Done()
}()
go func() {
var err error
crane2to1, err = NewCrane(ship1to2.Reverse(), nil, identity2)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not create crane2to1: %s", testID, err))
}
crane2to1.ID = "c2to1"
err = crane2to1.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not start crane2to1: %s", testID, err))
}
crane2to1.ship.MarkPublic()
craneWg.Done()
}()
go func() {
var err error
crane2to3, err = NewCrane(ship2to3, connectedHub3, nil)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not create crane2to3: %s", testID, err))
}
crane2to3.ID = "c2to3"
err = crane2to3.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not start crane2to3: %s", testID, err))
}
crane2to3.ship.MarkPublic()
craneWg.Done()
}()
go func() {
var err error
crane3to2, err = NewCrane(ship2to3.Reverse(), nil, identity3)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not create crane3to2: %s", testID, err))
}
crane3to2.ID = "c3to2"
err = crane3to2.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not start crane3to2: %s", testID, err))
}
crane3to2.ship.MarkPublic()
craneWg.Done()
}()
go func() {
var err error
crane3to4, err = NewCrane(ship3to4, connectedHub4, nil)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not create crane3to4: %s", testID, err))
}
crane3to4.ID = "c3to4"
err = crane3to4.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not start crane3to4: %s", testID, err))
}
crane3to4.ship.MarkPublic()
craneWg.Done()
}()
go func() {
var err error
crane4, err = NewCrane(ship3to4.Reverse(), nil, identity4)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not create crane4: %s", testID, err))
}
crane4.ID = "c4"
err = crane4.Start(module.Ctx)
if err != nil {
panic(fmt.Sprintf("expansion test %s could not start crane4: %s", testID, err))
}
crane4.ship.MarkPublic()
craneWg.Done()
}()
craneWg.Wait()
// Assign cranes.
crane3HubID := testID + "-crane3HubID"
AssignCrane(crane3HubID, crane2to3)
crane4HubID := testID + "-crane4HubID"
AssignCrane(crane4HubID, crane3to4)
t.Logf("expansion test %s: initial setup complete", testID)
// Wait async for test to complete, print stack after timeout.
finished := make(chan struct{})
go func() {
select {
case <-finished:
case <-time.After(30 * time.Second):
fmt.Printf("expansion test %s is taking too long, print stack:\n", testID)
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
os.Exit(1)
}
}()
// Start initial crane.
homeTerminal, initData, tErr := NewLocalCraneTerminal(crane1, nil, &terminal.TerminalOpts{})
if tErr != nil {
t.Fatalf("expansion test %s failed to create home terminal: %s", testID, tErr)
}
tErr = crane1.EstablishNewTerminal(homeTerminal, initData)
if tErr != nil {
t.Fatalf("expansion test %s failed to connect home terminal: %s", testID, tErr)
}
t.Logf("expansion test %s: home terminal setup complete", testID)
time.Sleep(100 * time.Millisecond)
// Start counters for testing.
op0, tErr := terminal.NewCounterOp(homeTerminal, terminal.CounterOpts{
ClientCountTo: clientCountTo,
ServerCountTo: serverCountTo,
})
if tErr != nil {
t.Fatalf("expansion test %s failed to run counter op: %s", testID, tErr)
}
t.Logf("expansion test %s: home terminal counter setup complete", testID)
if !inParallel {
op0.Wait()
}
// Start expansion to crane 3.
opAuthTo2, tErr := access.AuthorizeToTerminal(homeTerminal)
if tErr != nil {
t.Fatalf("expansion test %s failed to auth with home terminal: %s", testID, tErr)
}
tErr = <-opAuthTo2.Result
if tErr.IsError() {
t.Fatalf("expansion test %s failed to auth with home terminal: %s", testID, tErr)
}
expansionTerminalTo3, err := ExpandTo(homeTerminal, crane3HubID, connectedHub3)
if err != nil {
t.Fatalf("expansion test %s failed to expand to %s: %s", testID, crane3HubID, tErr)
}
// Start counters for testing.
op1, tErr := terminal.NewCounterOp(expansionTerminalTo3, terminal.CounterOpts{
ClientCountTo: clientCountTo,
ServerCountTo: serverCountTo,
})
if tErr != nil {
t.Fatalf("expansion test %s failed to run counter op: %s", testID, tErr)
}
t.Logf("expansion test %s: expansion to crane3 and counter setup complete", testID)
if !inParallel {
op1.Wait()
}
// Start expansion to crane 4.
opAuthTo3, tErr := access.AuthorizeToTerminal(expansionTerminalTo3)
if tErr != nil {
t.Fatalf("expansion test %s failed to auth with extenstion terminal: %s", testID, tErr)
}
tErr = <-opAuthTo3.Result
if tErr.IsError() {
t.Fatalf("expansion test %s failed to auth with extenstion terminal: %s", testID, tErr)
}
expansionTerminalTo4, err := ExpandTo(expansionTerminalTo3, crane4HubID, connectedHub4)
if err != nil {
t.Fatalf("expansion test %s failed to expand to %s: %s", testID, crane4HubID, tErr)
}
// Start counters for testing.
op2, tErr := terminal.NewCounterOp(expansionTerminalTo4, terminal.CounterOpts{
ClientCountTo: clientCountTo,
ServerCountTo: serverCountTo,
})
if tErr != nil {
t.Fatalf("expansion test %s failed to run counter op: %s", testID, tErr)
}
t.Logf("expansion test %s: expansion to crane4 and counter setup complete", testID)
op2.Wait()
// Wait for op1 if not already.
if inParallel {
op0.Wait()
op1.Wait()
}
// Wait for completion.
close(finished)
// Wait a little so that all errors can be propagated, so we can truly see
// if we succeeded.
time.Sleep(100 * time.Millisecond)
// Check errors.
if op1.Error != nil {
t.Fatalf("crane test %s counter op1 failed: %s", testID, op1.Error)
}
if op2.Error != nil {
t.Fatalf("crane test %s counter op2 failed: %s", testID, op2.Error)
}
}

202
spn/hub/database.go Normal file
View File

@@ -0,0 +1,202 @@
package hub
import (
"errors"
"fmt"
"sync"
"time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
)
var (
db = database.NewInterface(&database.Options{
Local: true,
Internal: true,
})
getFromNavigator func(mapName, hubID string) *Hub
)
// MakeHubDBKey makes a hub db key.
func MakeHubDBKey(mapName, hubID string) string {
return fmt.Sprintf("cache:spn/hubs/%s/%s", mapName, hubID)
}
// MakeHubMsgDBKey makes a hub msg db key.
func MakeHubMsgDBKey(mapName string, msgType MsgType, hubID string) string {
return fmt.Sprintf("cache:spn/msgs/%s/%s/%s", mapName, msgType, hubID)
}
// SetNavigatorAccess sets a shortcut function to access hubs from the navigator instead of having go through the database.
// This also reduces the number of object in RAM and better caches parsed attributes.
func SetNavigatorAccess(fn func(mapName, hubID string) *Hub) {
if getFromNavigator == nil {
getFromNavigator = fn
}
}
// GetHub get a Hub from the database - or the navigator, if configured.
func GetHub(mapName string, hubID string) (*Hub, error) {
if getFromNavigator != nil {
hub := getFromNavigator(mapName, hubID)
if hub != nil {
return hub, nil
}
}
return GetHubByKey(MakeHubDBKey(mapName, hubID))
}
// GetHubByKey returns a hub by its raw DB key.
func GetHubByKey(key string) (*Hub, error) {
r, err := db.Get(key)
if err != nil {
return nil, err
}
hub, err := EnsureHub(r)
if err != nil {
return nil, err
}
return hub, nil
}
// EnsureHub makes sure a database record is a Hub.
func EnsureHub(r record.Record) (*Hub, error) {
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
newHub := &Hub{}
err := record.Unwrap(r, newHub)
if err != nil {
return nil, err
}
newHub = prepHub(newHub)
// Fully validate when getting from database.
if err := newHub.Info.validateFormatting(); err != nil {
return nil, fmt.Errorf("announcement failed format validation: %w", err)
}
if err := newHub.Status.validateFormatting(); err != nil {
return nil, fmt.Errorf("status failed format validation: %w", err)
}
if err := newHub.Info.prepare(false); err != nil {
return nil, fmt.Errorf("failed to prepare announcement: %w", err)
}
return newHub, nil
}
// or adjust type
newHub, ok := r.(*Hub)
if !ok {
return nil, fmt.Errorf("record not of type *Hub, but %T", r)
}
newHub = prepHub(newHub)
// Prepare only when already parsed.
if err := newHub.Info.prepare(false); err != nil {
return nil, fmt.Errorf("failed to prepare announcement: %w", err)
}
// ensure status
return newHub, nil
}
func prepHub(h *Hub) *Hub {
if h.Status == nil {
h.Status = &Status{}
}
h.Measurements = getSharedMeasurements(h.ID, h.Measurements)
return h
}
// Save saves to Hub to the correct scope in the database.
func (h *Hub) Save() error {
if !h.KeyIsSet() {
h.SetKey(MakeHubDBKey(h.Map, h.ID))
}
return db.Put(h)
}
// RemoveHubAndMsgs deletes a Hub and it's saved messages from the database.
func RemoveHubAndMsgs(mapName string, hubID string) (err error) {
err = db.Delete(MakeHubDBKey(mapName, hubID))
if err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to delete main hub entry: %w", err)
}
err = db.Delete(MakeHubMsgDBKey(mapName, MsgTypeAnnouncement, hubID))
if err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to delete hub announcement data: %w", err)
}
err = db.Delete(MakeHubMsgDBKey(mapName, MsgTypeStatus, hubID))
if err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to delete hub status data: %w", err)
}
return nil
}
// HubMsg stores raw Hub messages.
type HubMsg struct { //nolint:golint
record.Base
sync.Mutex
ID string
Map string
Type MsgType
Data []byte
Received int64
}
// SaveHubMsg saves a raw (and signed) message received by another Hub.
func SaveHubMsg(id string, mapName string, msgType MsgType, data []byte) error {
// create wrapper record
msg := &HubMsg{
ID: id,
Map: mapName,
Type: msgType,
Data: data,
Received: time.Now().Unix(),
}
// set key
msg.SetKey(MakeHubMsgDBKey(msg.Map, msg.Type, msg.ID))
// save
return db.PutNew(msg)
}
// QueryRawGossipMsgs queries the database for raw gossip messages.
func QueryRawGossipMsgs(mapName string, msgType MsgType) (it *iterator.Iterator, err error) {
it, err = db.Query(query.New(MakeHubMsgDBKey(mapName, msgType, "")))
return
}
// EnsureHubMsg makes sure a database record is a HubMsg.
func EnsureHubMsg(r record.Record) (*HubMsg, error) {
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
newHubMsg := &HubMsg{}
err := record.Unwrap(r, newHubMsg)
if err != nil {
return nil, err
}
return newHubMsg, nil
}
// or adjust type
newHubMsg, ok := r.(*HubMsg)
if !ok {
return nil, fmt.Errorf("record not of type *Hub, but %T", r)
}
return newHubMsg, nil
}

21
spn/hub/errors.go Normal file
View File

@@ -0,0 +1,21 @@
package hub
import "errors"
var (
// ErrMissingInfo signifies that the hub is missing the HubAnnouncement.
ErrMissingInfo = errors.New("hub has no announcement")
// ErrMissingTransports signifies that the hub announcement did not specify any transports.
ErrMissingTransports = errors.New("hub announcement has no transports")
// ErrMissingIPs signifies that the hub announcement did not specify any IPs,
// or none of the IPs is supported by the client.
ErrMissingIPs = errors.New("hub announcement has no (supported) IPs")
// ErrTemporaryValidationError is returned when a validation error might be temporary.
ErrTemporaryValidationError = errors.New("temporary validation error")
// ErrOldData is returned when received data is outdated.
ErrOldData = errors.New("")
)

69
spn/hub/format.go Normal file
View File

@@ -0,0 +1,69 @@
package hub
import (
"fmt"
"net"
"regexp"
"github.com/safing/portmaster/service/network/netutils"
)
// BaselineCharset defines the permitted characters.
var BaselineCharset = regexp.MustCompile(
// Start of charset selection.
`^[` +
// Printable ASCII (character code 32-127), excluding common control characters of different languages: "$%&';<>\` and DELETE.
` !#()*+,\-\./0-9:=?@A-Z[\]^_a-z{|}~` +
// Only latin characters from extended ASCII (character code 128-255).
`ŠŒŽšœžŸ¡¿ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿ` +
// End of charset selection.
`]*$`,
)
func checkStringFormat(fieldName, value string, maxLength int) error {
switch {
case len(value) > maxLength:
return fmt.Errorf("field %s with length of %d exceeds max length of %d", fieldName, len(value), maxLength)
case !BaselineCharset.MatchString(value):
return fmt.Errorf("field %s contains characters not permitted by baseline validation", fieldName)
default:
return nil
}
}
func checkStringSliceFormat(fieldName string, value []string, maxLength, maxStringLength int) error { //nolint:unparam
if len(value) > maxLength {
return fmt.Errorf("field %s with array/slice length of %d exceeds max length of %d", fieldName, len(value), maxLength)
}
for _, s := range value {
if err := checkStringFormat(fieldName, s, maxStringLength); err != nil {
return err
}
}
return nil
}
func checkByteSliceFormat(fieldName string, value []byte, maxLength int) error {
switch {
case len(value) > maxLength:
return fmt.Errorf("field %s with length of %d exceeds max length of %d", fieldName, len(value), maxLength)
default:
return nil
}
}
func checkIPFormat(fieldName string, value net.IP) error {
// Check if there is an IP address.
if value == nil {
return nil
}
switch {
case len(value) != 4 && len(value) != 16:
return fmt.Errorf("field %s has an invalid length of %d for an IP address", fieldName, len(value))
case netutils.GetIPScope(value) == netutils.Invalid:
return fmt.Errorf("field %s holds an invalid IP address: %s", fieldName, value)
default:
return nil
}
}

81
spn/hub/format_test.go Normal file
View File

@@ -0,0 +1,81 @@
package hub
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCheckStringFormat(t *testing.T) {
t.Parallel()
testSet := map[string]bool{
// Printable ASCII (character code 32-127)
" ": true, "!": true, `"`: false, "#": true, "$": false, "%": false, "&": false, "'": false,
"(": true, ")": true, "*": true, "+": true, ",": true, "-": true, ".": true, "/": true,
"0": true, "1": true, "2": true, "3": true, "4": true, "5": true, "6": true, "7": true,
"8": true, "9": true, ":": true, ";": false, "<": false, "=": true, ">": false, "?": true,
"@": true, "A": true, "B": true, "C": true, "D": true, "E": true, "F": true, "G": true,
"H": true, "I": true, "J": true, "K": true, "L": true, "M": true, "N": true, "O": true,
"P": true, "Q": true, "R": true, "S": true, "T": true, "U": true, "V": true, "W": true,
"X": true, "Y": true, "Z": true, "[": true, `\`: false, "]": true, "^": true, "_": true,
"`": false, "a": true, "b": true, "c": true, "d": true, "e": true, "f": true, "g": true,
"h": true, "i": true, "j": true, "k": true, "l": true, "m": true, "n": true, "o": true,
"p": true, "q": true, "r": true, "s": true, "t": true, "u": true, "v": true, "w": true,
"x": true, "y": true, "z": true, "{": true, "|": true, "}": true, "~": true,
// Not testing for DELETE character.
// Extended ASCII (character code 128-255)
"€": false, "": false, "ƒ": false, "„": false, "…": false, "†": false, "‡": false, "ˆ": false,
"‰": false, "Š": true, "": false, "Œ": true, "Ž": true, "": false, "": false, "“": false,
"”": false, "•": false, "": false, "—": false, "˜": false, "™": false, "š": true, "": false,
"œ": true, "ž": true, "Ÿ": true, "¡": true, "¢": false, "£": false, "¤": false, "¥": false,
"¦": false, "§": false, "¨": false, "©": false, "ª": false, "«": false, "¬": false, "®": false,
"¯": false, "°": false, "±": false, "²": false, "³": false, "´": false, "µ": false, "¶": false,
"·": false, "¸": false, "¹": false, "º": false, "»": false, "¼": false, "½": false, "¾": false,
"¿": true, "À": true, "Á": true, "Â": true, "Ã": true, "Ä": true, "Å": true, "Æ": true,
"Ç": true, "È": true, "É": true, "Ê": true, "Ë": true, "Ì": true, "Í": true, "Î": true,
"Ï": true, "Ð": true, "Ñ": true, "Ò": true, "Ó": true, "Ô": true, "Õ": true, "Ö": true,
"×": false, "Ø": true, "Ù": true, "Ú": true, "Û": true, "Ü": true, "Ý": true, "Þ": true,
"ß": true, "à": true, "á": true, "â": true, "ã": true, "ä": true, "å": true, "æ": true,
"ç": true, "è": true, "é": true, "ê": true, "ë": true, "ì": true, "í": true, "î": true,
"ï": true, "ð": true, "ñ": true, "ò": true, "ó": true, "ô": true, "õ": true, "ö": true,
"÷": false, "ø": true, "ù": true, "ú": true, "û": true, "ü": true, "ý": true, "þ": true,
"ÿ": true,
}
for testCharacter, isPermitted := range testSet {
if isPermitted {
assert.NoError(t, checkStringFormat(fmt.Sprintf("test character %q", testCharacter), testCharacter, 3))
} else {
assert.Error(t, checkStringFormat(fmt.Sprintf("test character %q", testCharacter), testCharacter, 3))
}
}
}
func TestCheckIPFormat(t *testing.T) {
t.Parallel()
// IPv4
assert.NoError(t, checkIPFormat("test IP 1.1.1.1", net.IPv4(1, 1, 1, 1)))
assert.NoError(t, checkIPFormat("test IP 192.168.1.1", net.IPv4(192, 168, 1, 1)))
assert.Error(t, checkIPFormat("test IP 255.0.0.1", net.IPv4(255, 0, 0, 1)))
// IPv6
assert.NoError(t, checkIPFormat("test IP ::1", net.ParseIP("::1")))
assert.NoError(t, checkIPFormat("test IP 2606:4700:4700::1111", net.ParseIP("2606:4700:4700::1111")))
// Invalid
assert.Error(t, checkIPFormat("test IP with length 3", net.IP([]byte{0, 0, 0})))
assert.Error(t, checkIPFormat("test IP with length 5", net.IP([]byte{0, 0, 0, 0, 0})))
assert.Error(t, checkIPFormat(
"test IP with length 15",
net.IP([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
))
assert.Error(t, checkIPFormat(
"test IP with length 17",
net.IP([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
))
}

Some files were not shown because too many files have changed in this diff Show More