wip: migrate to mono-repo. SPN has already been moved to spn/
This commit is contained in:
26
spn/TESTING.md
Normal file
26
spn/TESTING.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Testing SPN
|
||||
|
||||
This page documents ways to test if the SPN works as intended.
|
||||
|
||||
⚠ Work in Progress. Currently we are just collecting helpful things we find.
|
||||
|
||||
## Test Multi-Identity Routing
|
||||
|
||||
In order to test if the multi-identity routing is working, you can request multiple websites to display your public IP.
|
||||
If they show different values, multi-identity routing is working.
|
||||
|
||||
### Websites
|
||||
|
||||
- <https://icanhazip.com>
|
||||
- <https://ipecho.net>
|
||||
- <https://ipinfo.io>
|
||||
- <https://ipinfo.tw>
|
||||
|
||||
### Terminal
|
||||
|
||||
```sh
|
||||
curl https://icanhazip.com
|
||||
curl https://ipecho.net/plain
|
||||
curl https://ipinfo.io/ip
|
||||
curl https://ipinfo.tw/ip
|
||||
```
|
||||
5
spn/TRADEMARKS
Normal file
5
spn/TRADEMARKS
Normal file
@@ -0,0 +1,5 @@
|
||||
The names "Safing", "Portmaster", "SPN" and their logos are trademarks owned by Safing ICS Technologies GmbH (Austria).
|
||||
|
||||
Although our code is free, it is very important that we strictly enforce our trademark rights, in order to be able to protect our users against people who use the marks to commit fraud. This means that, while you have considerable freedom to redistribute and modify our software, there are tight restrictions on your ability to use our names and logos in ways which fall in the domain of trademark law, even when built into binaries that we provide.
|
||||
|
||||
This file is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License. Parts of it were taken from https://www.mozilla.org/en-US/foundation/licensing/.
|
||||
65
spn/access/account/auth.go
Normal file
65
spn/access/account/auth.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Authentication Headers.
|
||||
const (
|
||||
AuthHeaderDevice = "Device-17"
|
||||
AuthHeaderToken = "Token-17"
|
||||
AuthHeaderNextToken = "Next-Token-17"
|
||||
AuthHeaderNextTokenDeprecated = "Next_token_17"
|
||||
)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrMissingDeviceID = errors.New("missing device ID")
|
||||
ErrMissingToken = errors.New("missing token")
|
||||
)
|
||||
|
||||
// AuthToken holds an authentication token.
|
||||
type AuthToken struct {
|
||||
Device string
|
||||
Token string
|
||||
}
|
||||
|
||||
// GetAuthTokenFromRequest extracts an authentication token from a request.
|
||||
func GetAuthTokenFromRequest(request *http.Request) (*AuthToken, error) {
|
||||
device := request.Header.Get(AuthHeaderDevice)
|
||||
if device == "" {
|
||||
return nil, ErrMissingDeviceID
|
||||
}
|
||||
token := request.Header.Get(AuthHeaderToken)
|
||||
if token == "" {
|
||||
return nil, ErrMissingToken
|
||||
}
|
||||
|
||||
return &AuthToken{
|
||||
Device: device,
|
||||
Token: token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ApplyTo applies the authentication token to a request.
|
||||
func (at *AuthToken) ApplyTo(request *http.Request) {
|
||||
request.Header.Set(AuthHeaderDevice, at.Device)
|
||||
request.Header.Set(AuthHeaderToken, at.Token)
|
||||
}
|
||||
|
||||
// GetNextTokenFromResponse extracts an authentication token from a response.
|
||||
func GetNextTokenFromResponse(resp *http.Response) (token string, ok bool) {
|
||||
token = resp.Header.Get(AuthHeaderNextToken)
|
||||
if token == "" {
|
||||
// TODO: Remove when fixed on server.
|
||||
token = resp.Header.Get(AuthHeaderNextTokenDeprecated)
|
||||
}
|
||||
|
||||
return token, token != ""
|
||||
}
|
||||
|
||||
// ApplyNextTokenToResponse applies the next authentication token to a response.
|
||||
func ApplyNextTokenToResponse(w http.ResponseWriter, token string) {
|
||||
w.Header().Set(AuthHeaderNextToken, token)
|
||||
}
|
||||
14
spn/access/account/client.go
Normal file
14
spn/access/account/client.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package account
|
||||
|
||||
// Customer Agent URLs.
|
||||
const (
|
||||
CAAuthenticateURL = "/authenticate"
|
||||
CAProfileURL = "/user/profile"
|
||||
CAGetTokensURL = "/tokens"
|
||||
)
|
||||
|
||||
// Customer Hub URLs.
|
||||
const (
|
||||
CHAuthenticateURL = "/v1/authenticate"
|
||||
CHUserProfileURL = "/v1/user_profile"
|
||||
)
|
||||
137
spn/access/account/types.go
Normal file
137
spn/access/account/types.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// User, Subscription and Charge states.
|
||||
const (
|
||||
// UserStateNone is only used within Portmaster for saving information for
|
||||
// logging into the same device.
|
||||
UserStateNone = ""
|
||||
UserStateFresh = "fresh"
|
||||
UserStateQueued = "queued"
|
||||
UserStateApproved = "approved"
|
||||
UserStateSuspended = "suspended"
|
||||
UserStateLoggedOut = "loggedout" // Portmaster only.
|
||||
|
||||
SubscriptionStateManual = "manual" // Manual renewal.
|
||||
SubscriptionStateActive = "active" // Automatic renewal.
|
||||
SubscriptionStateCancelled = "cancelled" // Automatic, but canceled.
|
||||
|
||||
ChargeStatePending = "pending"
|
||||
ChargeStateCompleted = "completed"
|
||||
ChargeStateDead = "dead"
|
||||
)
|
||||
|
||||
// Agent and Hub return statuses.
|
||||
const (
|
||||
// StatusInvalidAuth [401 Unauthorized] is returned when the credentials are
|
||||
// invalid or the user was logged out.
|
||||
StatusInvalidAuth = 401
|
||||
// StatusNoAccess [403 Forbidden] is returned when the user does not have
|
||||
// an active subscription or the subscription does not include the required
|
||||
// feature for the request.
|
||||
StatusNoAccess = 403
|
||||
// StatusInvalidDevice [410 Gone] is returned when the device trying to
|
||||
// log into does not exist.
|
||||
StatusInvalidDevice = 410
|
||||
// StatusReachedDeviceLimit [409 Conflict] is returned when the device limit is reached.
|
||||
StatusReachedDeviceLimit = 409
|
||||
// StatusDeviceInactive [423 Locked] is returned when the device is locked.
|
||||
StatusDeviceInactive = 423
|
||||
// StatusNotLoggedIn [412 Precondition] is returned by the Portmaster, if an action required to be logged in, but the user is not logged in.
|
||||
StatusNotLoggedIn = 412
|
||||
|
||||
// StatusUnknownError is a special status code that signifies an unknown or
|
||||
// unexpected error by the API.
|
||||
StatusUnknownError = -1
|
||||
// StatusConnectionError is a special status code that signifies a
|
||||
// connection error.
|
||||
StatusConnectionError = -2
|
||||
)
|
||||
|
||||
// User describes an SPN user account.
|
||||
type User struct {
|
||||
Username string `json:"username"`
|
||||
State string `json:"state"`
|
||||
Balance int `json:"balance"`
|
||||
Device *Device `json:"device"`
|
||||
Subscription *Subscription `json:"subscription"`
|
||||
CurrentPlan *Plan `json:"current_plan"`
|
||||
NextPlan *Plan `json:"next_plan"`
|
||||
View *View `json:"view"`
|
||||
}
|
||||
|
||||
// MayUseSPN returns whether the user may currently use the SPN.
|
||||
func (u *User) MayUseSPN() bool {
|
||||
return u.MayUse(FeatureSPN)
|
||||
}
|
||||
|
||||
// MayUsePrioritySupport returns whether the user may currently use the priority support.
|
||||
func (u *User) MayUsePrioritySupport() bool {
|
||||
return u.MayUse(FeatureSafingSupport)
|
||||
}
|
||||
|
||||
// MayUse returns whether the user may currently use the feature identified by
|
||||
// the given feature ID.
|
||||
// Leave feature ID empty to check without feature.
|
||||
func (u *User) MayUse(featureID FeatureID) bool {
|
||||
switch {
|
||||
case u == nil:
|
||||
// We need a user, obviously.
|
||||
case u.State != UserStateApproved:
|
||||
// Only approved users may use the SPN.
|
||||
case u.Subscription == nil:
|
||||
// Need a subscription.
|
||||
case u.Subscription.EndsAt == nil:
|
||||
case time.Now().After(*u.Subscription.EndsAt):
|
||||
// Subscription needs to be active.
|
||||
case u.CurrentPlan == nil:
|
||||
// Need a plan / package.
|
||||
case featureID != "" &&
|
||||
!slices.Contains(u.CurrentPlan.FeatureIDs, featureID):
|
||||
// Required feature ID must be in plan / package feature IDs.
|
||||
default:
|
||||
// All checks passed!
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Device describes a device of an SPN user.
|
||||
type Device struct {
|
||||
Name string `json:"name"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// Subscription describes an SPN subscription.
|
||||
type Subscription struct {
|
||||
EndsAt *time.Time `json:"ends_at"`
|
||||
State string `json:"state"`
|
||||
NextBillingDate *time.Time `json:"next_billing_date"`
|
||||
PaymentProvider string `json:"payment_provider"`
|
||||
}
|
||||
|
||||
// FeatureID defines a feature that requires a plan/subscription.
|
||||
type FeatureID string
|
||||
|
||||
// A list of all supported features.
|
||||
const (
|
||||
FeatureSPN = FeatureID("spn")
|
||||
FeatureSafingSupport = FeatureID("support")
|
||||
FeatureHistory = FeatureID("history")
|
||||
FeatureBWVis = FeatureID("bw-vis")
|
||||
FeatureVPNCompat = FeatureID("vpn-compat")
|
||||
)
|
||||
|
||||
// Plan describes an SPN subscription plan.
|
||||
type Plan struct {
|
||||
Name string `json:"name"`
|
||||
Amount int `json:"amount"`
|
||||
Months int `json:"months"`
|
||||
Renewable bool `json:"renewable"`
|
||||
FeatureIDs []FeatureID `json:"feature_ids"`
|
||||
}
|
||||
123
spn/access/account/view.go
Normal file
123
spn/access/account/view.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// View holds metadata that assists in displaying account information.
|
||||
type View struct {
|
||||
Message string
|
||||
ShowAccountData bool
|
||||
ShowAccountButton bool
|
||||
ShowLoginButton bool
|
||||
ShowRefreshButton bool
|
||||
ShowLogoutButton bool
|
||||
}
|
||||
|
||||
// UpdateView updates the view and handles plan/package fallbacks.
|
||||
func (u *User) UpdateView(requestStatusCode int) {
|
||||
v := &View{}
|
||||
|
||||
// Clean up naming and fallbacks when finished.
|
||||
defer func() {
|
||||
// Display "Free" package if no plan is set or if it expired.
|
||||
switch {
|
||||
case u.CurrentPlan == nil,
|
||||
u.Subscription == nil,
|
||||
u.Subscription.EndsAt == nil:
|
||||
// Reset to free plan.
|
||||
u.CurrentPlan = &Plan{
|
||||
Name: "Free",
|
||||
}
|
||||
u.Subscription = nil
|
||||
|
||||
case u.Subscription.NextBillingDate != nil:
|
||||
// Subscription is on auto-renew.
|
||||
// Wait for update from server.
|
||||
|
||||
case time.Since(*u.Subscription.EndsAt) > 0:
|
||||
// Reset to free plan.
|
||||
u.CurrentPlan = &Plan{
|
||||
Name: "Free",
|
||||
}
|
||||
u.Subscription = nil
|
||||
}
|
||||
|
||||
// Prepend "Portmaster " to plan name.
|
||||
// TODO: Remove when Plan/Package naming has been updated.
|
||||
if u.CurrentPlan != nil && !strings.HasPrefix(u.CurrentPlan.Name, "Portmaster ") {
|
||||
u.CurrentPlan.Name = "Portmaster " + u.CurrentPlan.Name
|
||||
}
|
||||
|
||||
// Apply new view to user.
|
||||
u.View = v
|
||||
}()
|
||||
|
||||
// Set view data based on return code.
|
||||
switch requestStatusCode {
|
||||
case StatusInvalidAuth, StatusInvalidDevice, StatusDeviceInactive:
|
||||
// Account deleted or Device inactive or deleted.
|
||||
// When using token based auth, there is no difference between these cases.
|
||||
v.Message = "This device may have been deactivated or removed from your account. Please log in again."
|
||||
v.ShowAccountData = true
|
||||
v.ShowAccountButton = true
|
||||
v.ShowLoginButton = true
|
||||
v.ShowLogoutButton = true
|
||||
return
|
||||
|
||||
case StatusUnknownError:
|
||||
v.Message = "There is an unknown error in the communication with the account server. The shown information may not be accurate. "
|
||||
|
||||
case StatusConnectionError:
|
||||
v.Message = "Portmaster could not connect to the account server. The shown information may not be accurate. "
|
||||
}
|
||||
|
||||
// Set view data based on profile data.
|
||||
switch {
|
||||
case u.State == UserStateLoggedOut:
|
||||
// User logged out.
|
||||
v.ShowAccountButton = true
|
||||
v.ShowLoginButton = true
|
||||
return
|
||||
|
||||
case u.State == UserStateSuspended:
|
||||
// Account is suspended.
|
||||
v.Message += fmt.Sprintf("Your account (%s) was suspended. Please contact support for details.", u.Username)
|
||||
v.ShowAccountButton = true
|
||||
v.ShowRefreshButton = true
|
||||
v.ShowLogoutButton = true
|
||||
return
|
||||
|
||||
case u.Subscription == nil || u.Subscription.EndsAt == nil:
|
||||
// Account has never had a subscription.
|
||||
v.Message += "Get more features. Upgrade today."
|
||||
|
||||
case u.Subscription.NextBillingDate != nil:
|
||||
switch {
|
||||
case time.Since(*u.Subscription.NextBillingDate) > 0:
|
||||
v.Message += "Your auto-renewal seems to be delayed. Please refresh and check the status of your payment. Payment information may be delayed."
|
||||
case time.Until(*u.Subscription.NextBillingDate) < 24*time.Hour:
|
||||
v.Message += "Your subscription will auto-renew soon. Please note that payment information may be delayed."
|
||||
}
|
||||
|
||||
case time.Since(*u.Subscription.EndsAt) > 0:
|
||||
// Subscription expired.
|
||||
if u.CurrentPlan != nil {
|
||||
v.Message += fmt.Sprintf("Your package %s has ended. Extend it on the Account Page.", u.CurrentPlan.Name)
|
||||
} else {
|
||||
v.Message += "Your package has ended. Extend it on the Account Page."
|
||||
}
|
||||
|
||||
case time.Until(*u.Subscription.EndsAt) < 7*24*time.Hour:
|
||||
// Add generic ending soon message if the package ends in less than 7 days.
|
||||
v.Message += "Your package ends soon. Extend it on the Account Page."
|
||||
}
|
||||
|
||||
// Defaults for generally good accounts.
|
||||
v.ShowAccountData = true
|
||||
v.ShowAccountButton = true
|
||||
v.ShowRefreshButton = true
|
||||
v.ShowLogoutButton = true
|
||||
}
|
||||
168
spn/access/api.go
Normal file
168
spn/access/api.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/access/account"
|
||||
)
|
||||
|
||||
func registerAPIEndpoints() error {
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: `spn/account/login`,
|
||||
Write: api.PermitAdmin,
|
||||
WriteMethod: http.MethodPost,
|
||||
HandlerFunc: handleLogin,
|
||||
Name: "SPN Login",
|
||||
Description: "Log into your SPN account.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: `spn/account/logout`,
|
||||
Write: api.PermitAdmin,
|
||||
WriteMethod: http.MethodDelete,
|
||||
ActionFunc: handleLogout,
|
||||
Name: "SPN Logout",
|
||||
Description: "Logout from your SPN account.",
|
||||
Parameters: []api.Parameter{
|
||||
{
|
||||
Method: http.MethodDelete,
|
||||
Field: "purge",
|
||||
Value: "",
|
||||
Description: "If set, account data is purged. Otherwise, the username and device ID are kept in order to log into the same device when logging in with the same user again.",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: `spn/account/user/profile`,
|
||||
Read: api.PermitUser,
|
||||
ReadMethod: http.MethodGet,
|
||||
RecordFunc: handleGetUserProfile,
|
||||
Name: "SPN User Profile",
|
||||
Description: "Get the user profile of the logged in SPN account.",
|
||||
Parameters: []api.Parameter{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Field: "refresh",
|
||||
Value: "",
|
||||
Description: "If set, the user profile is freshly fetched from the account server.",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: `account/features`,
|
||||
Read: api.PermitUser,
|
||||
ReadMethod: http.MethodGet,
|
||||
StructFunc: func(_ *api.Request) (i interface{}, err error) {
|
||||
return struct {
|
||||
Features []Feature
|
||||
}{
|
||||
Features: features,
|
||||
}, nil
|
||||
},
|
||||
Name: "Get Account Features",
|
||||
Description: "Returns all account features.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: `account/features/{id:[A-Za-z0-9_-]+}/icon`,
|
||||
Read: api.PermitUser,
|
||||
ReadMethod: http.MethodGet,
|
||||
Name: "Returns the image of the featuare",
|
||||
MimeType: "image/svg+xml",
|
||||
DataFunc: func(ar *api.Request) (data []byte, err error) {
|
||||
featureID, ok := ar.URLVars["id"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid feature id")
|
||||
}
|
||||
|
||||
for _, feature := range features {
|
||||
if feature.ID == featureID {
|
||||
return []byte(feature.icon), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("feature id not found")
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
// Get username and password.
|
||||
username, password, ok := r.BasicAuth()
|
||||
// Request, if omitted.
|
||||
if !ok || username == "" || password == "" {
|
||||
w.Header().Set("WWW-Authenticate", "Basic realm=SPN Login")
|
||||
http.Error(w, "Login with your SPN account.", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Process login.
|
||||
user, code, err := Login(username, password)
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to login: %s", err)
|
||||
if code == 0 {
|
||||
http.Error(w, "Internal error: "+err.Error(), http.StatusInternalServerError)
|
||||
} else {
|
||||
http.Error(w, err.Error(), code)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Return success.
|
||||
_, _ = w.Write([]byte(
|
||||
fmt.Sprintf("Now logged in as %s as device %s", user.Username, user.Device.Name),
|
||||
))
|
||||
}
|
||||
|
||||
func handleLogout(ar *api.Request) (msg string, err error) {
|
||||
purge := ar.URL.Query().Get("purge") != ""
|
||||
err = Logout(false, purge)
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warningf("spn/access: failed to logout: %s", err)
|
||||
return "", err
|
||||
case purge:
|
||||
return "Logged out and user data purged.", nil
|
||||
default:
|
||||
return "Logged out.", nil
|
||||
}
|
||||
}
|
||||
|
||||
func handleGetUserProfile(ar *api.Request) (r record.Record, err error) {
|
||||
// Check if we are already authenticated.
|
||||
user, err := GetUser()
|
||||
if err != nil || user.State == account.UserStateNone {
|
||||
return nil, api.ErrorWithStatus(
|
||||
ErrNotLoggedIn,
|
||||
account.StatusInvalidAuth,
|
||||
)
|
||||
}
|
||||
|
||||
// Should we refresh the user profile?
|
||||
if ar.URL.Query().Get("refresh") != "" {
|
||||
user, _, err = UpdateUser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
550
spn/access/client.go
Normal file
550
spn/access/client.go
Normal file
@@ -0,0 +1,550 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/access/account"
|
||||
"github.com/safing/portmaster/spn/access/token"
|
||||
)
|
||||
|
||||
// Client URLs.
|
||||
const (
|
||||
AccountServer = "https://api.account.safing.io"
|
||||
LoginPath = "/api/v1/authenticate"
|
||||
UserProfilePath = "/api/v1/user/profile"
|
||||
TokenRequestSetupPath = "/api/v1/token/request/setup" //nolint:gosec
|
||||
TokenRequestIssuePath = "/api/v1/token/request/issue" //nolint:gosec
|
||||
HealthCheckPath = "/api/v1/health"
|
||||
|
||||
defaultDataFormat = dsd.CBOR
|
||||
defaultRequestTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
accountClient = &http.Client{}
|
||||
clientRequestLock sync.Mutex
|
||||
|
||||
// EnableAfterLogin automatically enables the SPN subsystem/module after login.
|
||||
EnableAfterLogin = true
|
||||
)
|
||||
|
||||
type clientRequestOptions struct {
|
||||
method string
|
||||
url string
|
||||
send interface{}
|
||||
recv interface{}
|
||||
requestTimeout time.Duration
|
||||
dataFormat uint8
|
||||
setAuthToken bool
|
||||
requireNextAuthToken bool
|
||||
logoutOnAuthError bool
|
||||
requestSetupFunc func(*http.Request) error
|
||||
}
|
||||
|
||||
func makeClientRequest(opts *clientRequestOptions) (resp *http.Response, err error) {
|
||||
// Get request timeout.
|
||||
if opts.requestTimeout == 0 {
|
||||
opts.requestTimeout = defaultRequestTimeout
|
||||
}
|
||||
// Get context for request.
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
if module.Online() {
|
||||
// Only use module context if online.
|
||||
ctx, cancel = context.WithTimeout(module.Ctx, opts.requestTimeout)
|
||||
defer cancel()
|
||||
} else {
|
||||
// Otherwise, use the background context.
|
||||
ctx, cancel = context.WithTimeout(context.Background(), opts.requestTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Create new request.
|
||||
request, err := http.NewRequestWithContext(ctx, opts.method, opts.url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request structure: %w", err)
|
||||
}
|
||||
|
||||
// Prepare body and content type.
|
||||
if opts.dataFormat == dsd.AUTO {
|
||||
opts.dataFormat = defaultDataFormat
|
||||
}
|
||||
if opts.send != nil {
|
||||
// Add data to body.
|
||||
err = dsd.DumpToHTTPRequest(request, opts.send, opts.dataFormat)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add request body: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Set requested HTTP response format.
|
||||
_, err = dsd.RequestHTTPResponseFormat(request, opts.dataFormat)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set requested response format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get auth token to apply to request.
|
||||
var authToken *AuthTokenRecord
|
||||
if opts.setAuthToken {
|
||||
authToken, err = GetAuthToken()
|
||||
if err != nil {
|
||||
return nil, ErrNotLoggedIn
|
||||
}
|
||||
authToken.Token.ApplyTo(request)
|
||||
}
|
||||
|
||||
// Do any additional custom request setup.
|
||||
if opts.requestSetupFunc != nil {
|
||||
err = opts.requestSetupFunc(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Make request.
|
||||
resp, err = accountClient.Do(request)
|
||||
if err != nil {
|
||||
updateUserWithFailedRequest(account.StatusConnectionError, false)
|
||||
tokenIssuerFailed()
|
||||
return nil, fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
log.Debugf("spn/access: request to %s returned %s", request.URL, resp.Status)
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
// Handle request error.
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK, http.StatusCreated:
|
||||
// All good!
|
||||
|
||||
case account.StatusInvalidAuth, account.StatusInvalidDevice:
|
||||
// Wrong username / password.
|
||||
updateUserWithFailedRequest(resp.StatusCode, true)
|
||||
return resp, ErrInvalidCredentials
|
||||
|
||||
case account.StatusReachedDeviceLimit:
|
||||
// Device limit is reached.
|
||||
updateUserWithFailedRequest(resp.StatusCode, true)
|
||||
return resp, ErrDeviceLimitReached
|
||||
|
||||
case account.StatusDeviceInactive:
|
||||
// Device is locked.
|
||||
updateUserWithFailedRequest(resp.StatusCode, true)
|
||||
return resp, ErrDeviceIsLocked
|
||||
|
||||
default:
|
||||
updateUserWithFailedRequest(account.StatusUnknownError, false)
|
||||
tokenIssuerFailed()
|
||||
return resp, fmt.Errorf("unexpected reply: [%d] %s", resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
// Save next auth token.
|
||||
if authToken != nil {
|
||||
err = authToken.Update(resp)
|
||||
if err != nil {
|
||||
if errors.Is(err, account.ErrMissingToken) {
|
||||
if opts.requireNextAuthToken {
|
||||
return resp, fmt.Errorf("failed to save next auth token: %w", err)
|
||||
}
|
||||
} else {
|
||||
return resp, fmt.Errorf("failed to save next auth token: %w", err)
|
||||
}
|
||||
}
|
||||
} else if opts.requireNextAuthToken {
|
||||
return resp, fmt.Errorf("failed to save next auth token: %w", account.ErrMissingToken)
|
||||
}
|
||||
|
||||
// Load response data.
|
||||
if opts.recv != nil {
|
||||
_, err = dsd.LoadFromHTTPResponse(resp, opts.recv)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
tokenIssuerIsFailing.UnSet()
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func updateUserWithFailedRequest(statusCode int, disableSubscription bool) {
|
||||
// Get user from database.
|
||||
user, err := GetUser()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrNotLoggedIn) {
|
||||
log.Warningf("spn/access: failed to get user to update with failed request: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
user.Lock()
|
||||
defer user.Unlock()
|
||||
|
||||
// Ignore update if user state is undefined or logged out.
|
||||
if user.State == "" || user.State == account.UserStateLoggedOut {
|
||||
return
|
||||
}
|
||||
|
||||
// Disable the subscription if desired.
|
||||
if disableSubscription && user.Subscription != nil {
|
||||
user.Subscription.EndsAt = nil
|
||||
}
|
||||
|
||||
// Update view with the status code and save user.
|
||||
user.UpdateView(statusCode)
|
||||
}()
|
||||
|
||||
err = user.Save()
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to save user after update with failed request: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Login logs the user into the SPN account with the given username and password.
|
||||
func Login(username, password string) (user *UserRecord, code int, err error) {
|
||||
clientRequestLock.Lock()
|
||||
defer clientRequestLock.Unlock()
|
||||
|
||||
// Trigger account update when done.
|
||||
defer module.TriggerEvent(AccountUpdateEvent, nil)
|
||||
|
||||
// Get previous user.
|
||||
previousUser, err := GetUser()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrNotLoggedIn) {
|
||||
log.Warningf("spn/access: failed to get previous for re-login: %s", err)
|
||||
}
|
||||
previousUser = nil
|
||||
}
|
||||
|
||||
// Create request options.
|
||||
userAccount := &account.User{}
|
||||
requestOptions := &clientRequestOptions{
|
||||
method: http.MethodPost,
|
||||
url: AccountServer + LoginPath,
|
||||
recv: userAccount,
|
||||
dataFormat: dsd.JSON,
|
||||
requestSetupFunc: func(request *http.Request) error {
|
||||
// Add username and password.
|
||||
request.SetBasicAuth(username, password)
|
||||
|
||||
// Try to reuse the device ID, if the username matches the previous user.
|
||||
if previousUser != nil && username == previousUser.Username {
|
||||
request.Header.Set(account.AuthHeaderDevice, previousUser.Device.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Make request.
|
||||
resp, err := makeClientRequest(requestOptions) //nolint:bodyclose // Body is closed in function.
|
||||
if err != nil {
|
||||
if resp != nil && resp.StatusCode == account.StatusInvalidDevice {
|
||||
// Try again without the previous device ID.
|
||||
previousUser = nil
|
||||
log.Info("spn/access: retrying log in without re-using previous device ID")
|
||||
resp, err = makeClientRequest(requestOptions) //nolint:bodyclose // Body is closed in function.
|
||||
}
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Save new user.
|
||||
now := time.Now()
|
||||
user = &UserRecord{
|
||||
User: userAccount,
|
||||
LoggedInAt: &now,
|
||||
}
|
||||
|
||||
user.UpdateView(0)
|
||||
err = user.Save()
|
||||
if err != nil {
|
||||
return user, resp.StatusCode, fmt.Errorf("failed to save new user profile: %w", err)
|
||||
}
|
||||
|
||||
// Save initial auth token.
|
||||
err = SaveNewAuthToken(user.Device.ID, resp)
|
||||
if err != nil {
|
||||
return user, resp.StatusCode, fmt.Errorf("failed to save initial auth token: %w", err)
|
||||
}
|
||||
|
||||
// Enable the SPN right after login.
|
||||
if user.MayUseSPN() && EnableAfterLogin {
|
||||
enableSPN()
|
||||
}
|
||||
|
||||
log.Infof("spn/access: logged in as %q on device %q", user.Username, user.Device.Name)
|
||||
return user, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// Logout logs the user out of the SPN account.
|
||||
// Specify "shallow" to keep user data in order to display data in the
|
||||
// UI - preferably when logged out be the server.
|
||||
// Specify "purge" in order to fully delete all user account data, even
|
||||
// the device ID so that logging in again will create a new device.
|
||||
func Logout(shallow, purge bool) error {
|
||||
clientRequestLock.Lock()
|
||||
defer clientRequestLock.Unlock()
|
||||
|
||||
// Trigger account update when done.
|
||||
defer module.TriggerEvent(AccountUpdateEvent, nil)
|
||||
|
||||
// Clear caches.
|
||||
clearUserCaches()
|
||||
|
||||
// Clear tokens.
|
||||
clearTokens()
|
||||
|
||||
// Delete auth token.
|
||||
err := db.Delete(authTokenRecordKey)
|
||||
if err != nil && !errors.Is(err, database.ErrNotFound) {
|
||||
return fmt.Errorf("failed to delete auth token: %w", err)
|
||||
}
|
||||
|
||||
// Delete all user data if purging.
|
||||
if purge {
|
||||
err := db.Delete(userRecordKey)
|
||||
if err != nil && !errors.Is(err, database.ErrNotFound) {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
// Disable SPN when the user logs out directly.
|
||||
disableSPN()
|
||||
|
||||
log.Info("spn/access: logged out and purged data")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Else, just update the user.
|
||||
user, err := GetUser()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotLoggedIn) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to load user for logout: %w", err)
|
||||
}
|
||||
|
||||
func() {
|
||||
user.Lock()
|
||||
defer user.Unlock()
|
||||
|
||||
if shallow {
|
||||
// Shallow logout: User stays logged in the UI to display status when
|
||||
// logged out from the Portmaster or Customer Hub.
|
||||
user.User.State = account.UserStateLoggedOut
|
||||
} else {
|
||||
// Proper logout: User is logged out from UI.
|
||||
// Reset all user data, except for username and device ID in order to log
|
||||
// into the same device again.
|
||||
user.User = &account.User{
|
||||
Username: user.Username,
|
||||
Device: &account.Device{
|
||||
ID: user.Device.ID,
|
||||
},
|
||||
}
|
||||
user.LoggedInAt = &time.Time{}
|
||||
}
|
||||
user.UpdateView(0)
|
||||
}()
|
||||
err = user.Save()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save user for logout: %w", err)
|
||||
}
|
||||
|
||||
if shallow {
|
||||
log.Info("spn/access: logged out shallow")
|
||||
} else {
|
||||
log.Info("spn/access: logged out")
|
||||
|
||||
// Disable SPN when the user logs out directly.
|
||||
disableSPN()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUser fetches the current user information from the server.
|
||||
func UpdateUser() (user *UserRecord, statusCode int, err error) {
|
||||
clientRequestLock.Lock()
|
||||
defer clientRequestLock.Unlock()
|
||||
|
||||
// Trigger account update when done.
|
||||
defer module.TriggerEvent(AccountUpdateEvent, nil)
|
||||
|
||||
// Create request options.
|
||||
userData := &account.User{}
|
||||
requestOptions := &clientRequestOptions{
|
||||
method: http.MethodGet,
|
||||
url: AccountServer + UserProfilePath,
|
||||
recv: userData,
|
||||
dataFormat: dsd.JSON,
|
||||
setAuthToken: true,
|
||||
requireNextAuthToken: true,
|
||||
logoutOnAuthError: true,
|
||||
}
|
||||
|
||||
// Make request.
|
||||
resp, err := makeClientRequest(requestOptions) //nolint:bodyclose // Body is closed in function.
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Save to previous user, if exists.
|
||||
previousUser, err := GetUser()
|
||||
if err == nil {
|
||||
func() {
|
||||
previousUser.Lock()
|
||||
defer previousUser.Unlock()
|
||||
previousUser.User = userData
|
||||
previousUser.UpdateView(resp.StatusCode)
|
||||
}()
|
||||
err := previousUser.Save()
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to save updated user profile: %s", err)
|
||||
}
|
||||
|
||||
// Notify user of nearing end of package.
|
||||
notifyOfPackageEnd(previousUser)
|
||||
|
||||
log.Infof("spn/access: got user profile, updated existing")
|
||||
return previousUser, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// Else, save as new user.
|
||||
now := time.Now()
|
||||
newUser := &UserRecord{
|
||||
User: userData,
|
||||
LoggedInAt: &now,
|
||||
}
|
||||
newUser.UpdateView(resp.StatusCode)
|
||||
err = newUser.Save()
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to save new user profile: %s", err)
|
||||
}
|
||||
|
||||
// Notify user of nearing end of package.
|
||||
notifyOfPackageEnd(newUser)
|
||||
|
||||
log.Infof("spn/access: got user profile, saved as new")
|
||||
return newUser, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// UpdateTokens fetches more tokens for handlers that need it.
|
||||
func UpdateTokens() error {
|
||||
clientRequestLock.Lock()
|
||||
defer clientRequestLock.Unlock()
|
||||
|
||||
// Check if the user may request tokens.
|
||||
user, err := GetUser()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
if !user.MayUseTheSPN() {
|
||||
return ErrMayNotUseSPN
|
||||
}
|
||||
|
||||
// Create setup request, return if not required.
|
||||
setupRequest, setupRequired := token.CreateSetupRequest()
|
||||
var setupResponse *token.SetupResponse
|
||||
if setupRequired {
|
||||
// Request setup data.
|
||||
setupResponse = &token.SetupResponse{}
|
||||
_, err := makeClientRequest(&clientRequestOptions{ //nolint:bodyclose // Body is closed in function.
|
||||
method: http.MethodPost,
|
||||
url: AccountServer + TokenRequestSetupPath,
|
||||
send: setupRequest,
|
||||
recv: setupResponse,
|
||||
dataFormat: dsd.MsgPack,
|
||||
setAuthToken: true,
|
||||
logoutOnAuthError: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to request setup data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create request for issuing new tokens.
|
||||
tokenRequest, requestRequired, err := token.CreateTokenRequest(setupResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
if !requestRequired {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Request issuing new tokens.
|
||||
issuedTokens := &token.IssuedTokens{}
|
||||
_, err = makeClientRequest(&clientRequestOptions{ //nolint:bodyclose // Body is closed in function.
|
||||
method: http.MethodPost,
|
||||
url: AccountServer + TokenRequestIssuePath,
|
||||
send: tokenRequest,
|
||||
recv: issuedTokens,
|
||||
dataFormat: dsd.MsgPack,
|
||||
setAuthToken: true,
|
||||
logoutOnAuthError: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to request tokens: %w", err)
|
||||
}
|
||||
|
||||
// Save tokens to handlers.
|
||||
err = token.ProcessIssuedTokens(issuedTokens)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process issued tokens: %w", err)
|
||||
}
|
||||
|
||||
// Log new status.
|
||||
regular, fallback := GetTokenAmount(ExpandAndConnectZones)
|
||||
log.Infof(
|
||||
"spn/access: got new tokens, now at %d regular and %d fallback tokens for expand and connect",
|
||||
regular,
|
||||
fallback,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
lastHealthCheckExpires time.Time
|
||||
lastHealthCheckLock sync.Mutex
|
||||
lastHealthCheckValidityDuration = 30 * time.Second
|
||||
)
|
||||
|
||||
func healthCheck() (ok bool) {
|
||||
lastHealthCheckLock.Lock()
|
||||
defer lastHealthCheckLock.Unlock()
|
||||
|
||||
// Return current value if recently checked.
|
||||
if time.Now().Before(lastHealthCheckExpires) {
|
||||
return tokenIssuerIsFailing.IsNotSet()
|
||||
}
|
||||
|
||||
// Check health.
|
||||
_, err := makeClientRequest(&clientRequestOptions{ //nolint:bodyclose // Body is closed in function.
|
||||
method: http.MethodGet,
|
||||
url: AccountServer + HealthCheckPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: token issuer health check failed: %s", err)
|
||||
}
|
||||
// Update health check expiry.
|
||||
lastHealthCheckExpires = time.Now().Add(lastHealthCheckValidityDuration)
|
||||
|
||||
return tokenIssuerIsFailing.IsNotSet()
|
||||
}
|
||||
79
spn/access/client_test.go
Normal file
79
spn/access/client_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
testUsername = os.Getenv("SPN_TEST_USERNAME")
|
||||
testPassword = os.Getenv("SPN_TEST_PASSWORD")
|
||||
)
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
// Skip test in CI.
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
t.Parallel()
|
||||
|
||||
if testUsername == "" || testPassword == "" {
|
||||
t.Fatal("test username or password not configured")
|
||||
}
|
||||
|
||||
loginAndRefresh(t, true, 5)
|
||||
clearUserCaches()
|
||||
loginAndRefresh(t, false, 1)
|
||||
|
||||
err := Logout(false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to log out: %s", err)
|
||||
}
|
||||
t.Logf("logged out")
|
||||
|
||||
loginAndRefresh(t, true, 1)
|
||||
|
||||
err = Logout(false, true)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to log out: %s", err)
|
||||
}
|
||||
t.Logf("logged out with purge")
|
||||
|
||||
loginAndRefresh(t, true, 1)
|
||||
}
|
||||
|
||||
func loginAndRefresh(t *testing.T, doLogin bool, refreshTimes int) {
|
||||
t.Helper()
|
||||
|
||||
if doLogin {
|
||||
_, _, err := Login(testUsername, testPassword)
|
||||
if err != nil {
|
||||
t.Fatalf("login failed: %s", err)
|
||||
}
|
||||
user, err := GetUser()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get user: %s", err)
|
||||
}
|
||||
t.Logf("user (from login): %+v", user.User)
|
||||
t.Logf("device (from login): %+v", user.User.Device)
|
||||
authToken, err := GetAuthToken()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get auth token: %s", err)
|
||||
}
|
||||
t.Logf("auth token: %+v", authToken.Token)
|
||||
}
|
||||
|
||||
for i := 0; i < refreshTimes; i++ {
|
||||
user, _, err := UpdateUser()
|
||||
if err != nil {
|
||||
t.Fatalf("getting profile failed: %s", err)
|
||||
}
|
||||
t.Logf("user (from refresh): %+v", user.User)
|
||||
|
||||
authToken, err := GetAuthToken()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get auth token: %s", err)
|
||||
}
|
||||
t.Logf("auth token: %+v", authToken.Token)
|
||||
}
|
||||
}
|
||||
258
spn/access/database.go
Normal file
258
spn/access/database.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portmaster/spn/access/account"
|
||||
)
|
||||
|
||||
const (
|
||||
userRecordKey = "core:spn/account/user"
|
||||
authTokenRecordKey = "core:spn/account/authtoken" //nolint:gosec // Not a credential.
|
||||
tokenStorageKeyTemplate = "core:spn/account/tokens/%s" //nolint:gosec // Not a credential.
|
||||
)
|
||||
|
||||
var db = database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
// UserRecord holds a SPN user account.
|
||||
type UserRecord struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
*account.User
|
||||
|
||||
LastNotifiedOfEnd *time.Time
|
||||
LoggedInAt *time.Time
|
||||
}
|
||||
|
||||
// MayUseSPN returns whether the user may currently use the SPN.
|
||||
func (user *UserRecord) MayUseSPN() bool {
|
||||
// Shadow this function in order to allow calls on a nil user.
|
||||
if user == nil || user.User == nil {
|
||||
return false
|
||||
}
|
||||
return user.User.MayUseSPN()
|
||||
}
|
||||
|
||||
// MayUsePrioritySupport returns whether the user may currently use the priority support.
|
||||
func (user *UserRecord) MayUsePrioritySupport() bool {
|
||||
// Shadow this function in order to allow calls on a nil user.
|
||||
if user == nil || user.User == nil {
|
||||
return false
|
||||
}
|
||||
return user.User.MayUsePrioritySupport()
|
||||
}
|
||||
|
||||
// MayUse returns whether the user may currently use the feature identified by
|
||||
// the given feature ID.
|
||||
// Leave feature ID empty to check without feature.
|
||||
func (user *UserRecord) MayUse(featureID account.FeatureID) bool {
|
||||
// Shadow this function in order to allow calls on a nil user.
|
||||
if user == nil || user.User == nil {
|
||||
return false
|
||||
}
|
||||
return user.User.MayUse(featureID)
|
||||
}
|
||||
|
||||
// AuthTokenRecord holds an authentication token.
|
||||
type AuthTokenRecord struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Token *account.AuthToken
|
||||
}
|
||||
|
||||
// GetToken returns the token from the record.
|
||||
func (authToken *AuthTokenRecord) GetToken() *account.AuthToken {
|
||||
authToken.Lock()
|
||||
defer authToken.Unlock()
|
||||
|
||||
return authToken.Token
|
||||
}
|
||||
|
||||
// SaveNewAuthToken saves a new auth token to the database.
|
||||
func SaveNewAuthToken(deviceID string, resp *http.Response) error {
|
||||
token, ok := account.GetNextTokenFromResponse(resp)
|
||||
if !ok {
|
||||
return account.ErrMissingToken
|
||||
}
|
||||
|
||||
newAuthToken := &AuthTokenRecord{
|
||||
Token: &account.AuthToken{
|
||||
Device: deviceID,
|
||||
Token: token,
|
||||
},
|
||||
}
|
||||
return newAuthToken.Save()
|
||||
}
|
||||
|
||||
// Update updates an existing auth token with the next token from a response.
|
||||
func (authToken *AuthTokenRecord) Update(resp *http.Response) error {
|
||||
token, ok := account.GetNextTokenFromResponse(resp)
|
||||
if !ok {
|
||||
return account.ErrMissingToken
|
||||
}
|
||||
|
||||
// Update token with new account.AuthToken.
|
||||
func() {
|
||||
authToken.Lock()
|
||||
defer authToken.Unlock()
|
||||
|
||||
authToken.Token = &account.AuthToken{
|
||||
Device: authToken.Token.Device,
|
||||
Token: token,
|
||||
}
|
||||
}()
|
||||
|
||||
return authToken.Save()
|
||||
}
|
||||
|
||||
var (
|
||||
accountCacheLock sync.Mutex
|
||||
|
||||
cachedUser *UserRecord
|
||||
cachedUserSet bool
|
||||
|
||||
cachedAuthToken *AuthTokenRecord
|
||||
)
|
||||
|
||||
func clearUserCaches() {
|
||||
accountCacheLock.Lock()
|
||||
defer accountCacheLock.Unlock()
|
||||
|
||||
cachedUser = nil
|
||||
cachedUserSet = false
|
||||
cachedAuthToken = nil
|
||||
}
|
||||
|
||||
// GetUser returns the current user account.
|
||||
// Returns nil when no user is logged in.
|
||||
func GetUser() (*UserRecord, error) {
|
||||
// Check cache.
|
||||
accountCacheLock.Lock()
|
||||
defer accountCacheLock.Unlock()
|
||||
if cachedUserSet {
|
||||
if cachedUser == nil {
|
||||
return nil, ErrNotLoggedIn
|
||||
}
|
||||
return cachedUser, nil
|
||||
}
|
||||
|
||||
// Load from disk.
|
||||
r, err := db.Get(userRecordKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNotFound) {
|
||||
cachedUser = nil
|
||||
cachedUserSet = true
|
||||
return nil, ErrNotLoggedIn
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unwrap record.
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
newUser := &UserRecord{}
|
||||
err = record.Unwrap(r, newUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cachedUser = newUser
|
||||
cachedUserSet = true
|
||||
return cachedUser, nil
|
||||
}
|
||||
|
||||
// Or adjust type.
|
||||
newUser, ok := r.(*UserRecord)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *UserRecord, but %T", r)
|
||||
}
|
||||
cachedUser = newUser
|
||||
cachedUserSet = true
|
||||
return cachedUser, nil
|
||||
}
|
||||
|
||||
// Save saves the User.
|
||||
func (user *UserRecord) Save() error {
|
||||
// Update cache.
|
||||
accountCacheLock.Lock()
|
||||
defer accountCacheLock.Unlock()
|
||||
cachedUser = user
|
||||
cachedUserSet = true
|
||||
|
||||
// Update view if unset.
|
||||
if user.View == nil {
|
||||
user.UpdateView(0)
|
||||
}
|
||||
|
||||
// Set, check and update metadata.
|
||||
if !user.KeyIsSet() {
|
||||
user.SetKey(userRecordKey)
|
||||
}
|
||||
user.UpdateMeta()
|
||||
|
||||
return db.Put(user)
|
||||
}
|
||||
|
||||
// GetAuthToken returns the current auth token.
|
||||
func GetAuthToken() (*AuthTokenRecord, error) {
|
||||
// Check cache.
|
||||
accountCacheLock.Lock()
|
||||
defer accountCacheLock.Unlock()
|
||||
if cachedAuthToken != nil {
|
||||
return cachedAuthToken, nil
|
||||
}
|
||||
|
||||
// Load from disk.
|
||||
r, err := db.Get(authTokenRecordKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unwrap record.
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
newAuthRecord := &AuthTokenRecord{}
|
||||
err = record.Unwrap(r, newAuthRecord)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cachedAuthToken = newAuthRecord
|
||||
return newAuthRecord, nil
|
||||
}
|
||||
|
||||
// Or adjust type.
|
||||
newAuthRecord, ok := r.(*AuthTokenRecord)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *AuthTokenRecord, but %T", r)
|
||||
}
|
||||
cachedAuthToken = newAuthRecord
|
||||
return newAuthRecord, nil
|
||||
}
|
||||
|
||||
// Save saves the auth token to the database.
|
||||
func (authToken *AuthTokenRecord) Save() error {
|
||||
// Update cache.
|
||||
accountCacheLock.Lock()
|
||||
defer accountCacheLock.Unlock()
|
||||
cachedAuthToken = authToken
|
||||
|
||||
// Set, check and update metadata.
|
||||
if !authToken.KeyIsSet() {
|
||||
authToken.SetKey(authTokenRecordKey)
|
||||
}
|
||||
authToken.UpdateMeta()
|
||||
authToken.Meta().MakeSecret()
|
||||
authToken.Meta().MakeCrownJewel()
|
||||
|
||||
return db.Put(authToken)
|
||||
}
|
||||
127
spn/access/features.go
Normal file
127
spn/access/features.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package access
|
||||
|
||||
import "github.com/safing/portmaster/spn/access/account"
|
||||
|
||||
// Feature describes a notable part of the program.
|
||||
type Feature struct {
|
||||
Name string
|
||||
ID string
|
||||
RequiredFeatureID account.FeatureID
|
||||
ConfigKey string
|
||||
ConfigScope string
|
||||
InPackage *Package
|
||||
Comment string
|
||||
Beta bool
|
||||
ComingSoon bool
|
||||
icon string
|
||||
}
|
||||
|
||||
// Package combines a set of features.
|
||||
type Package struct {
|
||||
Name string
|
||||
HexColor string
|
||||
InfoURL string
|
||||
}
|
||||
|
||||
var (
|
||||
infoURL = "https://safing.io/pricing/"
|
||||
packageFree = &Package{
|
||||
Name: "Free",
|
||||
HexColor: "#ffffff",
|
||||
InfoURL: infoURL,
|
||||
}
|
||||
packagePlus = &Package{
|
||||
Name: "Plus",
|
||||
HexColor: "#2fcfae",
|
||||
InfoURL: infoURL,
|
||||
}
|
||||
packagePro = &Package{
|
||||
Name: "Pro",
|
||||
HexColor: "#029ad0",
|
||||
InfoURL: infoURL,
|
||||
}
|
||||
features = []Feature{
|
||||
{
|
||||
Name: "Secure DNS",
|
||||
ID: "dns",
|
||||
ConfigScope: "dns/",
|
||||
InPackage: packageFree,
|
||||
icon: `
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round"
|
||||
d="M12 21a9.004 9.004 0 008.716-6.747M12 21a9.004 9.004 0 01-8.716-6.747M12 21c2.485 0 4.5-4.03 4.5-9S14.485 3 12 3m0 18c-2.485 0-4.5-4.03-4.5-9S9.515 3 12 3m0 0a8.997 8.997 0 017.843 4.582M12 3a8.997 8.997 0 00-7.843 4.582m15.686 0A11.953 11.953 0 0112 10.5c-2.998 0-5.74-1.1-7.843-2.918m15.686 0A8.959 8.959 0 0121 12c0 .778-.099 1.533-.284 2.253m0 0A17.919 17.919 0 0112 16.5c-3.162 0-6.133-.815-8.716-2.247m0 0A9.015 9.015 0 013 12c0-1.605.42-3.113 1.157-4.418" />
|
||||
</svg>
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "Privacy Filter",
|
||||
ID: "filter",
|
||||
ConfigScope: "filter/",
|
||||
InPackage: packageFree,
|
||||
icon: `
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M3.98 8.223A10.477 10.477 0 001.934 12C3.226 16.338 7.244 19.5 12 19.5c.993 0 1.953-.138 2.863-.395M6.228 6.228A10.45 10.45 0 0112 4.5c4.756 0 8.773 3.162 10.065 7.498a10.523 10.523 0 01-4.293 5.774M6.228 6.228L3 3m3.228 3.228l3.65 3.65m7.894 7.894L21 21m-3.228-3.228l-3.65-3.65m0 0a3 3 0 10-4.243-4.243m4.242 4.242L9.88 9.88" />
|
||||
</svg>
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "Network History",
|
||||
ID: string(account.FeatureHistory),
|
||||
RequiredFeatureID: account.FeatureHistory,
|
||||
ConfigKey: "history/enable",
|
||||
ConfigScope: "history/",
|
||||
InPackage: packagePlus,
|
||||
icon: `
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round"
|
||||
d="M12 6.042A8.967 8.967 0 006 3.75c-1.052 0-2.062.18-3 .512v14.25A8.987 8.987 0 016 18c2.305 0 4.408.867 6 2.292m0-14.25a8.966 8.966 0 016-2.292c1.052 0 2.062.18 3 .512v14.25A8.987 8.987 0 0018 18a8.967 8.967 0 00-6 2.292m0-14.25v14.25" />
|
||||
</svg>
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "Bandwidth Visibility",
|
||||
ID: string(account.FeatureBWVis),
|
||||
RequiredFeatureID: account.FeatureBWVis,
|
||||
InPackage: packagePlus,
|
||||
Beta: true,
|
||||
icon: `
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round"
|
||||
d="M3 13.125C3 12.504 3.504 12 4.125 12h2.25c.621 0 1.125.504 1.125 1.125v6.75C7.5 20.496 6.996 21 6.375 21h-2.25A1.125 1.125 0 013 19.875v-6.75zM9.75 8.625c0-.621.504-1.125 1.125-1.125h2.25c.621 0 1.125.504 1.125 1.125v11.25c0 .621-.504 1.125-1.125 1.125h-2.25a1.125 1.125 0 01-1.125-1.125V8.625zM16.5 4.125c0-.621.504-1.125 1.125-1.125h2.25C20.496 3 21 3.504 21 4.125v15.75c0 .621-.504 1.125-1.125 1.125h-2.25a1.125 1.125 0 01-1.125-1.125V4.125z" />
|
||||
</svg>
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "Safing Support",
|
||||
ID: string(account.FeatureSafingSupport),
|
||||
RequiredFeatureID: account.FeatureSafingSupport,
|
||||
InPackage: packagePlus,
|
||||
icon: `
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round"
|
||||
d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" />
|
||||
</svg>
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "Safing Privacy Network",
|
||||
ID: string(account.FeatureSPN),
|
||||
RequiredFeatureID: account.FeatureSPN,
|
||||
ConfigKey: "spn/enable",
|
||||
ConfigScope: "spn/",
|
||||
InPackage: packagePro,
|
||||
icon: `
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" stroke="currentColor" class="text-green-300">
|
||||
<g fill="none" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5">
|
||||
<path
|
||||
d="M6.488 15.581c.782.781.782 2.048 0 2.829-.782.781-2.049.781-2.83 0-.782-.781-.782-2.048 0-2.829.781-.781 2.048-.781 2.83 0M13.415 3.586c.782.781.782 2.048 0 2.829-.782.781-2.049.781-2.83 0-.782-.781-.782-2.048 0-2.829.781-.781 2.049-.781 2.83 0M20.343 15.58c.782.781.782 2.048 0 2.829-.782.781-2.049.781-2.83 0-.782-.781-.782-2.048 0-2.829.781-.781 2.048-.781 2.83 0">
|
||||
</path>
|
||||
<path
|
||||
d="M17.721 18.581C16.269 20.071 14.246 21 12 21c-1.146 0-2.231-.246-3.215-.68M4.293 15.152c-.56-1.999-.352-4.21.769-6.151.574-.995 1.334-1.814 2.205-2.449M13.975 5.254c2.017.512 3.834 1.799 4.957 3.743.569.985.899 2.041 1.018 3.103">
|
||||
</path>
|
||||
</g>
|
||||
</svg>
|
||||
`,
|
||||
},
|
||||
}
|
||||
)
|
||||
194
spn/access/module.go
Normal file
194
spn/access/module.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/spn/access/account"
|
||||
"github.com/safing/portmaster/spn/access/token"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
accountUpdateTask *modules.Task
|
||||
|
||||
tokenIssuerIsFailing = abool.New()
|
||||
tokenIssuerRetryDuration = 10 * time.Minute
|
||||
|
||||
// AccountUpdateEvent is fired when the account has changed in any way.
|
||||
AccountUpdateEvent = "account update"
|
||||
)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrDeviceIsLocked = errors.New("device is locked")
|
||||
ErrDeviceLimitReached = errors.New("device limit reached")
|
||||
ErrFallbackNotAvailable = errors.New("fallback tokens not available, token issuer is online")
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrMayNotUseSPN = errors.New("may not use SPN")
|
||||
ErrNotLoggedIn = errors.New("not logged in")
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("access", prep, start, stop, "terminal")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
module.RegisterEvent(AccountUpdateEvent, true)
|
||||
|
||||
// Register API handlers.
|
||||
if conf.Client() {
|
||||
err := registerAPIEndpoints()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
// Initialize zones.
|
||||
if err := InitializeZones(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if conf.Client() {
|
||||
// Load tokens from database.
|
||||
loadTokens()
|
||||
|
||||
// Register new task.
|
||||
accountUpdateTask = module.NewTask(
|
||||
"update account",
|
||||
UpdateAccount,
|
||||
).Repeat(24 * time.Hour).Schedule(time.Now().Add(1 * time.Minute))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
if conf.Client() {
|
||||
// Stop account update task.
|
||||
accountUpdateTask.Cancel()
|
||||
accountUpdateTask = nil
|
||||
|
||||
// Store tokens to database.
|
||||
storeTokens()
|
||||
}
|
||||
|
||||
// Reset zones.
|
||||
token.ResetRegistry()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAccount updates the user account and fetches new tokens, if needed.
|
||||
func UpdateAccount(_ context.Context, task *modules.Task) error {
|
||||
// Retry sooner if the token issuer is failing.
|
||||
defer func() {
|
||||
if tokenIssuerIsFailing.IsSet() && task != nil {
|
||||
task.Schedule(time.Now().Add(tokenIssuerRetryDuration))
|
||||
}
|
||||
}()
|
||||
|
||||
// Get current user.
|
||||
u, err := GetUser()
|
||||
if err == nil {
|
||||
// Do not update if we just updated.
|
||||
if time.Since(time.Unix(u.Meta().Modified, 0)) < 2*time.Minute {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
u, _, err = UpdateUser()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user profile: %w", err)
|
||||
}
|
||||
|
||||
err = UpdateTokens()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tokens: %w", err)
|
||||
}
|
||||
|
||||
// Schedule next check.
|
||||
switch {
|
||||
case u == nil: // No user.
|
||||
case u.Subscription == nil: // No subscription.
|
||||
case u.Subscription.EndsAt == nil: // Subscription not active
|
||||
|
||||
case time.Until(*u.Subscription.EndsAt) < 24*time.Hour &&
|
||||
time.Since(*u.Subscription.EndsAt) < 24*time.Hour:
|
||||
// Update account every hour 24h hours before and after the subscription ends.
|
||||
task.Schedule(time.Now().Add(time.Hour))
|
||||
|
||||
case u.Subscription.NextBillingDate == nil: // No auto-subscription.
|
||||
|
||||
case time.Until(*u.Subscription.NextBillingDate) < 24*time.Hour &&
|
||||
time.Since(*u.Subscription.NextBillingDate) < 24*time.Hour:
|
||||
// Update account every hour 24h hours before and after the next billing date.
|
||||
task.Schedule(time.Now().Add(time.Hour))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableSPN() {
|
||||
err := config.SetConfigOption("spn/enable", true)
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to enable the SPN during login: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func disableSPN() {
|
||||
err := config.SetConfigOption("spn/enable", false)
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to disable the SPN during logout: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TokenIssuerIsFailing returns whether token issuing is currently failing.
|
||||
func TokenIssuerIsFailing() bool {
|
||||
return tokenIssuerIsFailing.IsSet()
|
||||
}
|
||||
|
||||
func tokenIssuerFailed() {
|
||||
if !tokenIssuerIsFailing.SetToIf(false, true) {
|
||||
return
|
||||
}
|
||||
if !module.Online() {
|
||||
return
|
||||
}
|
||||
|
||||
accountUpdateTask.Schedule(time.Now().Add(tokenIssuerRetryDuration))
|
||||
}
|
||||
|
||||
// IsLoggedIn returns whether a User is currently logged in.
|
||||
func (user *UserRecord) IsLoggedIn() bool {
|
||||
user.Lock()
|
||||
defer user.Unlock()
|
||||
|
||||
switch user.State {
|
||||
case account.UserStateNone, account.UserStateLoggedOut:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// MayUseTheSPN returns whether the currently logged in User may use the SPN.
|
||||
func (user *UserRecord) MayUseTheSPN() bool {
|
||||
user.Lock()
|
||||
defer user.Unlock()
|
||||
|
||||
return user.User.MayUseSPN()
|
||||
}
|
||||
13
spn/access/module_test.go
Normal file
13
spn/access/module_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/service/core/pmtesting"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
conf.EnableClient(true)
|
||||
pmtesting.TestMain(m, module)
|
||||
}
|
||||
105
spn/access/notify.go
Normal file
105
spn/access/notify.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/notifications"
|
||||
)
|
||||
|
||||
const (
|
||||
day = 24 * time.Hour
|
||||
week = 7 * day
|
||||
|
||||
endOfPackageNearNotifID = "access:end-of-package-near"
|
||||
)
|
||||
|
||||
func notifyOfPackageEnd(u *UserRecord) {
|
||||
// TODO: Check if subscription auto-renews.
|
||||
|
||||
// Skip if there is not active subscription or if it has ended already.
|
||||
switch {
|
||||
case u.Subscription == nil, // No subscription.
|
||||
u.Subscription.EndsAt == nil, // Subscription not active.
|
||||
u.Subscription.NextBillingDate != nil, // Subscription is auto-renewing.
|
||||
time.Now().After(*u.Subscription.EndsAt): // Subscription has ended.
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate durations.
|
||||
sinceLastNotified := 52 * week // Never.
|
||||
if u.LastNotifiedOfEnd != nil {
|
||||
sinceLastNotified = time.Since(*u.LastNotifiedOfEnd)
|
||||
}
|
||||
untilEnd := time.Until(*u.Subscription.EndsAt)
|
||||
|
||||
// Notify every two days in the week before end.
|
||||
notifType := notifications.Info
|
||||
switch {
|
||||
case untilEnd < week && sinceLastNotified > 2*day:
|
||||
// Notify 7, 5, 3 and 1 days before end.
|
||||
if untilEnd < 4*day {
|
||||
notifType = notifications.Warning
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case u.CurrentPlan != nil && u.CurrentPlan.Months >= 6 &&
|
||||
untilEnd < 4*week && sinceLastNotified > week:
|
||||
// Notify 4, 3 and 2 weeks before end - on long running packages.
|
||||
|
||||
// Get names and messages.
|
||||
packageNameTitle := "Portmaster Package"
|
||||
if u.CurrentPlan != nil {
|
||||
packageNameTitle = u.CurrentPlan.Name
|
||||
}
|
||||
packageNameBody := packageNameTitle
|
||||
if !strings.HasSuffix(packageNameBody, " Package") {
|
||||
packageNameBody += " Package"
|
||||
}
|
||||
|
||||
var endsText string
|
||||
daysUntilEnd := untilEnd / day
|
||||
switch daysUntilEnd { //nolint:exhaustive
|
||||
case 0:
|
||||
endsText = "today"
|
||||
case 1:
|
||||
endsText = "tomorrow"
|
||||
default:
|
||||
endsText = fmt.Sprintf("in %d days", daysUntilEnd)
|
||||
}
|
||||
|
||||
// Send notification.
|
||||
notifications.Notify(¬ifications.Notification{
|
||||
EventID: endOfPackageNearNotifID,
|
||||
Type: notifType,
|
||||
Title: fmt.Sprintf("%s About to Expire", packageNameTitle),
|
||||
Message: fmt.Sprintf(
|
||||
"Your current %s ends %s. Extend it to keep your full privacy protections.",
|
||||
packageNameBody,
|
||||
endsText,
|
||||
),
|
||||
ShowOnSystem: notifType == notifications.Warning,
|
||||
AvailableActions: []*notifications.Action{
|
||||
{
|
||||
Text: "Open Account Page",
|
||||
Type: notifications.ActionTypeOpenURL,
|
||||
Payload: "https://account.safing.io",
|
||||
},
|
||||
{
|
||||
ID: "ack",
|
||||
Text: "Got it!",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Save that we sent a notification.
|
||||
now := time.Now()
|
||||
u.LastNotifiedOfEnd = &now
|
||||
err := u.Save()
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to save user after sending subscription ending soon notification: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
75
spn/access/op_auth.go
Normal file
75
spn/access/op_auth.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/access/token"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// OpTypeAccessCodeAuth is the type ID of the auth operation.
|
||||
const OpTypeAccessCodeAuth = "auth"
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: OpTypeAccessCodeAuth,
|
||||
Start: checkAccessCode,
|
||||
})
|
||||
}
|
||||
|
||||
// AuthorizeOp is used to authorize a session.
|
||||
type AuthorizeOp struct {
|
||||
terminal.OneOffOperationBase
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *AuthorizeOp) Type() string {
|
||||
return OpTypeAccessCodeAuth
|
||||
}
|
||||
|
||||
// AuthorizeToTerminal starts an authorization operation.
|
||||
func AuthorizeToTerminal(t terminal.Terminal) (*AuthorizeOp, *terminal.Error) {
|
||||
op := &AuthorizeOp{}
|
||||
op.Init()
|
||||
|
||||
newToken, err := GetToken(ExpandAndConnectZones)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
tErr := t.StartOperation(op, container.New(newToken.Raw()), 10*time.Second)
|
||||
if tErr != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to init auth op: %w", tErr)
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func checkAccessCode(t terminal.Terminal, opID uint32, initData *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Parse provided access token.
|
||||
receivedToken, err := token.ParseRawToken(initData.CompileData())
|
||||
if err != nil {
|
||||
return nil, terminal.ErrMalformedData.With("failed to parse access token: %w", err)
|
||||
}
|
||||
|
||||
// Check if token is valid.
|
||||
granted, err := VerifyToken(receivedToken)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrPermissionDenied.With("invalid access token: %w", err)
|
||||
}
|
||||
|
||||
// Get the authorizing terminal for applying the granted permission.
|
||||
authTerm, ok := t.(terminal.AuthorizingTerminal)
|
||||
if !ok {
|
||||
return nil, terminal.ErrIncorrectUsage.With("terminal does not handle authorization")
|
||||
}
|
||||
|
||||
// Grant permissions.
|
||||
authTerm.GrantPermission(granted)
|
||||
log.Debugf("spn/access: granted %s permissions via %s zone", t.FmtID(), receivedToken.Zone)
|
||||
|
||||
// End successfully.
|
||||
return nil, terminal.ErrExplicitAck
|
||||
}
|
||||
131
spn/access/storage.go
Normal file
131
spn/access/storage.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/access/token"
|
||||
)
|
||||
|
||||
func loadTokens() {
|
||||
for _, zone := range persistentZones {
|
||||
// Get handler of zone.
|
||||
handler, ok := token.GetHandler(zone)
|
||||
if !ok {
|
||||
log.Warningf("spn/access: could not find zone %s for loading tokens", zone)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get data from database.
|
||||
r, err := db.Get(fmt.Sprintf(tokenStorageKeyTemplate, zone))
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNotFound) {
|
||||
log.Debugf("spn/access: no %s tokens to load", zone)
|
||||
} else {
|
||||
log.Warningf("spn/access: failed to load %s tokens: %s", zone, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Get wrapper.
|
||||
wrapper, ok := r.(*record.Wrapper)
|
||||
if !ok {
|
||||
log.Warningf("spn/access: failed to parse %s tokens: expected wrapper, got %T", zone, r)
|
||||
continue
|
||||
}
|
||||
|
||||
// Load into handler.
|
||||
err = handler.Load(wrapper.Data)
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to load %s tokens: %s", zone, err)
|
||||
}
|
||||
log.Infof("spn/access: loaded %d %s tokens", handler.Amount(), zone)
|
||||
}
|
||||
}
|
||||
|
||||
func storeTokens() {
|
||||
for _, zone := range persistentZones {
|
||||
// Get handler of zone.
|
||||
handler, ok := token.GetHandler(zone)
|
||||
if !ok {
|
||||
log.Warningf("spn/access: could not find zone %s for storing tokens", zone)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate storage key.
|
||||
storageKey := fmt.Sprintf(tokenStorageKeyTemplate, zone)
|
||||
|
||||
// Check if there is data to save.
|
||||
amount := handler.Amount()
|
||||
if amount == 0 {
|
||||
// Remove possible old entry from database.
|
||||
err := db.Delete(storageKey)
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to delete possible old %s tokens from storage: %s", zone, err)
|
||||
}
|
||||
log.Debugf("spn/access: no %s tokens to store", zone)
|
||||
continue
|
||||
}
|
||||
|
||||
// Export data.
|
||||
data, err := handler.Save()
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to export %s tokens for storing: %s", zone, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Wrap data into raw record.
|
||||
r, err := record.NewWrapper(storageKey, nil, dsd.RAW, data)
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to prepare %s token export for storing: %s", zone, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Let tokens expire after one month.
|
||||
// This will regularly happen when we switch zones.
|
||||
r.UpdateMeta()
|
||||
r.Meta().MakeSecret()
|
||||
r.Meta().MakeCrownJewel()
|
||||
r.Meta().SetRelativateExpiry(30 * 86400)
|
||||
|
||||
// Save to database.
|
||||
err = db.Put(r)
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to store %s tokens: %s", zone, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("spn/access: stored %d %s tokens", amount, zone)
|
||||
}
|
||||
}
|
||||
|
||||
func clearTokens() {
|
||||
for _, zone := range persistentZones {
|
||||
// Get handler of zone.
|
||||
handler, ok := token.GetHandler(zone)
|
||||
if !ok {
|
||||
log.Warningf("spn/access: could not find zone %s for clearing tokens", zone)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clear tokens.
|
||||
handler.Clear()
|
||||
}
|
||||
|
||||
// Purge database storage prefix.
|
||||
ctx, cancel := context.WithTimeout(module.Ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
n, err := db.Purge(ctx, query.New(fmt.Sprintf(tokenStorageKeyTemplate, "")))
|
||||
if err != nil {
|
||||
log.Warningf("spn/access: failed to clear token storages: %s", err)
|
||||
return
|
||||
}
|
||||
log.Infof("spn/access: cleared %d token storages", n)
|
||||
}
|
||||
15
spn/access/token/errors.go
Normal file
15
spn/access/token/errors.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package token
|
||||
|
||||
import "errors"
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrEmpty = errors.New("token storage is empty")
|
||||
ErrNoZone = errors.New("no zone specified")
|
||||
ErrTokenInvalid = errors.New("token is invalid")
|
||||
ErrTokenMalformed = errors.New("token malformed")
|
||||
ErrTokenUsed = errors.New("token already used")
|
||||
ErrZoneMismatch = errors.New("zone mismatch")
|
||||
ErrZoneTaken = errors.New("zone taken")
|
||||
ErrZoneUnknown = errors.New("zone unknown")
|
||||
)
|
||||
13
spn/access/token/module_test.go
Normal file
13
spn/access/token/module_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/core/pmtesting"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
module := modules.Register("token", nil, nil, nil, "rng")
|
||||
pmtesting.TestMain(m, module)
|
||||
}
|
||||
552
spn/access/token/pblind.go
Normal file
552
spn/access/token/pblind.go
Normal file
@@ -0,0 +1,552 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
mrand "math/rand"
|
||||
"sync"
|
||||
|
||||
"github.com/mr-tron/base58"
|
||||
"github.com/rot256/pblind"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
)
|
||||
|
||||
const pblindSecretSize = 32
|
||||
|
||||
// PBlindToken is token based on the pblind library.
|
||||
type PBlindToken struct {
|
||||
Serial int `json:"N,omitempty"`
|
||||
Token []byte `json:"T,omitempty"`
|
||||
Signature *pblind.Signature `json:"S,omitempty"`
|
||||
}
|
||||
|
||||
// Pack packs the token.
|
||||
func (pbt *PBlindToken) Pack() ([]byte, error) {
|
||||
return dsd.Dump(pbt, dsd.CBOR)
|
||||
}
|
||||
|
||||
// UnpackPBlindToken unpacks the token.
|
||||
func UnpackPBlindToken(token []byte) (*PBlindToken, error) {
|
||||
t := &PBlindToken{}
|
||||
|
||||
_, err := dsd.Load(token, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// PBlindHandler is a handler for the pblind tokens.
|
||||
type PBlindHandler struct {
|
||||
sync.Mutex
|
||||
opts *PBlindOptions
|
||||
|
||||
publicKey *pblind.PublicKey
|
||||
privateKey *pblind.SecretKey
|
||||
|
||||
storageLock sync.Mutex
|
||||
Storage []*PBlindToken
|
||||
|
||||
// Client request state.
|
||||
requestStateLock sync.Mutex
|
||||
requestState []RequestState
|
||||
}
|
||||
|
||||
// PBlindOptions are options for the PBlindHandler.
|
||||
type PBlindOptions struct {
|
||||
Zone string
|
||||
CurveName string
|
||||
Curve elliptic.Curve
|
||||
PublicKey string
|
||||
PrivateKey string
|
||||
BatchSize int
|
||||
UseSerials bool
|
||||
RandomizeOrder bool
|
||||
Fallback bool
|
||||
SignalShouldRequest func(Handler)
|
||||
DoubleSpendProtection func([]byte) error
|
||||
}
|
||||
|
||||
// PBlindSignerState is a signer state.
|
||||
type PBlindSignerState struct {
|
||||
signers []*pblind.StateSigner
|
||||
}
|
||||
|
||||
// PBlindSetupResponse is a setup response.
|
||||
type PBlindSetupResponse struct {
|
||||
Msgs []*pblind.Message1
|
||||
}
|
||||
|
||||
// PBlindTokenRequest is a token request.
|
||||
type PBlindTokenRequest struct {
|
||||
Msgs []*pblind.Message2
|
||||
}
|
||||
|
||||
// IssuedPBlindTokens are issued pblind tokens.
|
||||
type IssuedPBlindTokens struct {
|
||||
Msgs []*pblind.Message3
|
||||
}
|
||||
|
||||
// RequestState is a request state.
|
||||
type RequestState struct {
|
||||
Token []byte
|
||||
State *pblind.StateRequester
|
||||
}
|
||||
|
||||
// NewPBlindHandler creates a new pblind handler.
|
||||
func NewPBlindHandler(opts PBlindOptions) (*PBlindHandler, error) {
|
||||
pbh := &PBlindHandler{
|
||||
opts: &opts,
|
||||
}
|
||||
|
||||
// Check curve, get from name.
|
||||
if opts.Curve == nil {
|
||||
switch opts.CurveName {
|
||||
case "P-256":
|
||||
opts.Curve = elliptic.P256()
|
||||
case "P-384":
|
||||
opts.Curve = elliptic.P384()
|
||||
case "P-521":
|
||||
opts.Curve = elliptic.P521()
|
||||
default:
|
||||
return nil, errors.New("no curve supplied")
|
||||
}
|
||||
} else if opts.CurveName != "" {
|
||||
return nil, errors.New("both curve and curve name supplied")
|
||||
}
|
||||
|
||||
// Load keys.
|
||||
switch {
|
||||
case pbh.opts.PrivateKey != "":
|
||||
keyData, err := base58.Decode(pbh.opts.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode private key: %w", err)
|
||||
}
|
||||
pivateKey := pblind.SecretKeyFromBytes(pbh.opts.Curve, keyData)
|
||||
pbh.privateKey = &pivateKey
|
||||
publicKey := pbh.privateKey.GetPublicKey()
|
||||
pbh.publicKey = &publicKey
|
||||
|
||||
// Check public key if also provided.
|
||||
if pbh.opts.PublicKey != "" {
|
||||
if pbh.opts.PublicKey != base58.Encode(pbh.publicKey.Bytes()) {
|
||||
return nil, errors.New("private and public mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
case pbh.opts.PublicKey != "":
|
||||
keyData, err := base58.Decode(pbh.opts.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode public key: %w", err)
|
||||
}
|
||||
publicKey, err := pblind.PublicKeyFromBytes(pbh.opts.Curve, keyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode public key: %w", err)
|
||||
}
|
||||
pbh.publicKey = &publicKey
|
||||
|
||||
default:
|
||||
return nil, errors.New("no key supplied")
|
||||
}
|
||||
|
||||
return pbh, nil
|
||||
}
|
||||
|
||||
func (pbh *PBlindHandler) makeInfo(serial int) (*pblind.Info, error) {
|
||||
// Gather data for info.
|
||||
infoData := container.New()
|
||||
infoData.AppendAsBlock([]byte(pbh.opts.Zone))
|
||||
if pbh.opts.UseSerials {
|
||||
infoData.AppendInt(serial)
|
||||
}
|
||||
|
||||
// Compress to point.
|
||||
info, err := pblind.CompressInfo(pbh.opts.Curve, infoData.CompileData())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compress info: %w", err)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// Zone returns the zone name.
|
||||
func (pbh *PBlindHandler) Zone() string {
|
||||
return pbh.opts.Zone
|
||||
}
|
||||
|
||||
// ShouldRequest returns whether the new tokens should be requested.
|
||||
func (pbh *PBlindHandler) ShouldRequest() bool {
|
||||
pbh.storageLock.Lock()
|
||||
defer pbh.storageLock.Unlock()
|
||||
|
||||
return pbh.shouldRequest()
|
||||
}
|
||||
|
||||
func (pbh *PBlindHandler) shouldRequest() bool {
|
||||
// Return true if storage is at or below 10%.
|
||||
return len(pbh.Storage) == 0 || pbh.opts.BatchSize/len(pbh.Storage) > 10
|
||||
}
|
||||
|
||||
// Amount returns the current amount of tokens in this handler.
|
||||
func (pbh *PBlindHandler) Amount() int {
|
||||
pbh.storageLock.Lock()
|
||||
defer pbh.storageLock.Unlock()
|
||||
|
||||
return len(pbh.Storage)
|
||||
}
|
||||
|
||||
// IsFallback returns whether this handler should only be used as a fallback.
|
||||
func (pbh *PBlindHandler) IsFallback() bool {
|
||||
return pbh.opts.Fallback
|
||||
}
|
||||
|
||||
// CreateSetup sets up signers for a request.
|
||||
func (pbh *PBlindHandler) CreateSetup() (state *PBlindSignerState, setupResponse *PBlindSetupResponse, err error) {
|
||||
state = &PBlindSignerState{
|
||||
signers: make([]*pblind.StateSigner, pbh.opts.BatchSize),
|
||||
}
|
||||
setupResponse = &PBlindSetupResponse{
|
||||
Msgs: make([]*pblind.Message1, pbh.opts.BatchSize),
|
||||
}
|
||||
|
||||
// Go through the batch.
|
||||
for i := 0; i < pbh.opts.BatchSize; i++ {
|
||||
info, err := pbh.makeInfo(i + 1)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create info #%d: %w", i, err)
|
||||
}
|
||||
|
||||
// Create signer.
|
||||
signer, err := pblind.CreateSigner(*pbh.privateKey, *info)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create signer #%d: %w", i, err)
|
||||
}
|
||||
state.signers[i] = signer
|
||||
|
||||
// Create request setup.
|
||||
setupMsg, err := signer.CreateMessage1()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create setup msg #%d: %w", i, err)
|
||||
}
|
||||
setupResponse.Msgs[i] = &setupMsg
|
||||
}
|
||||
|
||||
return state, setupResponse, nil
|
||||
}
|
||||
|
||||
// CreateTokenRequest creates a token request to be sent to the token server.
|
||||
func (pbh *PBlindHandler) CreateTokenRequest(requestSetup *PBlindSetupResponse) (request *PBlindTokenRequest, err error) {
|
||||
// Check request setup data.
|
||||
if len(requestSetup.Msgs) != pbh.opts.BatchSize {
|
||||
return nil, fmt.Errorf("invalid request setup msg count of %d", len(requestSetup.Msgs))
|
||||
}
|
||||
|
||||
// Lock and reset the request state.
|
||||
pbh.requestStateLock.Lock()
|
||||
defer pbh.requestStateLock.Unlock()
|
||||
pbh.requestState = make([]RequestState, pbh.opts.BatchSize)
|
||||
request = &PBlindTokenRequest{
|
||||
Msgs: make([]*pblind.Message2, pbh.opts.BatchSize),
|
||||
}
|
||||
|
||||
// Go through the batch.
|
||||
for i := 0; i < pbh.opts.BatchSize; i++ {
|
||||
// Check if we have setup data.
|
||||
if requestSetup.Msgs[i] == nil {
|
||||
return nil, fmt.Errorf("missing setup data #%d", i)
|
||||
}
|
||||
|
||||
// Generate secret token.
|
||||
token := make([]byte, pblindSecretSize)
|
||||
n, err := rand.Read(token) //nolint:gosec // False positive - check the imports.
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get random token #%d: %w", i, err)
|
||||
}
|
||||
if n != pblindSecretSize {
|
||||
return nil, fmt.Errorf("failed to get full random token #%d: only got %d bytes", i, n)
|
||||
}
|
||||
pbh.requestState[i].Token = token
|
||||
|
||||
// Create public metadata.
|
||||
info, err := pbh.makeInfo(i + 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make token info #%d: %w", i, err)
|
||||
}
|
||||
|
||||
// Create request and request state.
|
||||
requester, err := pblind.CreateRequester(*pbh.publicKey, *info, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request state #%d: %w", i, err)
|
||||
}
|
||||
pbh.requestState[i].State = requester
|
||||
|
||||
err = requester.ProcessMessage1(*requestSetup.Msgs[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process setup message #%d: %w", i, err)
|
||||
}
|
||||
|
||||
// Create request message.
|
||||
requestMsg, err := requester.CreateMessage2()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request message #%d: %w", i, err)
|
||||
}
|
||||
request.Msgs[i] = &requestMsg
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// IssueTokens sign the requested tokens.
|
||||
func (pbh *PBlindHandler) IssueTokens(state *PBlindSignerState, request *PBlindTokenRequest) (response *IssuedPBlindTokens, err error) {
|
||||
// Check request data.
|
||||
if len(request.Msgs) != pbh.opts.BatchSize {
|
||||
return nil, fmt.Errorf("invalid request msg count of %d", len(request.Msgs))
|
||||
}
|
||||
if len(state.signers) != pbh.opts.BatchSize {
|
||||
return nil, fmt.Errorf("invalid request state count of %d", len(request.Msgs))
|
||||
}
|
||||
|
||||
// Create response.
|
||||
response = &IssuedPBlindTokens{
|
||||
Msgs: make([]*pblind.Message3, pbh.opts.BatchSize),
|
||||
}
|
||||
|
||||
// Go through the batch.
|
||||
for i := 0; i < pbh.opts.BatchSize; i++ {
|
||||
// Check if we have request data.
|
||||
if request.Msgs[i] == nil {
|
||||
return nil, fmt.Errorf("missing request data #%d", i)
|
||||
}
|
||||
|
||||
// Process request msg.
|
||||
err = state.signers[i].ProcessMessage2(*request.Msgs[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process request msg #%d: %w", i, err)
|
||||
}
|
||||
|
||||
// Issue token.
|
||||
responseMsg, err := state.signers[i].CreateMessage3()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to issue token #%d: %w", i, err)
|
||||
}
|
||||
response.Msgs[i] = &responseMsg
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ProcessIssuedTokens processes the issued token from the server.
|
||||
func (pbh *PBlindHandler) ProcessIssuedTokens(issuedTokens *IssuedPBlindTokens) error {
|
||||
// Check data.
|
||||
if len(issuedTokens.Msgs) != pbh.opts.BatchSize {
|
||||
return fmt.Errorf("invalid issued token count of %d", len(issuedTokens.Msgs))
|
||||
}
|
||||
|
||||
// Step 1: Process issued tokens.
|
||||
|
||||
// Lock and reset the request state.
|
||||
pbh.requestStateLock.Lock()
|
||||
defer pbh.requestStateLock.Unlock()
|
||||
defer func() {
|
||||
pbh.requestState = make([]RequestState, pbh.opts.BatchSize)
|
||||
}()
|
||||
finalizedTokens := make([]*PBlindToken, pbh.opts.BatchSize)
|
||||
|
||||
// Go through the batch.
|
||||
for i := 0; i < pbh.opts.BatchSize; i++ {
|
||||
// Finalize token.
|
||||
err := pbh.requestState[i].State.ProcessMessage3(*issuedTokens.Msgs[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create final signature #%d: %w", i, err)
|
||||
}
|
||||
|
||||
// Get and check final signature.
|
||||
signature, err := pbh.requestState[i].State.Signature()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create final signature #%d: %w", i, err)
|
||||
}
|
||||
info, err := pbh.makeInfo(i + 1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to make token info #%d: %w", i, err)
|
||||
}
|
||||
if !pbh.publicKey.Check(signature, *info, pbh.requestState[i].Token) {
|
||||
return fmt.Errorf("invalid signature on #%d", i)
|
||||
}
|
||||
|
||||
// Save to temporary slice.
|
||||
newToken := &PBlindToken{
|
||||
Token: pbh.requestState[i].Token,
|
||||
Signature: &signature,
|
||||
}
|
||||
if pbh.opts.UseSerials {
|
||||
newToken.Serial = i + 1
|
||||
}
|
||||
finalizedTokens[i] = newToken
|
||||
}
|
||||
|
||||
// Step 2: Randomize received tokens
|
||||
|
||||
if pbh.opts.RandomizeOrder {
|
||||
rInt, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get seed for shuffle: %w", err)
|
||||
}
|
||||
mr := mrand.New(mrand.NewSource(rInt.Int64())) //nolint:gosec
|
||||
mr.Shuffle(len(finalizedTokens), func(i, j int) {
|
||||
finalizedTokens[i], finalizedTokens[j] = finalizedTokens[j], finalizedTokens[i]
|
||||
})
|
||||
}
|
||||
|
||||
// Step 3: Add tokens to storage.
|
||||
|
||||
// Wait for all processing to be complete, as using tokens from a faulty
|
||||
// batch can be dangerous, as the server could be doing this purposely to
|
||||
// create conditions that may benefit an attacker.
|
||||
|
||||
pbh.storageLock.Lock()
|
||||
defer pbh.storageLock.Unlock()
|
||||
|
||||
// Add finalized tokens to storage.
|
||||
pbh.Storage = append(pbh.Storage, finalizedTokens...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetToken returns a token.
|
||||
func (pbh *PBlindHandler) GetToken() (token *Token, err error) {
|
||||
pbh.storageLock.Lock()
|
||||
defer pbh.storageLock.Unlock()
|
||||
|
||||
// Check if we have supply.
|
||||
if len(pbh.Storage) == 0 {
|
||||
return nil, ErrEmpty
|
||||
}
|
||||
|
||||
// Pack token.
|
||||
data, err := pbh.Storage[0].Pack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pack token: %w", err)
|
||||
}
|
||||
|
||||
// Shift to next token.
|
||||
pbh.Storage = pbh.Storage[1:]
|
||||
|
||||
// Check if we should signal that we should request tokens.
|
||||
if pbh.opts.SignalShouldRequest != nil && pbh.shouldRequest() {
|
||||
pbh.opts.SignalShouldRequest(pbh)
|
||||
}
|
||||
|
||||
return &Token{
|
||||
Zone: pbh.opts.Zone,
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify verifies the given token.
|
||||
func (pbh *PBlindHandler) Verify(token *Token) error {
|
||||
// Check if zone matches.
|
||||
if token.Zone != pbh.opts.Zone {
|
||||
return ErrZoneMismatch
|
||||
}
|
||||
|
||||
// Unpack token.
|
||||
t, err := UnpackPBlindToken(token.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrTokenMalformed, err)
|
||||
}
|
||||
|
||||
// Check if serial is valid.
|
||||
switch {
|
||||
case pbh.opts.UseSerials && t.Serial > 0 && t.Serial <= pbh.opts.BatchSize:
|
||||
// Using serials in accepted range.
|
||||
case !pbh.opts.UseSerials && t.Serial == 0:
|
||||
// Not using serials and serial is zero.
|
||||
default:
|
||||
return fmt.Errorf("%w: invalid serial", ErrTokenMalformed)
|
||||
}
|
||||
|
||||
// Build info for checking signature.
|
||||
info, err := pbh.makeInfo(t.Serial)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrTokenMalformed, err)
|
||||
}
|
||||
|
||||
// Check signature.
|
||||
if !pbh.publicKey.Check(*t.Signature, *info, t.Token) {
|
||||
return ErrTokenInvalid
|
||||
}
|
||||
|
||||
// Check for double spending.
|
||||
if pbh.opts.DoubleSpendProtection != nil {
|
||||
if err := pbh.opts.DoubleSpendProtection(t.Token); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrTokenUsed, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PBlindStorage is a storage for pblind tokens.
|
||||
type PBlindStorage struct {
|
||||
Storage []*PBlindToken
|
||||
}
|
||||
|
||||
// Save serializes and returns the current tokens.
|
||||
func (pbh *PBlindHandler) Save() ([]byte, error) {
|
||||
pbh.storageLock.Lock()
|
||||
defer pbh.storageLock.Unlock()
|
||||
|
||||
if len(pbh.Storage) == 0 {
|
||||
return nil, ErrEmpty
|
||||
}
|
||||
|
||||
s := &PBlindStorage{
|
||||
Storage: pbh.Storage,
|
||||
}
|
||||
|
||||
return dsd.Dump(s, dsd.CBOR)
|
||||
}
|
||||
|
||||
// Load loads the given tokens into the handler.
|
||||
func (pbh *PBlindHandler) Load(data []byte) error {
|
||||
pbh.storageLock.Lock()
|
||||
defer pbh.storageLock.Unlock()
|
||||
|
||||
s := &PBlindStorage{}
|
||||
_, err := dsd.Load(data, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check signatures on load.
|
||||
for _, t := range s.Storage {
|
||||
// Build info for checking signature.
|
||||
info, err := pbh.makeInfo(t.Serial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check signature.
|
||||
if !pbh.publicKey.Check(*t.Signature, *info, t.Token) {
|
||||
return ErrTokenInvalid
|
||||
}
|
||||
}
|
||||
|
||||
pbh.Storage = s.Storage
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear clears all the tokens in the handler.
|
||||
func (pbh *PBlindHandler) Clear() {
|
||||
pbh.storageLock.Lock()
|
||||
defer pbh.storageLock.Unlock()
|
||||
|
||||
pbh.Storage = nil
|
||||
}
|
||||
39
spn/access/token/pblind_gen_test.go
Normal file
39
spn/access/token/pblind_gen_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mr-tron/base58"
|
||||
"github.com/rot256/pblind"
|
||||
)
|
||||
|
||||
func TestGeneratePBlindKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, curve := range []elliptic.Curve{
|
||||
elliptic.P256(),
|
||||
elliptic.P384(),
|
||||
elliptic.P521(),
|
||||
} {
|
||||
privateKey, err := pblind.NewSecretKey(curve)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
publicKey := privateKey.GetPublicKey()
|
||||
|
||||
fmt.Printf(
|
||||
"%s (%dbit) private key: %s\n",
|
||||
curve.Params().Name,
|
||||
curve.Params().BitSize,
|
||||
base58.Encode(privateKey.Bytes()),
|
||||
)
|
||||
fmt.Printf(
|
||||
"%s (%dbit) public key: %s\n",
|
||||
curve.Params().Name,
|
||||
curve.Params().BitSize,
|
||||
base58.Encode(publicKey.Bytes()),
|
||||
)
|
||||
}
|
||||
}
|
||||
260
spn/access/token/pblind_test.go
Normal file
260
spn/access/token/pblind_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"encoding/asn1"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rot256/pblind"
|
||||
)
|
||||
|
||||
const PBlindTestZone = "test-pblind"
|
||||
|
||||
func init() {
|
||||
// Combined testing config.
|
||||
|
||||
h, err := NewPBlindHandler(PBlindOptions{
|
||||
Zone: PBlindTestZone,
|
||||
Curve: elliptic.P256(),
|
||||
PrivateKey: "HbwGtLsqek1Fdwuz1MhNQfiY7tj9EpWHeMWHPZ9c6KYY",
|
||||
UseSerials: true,
|
||||
BatchSize: 1000,
|
||||
RandomizeOrder: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = RegisterPBlindHandler(h)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPBlind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
opts := &PBlindOptions{
|
||||
Zone: PBlindTestZone,
|
||||
Curve: elliptic.P256(),
|
||||
UseSerials: true,
|
||||
BatchSize: 1000,
|
||||
RandomizeOrder: true,
|
||||
}
|
||||
|
||||
// Issuer
|
||||
opts.PrivateKey = "HbwGtLsqek1Fdwuz1MhNQfiY7tj9EpWHeMWHPZ9c6KYY"
|
||||
issuer, err := NewPBlindHandler(*opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Client
|
||||
opts.PrivateKey = ""
|
||||
opts.PublicKey = "285oMDh3w5mxyFgpmmURifKfhkcqwwsdnePpPZ6Nqm8cc"
|
||||
client, err := NewPBlindHandler(*opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verifier
|
||||
verifier, err := NewPBlindHandler(*opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Play through the whole use case.
|
||||
|
||||
signerState, setupResponse, err := issuer.CreateSetup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
request, err := client.CreateTokenRequest(setupResponse)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
issuedTokens, err := issuer.IssueTokens(signerState, request)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = client.ProcessIssuedTokens(issuedTokens)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := client.GetToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = verifier.Verify(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPBlindLibrary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// generate a key-pair
|
||||
|
||||
curve := elliptic.P256()
|
||||
|
||||
sk, _ := pblind.NewSecretKey(curve)
|
||||
pk := sk.GetPublicKey()
|
||||
|
||||
msgStr := []byte("128b_accesstoken")
|
||||
infoStr := []byte("v=1 serial=12345")
|
||||
info, err := pblind.CompressInfo(curve, infoStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalStart := time.Now()
|
||||
batchSize := 1000
|
||||
|
||||
signers := make([]*pblind.StateSigner, batchSize)
|
||||
requesters := make([]*pblind.StateRequester, batchSize)
|
||||
toServer := make([][]byte, batchSize)
|
||||
toClient := make([][]byte, batchSize)
|
||||
|
||||
// Create signers and prep requests.
|
||||
start := time.Now()
|
||||
for i := 0; i < batchSize; i++ {
|
||||
signer, err := pblind.CreateSigner(sk, info)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signers[i] = signer
|
||||
|
||||
msg1S, err := signer.CreateMessage1()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ser1S, err := asn1.Marshal(msg1S)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
toClient[i] = ser1S
|
||||
}
|
||||
t.Logf("created %d signers and request preps in %s", batchSize, time.Since(start))
|
||||
t.Logf("sending %d bytes to client", lenOfByteSlices(toClient))
|
||||
|
||||
// Create requesters and create requests.
|
||||
start = time.Now()
|
||||
for i := 0; i < batchSize; i++ {
|
||||
requester, err := pblind.CreateRequester(pk, info, msgStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
requesters[i] = requester
|
||||
|
||||
var msg1R pblind.Message1
|
||||
_, err = asn1.Unmarshal(toClient[i], &msg1R)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = requester.ProcessMessage1(msg1R)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg2R, err := requester.CreateMessage2()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ser2R, err := asn1.Marshal(msg2R)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
toServer[i] = ser2R
|
||||
}
|
||||
t.Logf("created %d requesters and requests in %s", batchSize, time.Since(start))
|
||||
t.Logf("sending %d bytes to server", lenOfByteSlices(toServer))
|
||||
|
||||
// Sign requests
|
||||
start = time.Now()
|
||||
for i := 0; i < batchSize; i++ {
|
||||
var msg2S pblind.Message2
|
||||
_, err = asn1.Unmarshal(toServer[i], &msg2S)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = signers[i].ProcessMessage2(msg2S)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg3S, err := signers[i].CreateMessage3()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ser3S, err := asn1.Marshal(msg3S)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
toClient[i] = ser3S
|
||||
}
|
||||
t.Logf("signed %d requests in %s", batchSize, time.Since(start))
|
||||
t.Logf("sending %d bytes to client", lenOfByteSlices(toClient))
|
||||
|
||||
// Verify signed requests
|
||||
start = time.Now()
|
||||
for i := 0; i < batchSize; i++ {
|
||||
var msg3R pblind.Message3
|
||||
_, err := asn1.Unmarshal(toClient[i], &msg3R)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = requesters[i].ProcessMessage3(msg3R)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signature, err := requesters[i].Signature()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sig, err := asn1.Marshal(signature)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
toServer[i] = sig
|
||||
|
||||
// check signature
|
||||
if !pk.Check(signature, info, msgStr) {
|
||||
t.Fatal("signature invalid")
|
||||
}
|
||||
}
|
||||
t.Logf("finalized and verified %d signed tokens in %s", batchSize, time.Since(start))
|
||||
t.Logf("stored %d signed tokens in %d bytes", batchSize, lenOfByteSlices(toServer))
|
||||
|
||||
// Verify on server
|
||||
start = time.Now()
|
||||
for i := 0; i < batchSize; i++ {
|
||||
var sig pblind.Signature
|
||||
_, err := asn1.Unmarshal(toServer[i], &sig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check signature
|
||||
if !pk.Check(sig, info, msgStr) {
|
||||
t.Fatal("signature invalid")
|
||||
}
|
||||
}
|
||||
t.Logf("verified %d signed tokens in %s", batchSize, time.Since(start))
|
||||
|
||||
t.Logf("process complete")
|
||||
t.Logf("simulated the whole process for %d tokens in %s", batchSize, time.Since(totalStart))
|
||||
}
|
||||
|
||||
func lenOfByteSlices(v [][]byte) (length int) {
|
||||
for _, s := range v {
|
||||
length += len(s)
|
||||
}
|
||||
return
|
||||
}
|
||||
116
spn/access/token/registry.go
Normal file
116
spn/access/token/registry.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package token
|
||||
|
||||
import "sync"
|
||||
|
||||
// Handler represents a token handling system.
|
||||
type Handler interface {
|
||||
// Zone returns the zone name.
|
||||
Zone() string
|
||||
|
||||
// ShouldRequest returns whether the new tokens should be requested.
|
||||
ShouldRequest() bool
|
||||
|
||||
// Amount returns the current amount of tokens in this handler.
|
||||
Amount() int
|
||||
|
||||
// IsFallback returns whether this handler should only be used as a fallback.
|
||||
IsFallback() bool
|
||||
|
||||
// GetToken returns a token.
|
||||
GetToken() (token *Token, err error)
|
||||
|
||||
// Verify verifies the given token.
|
||||
Verify(token *Token) error
|
||||
|
||||
// Save serializes and returns the current tokens.
|
||||
Save() ([]byte, error)
|
||||
|
||||
// Load loads the given tokens into the handler.
|
||||
Load(data []byte) error
|
||||
|
||||
// Clear clears all the tokens in the handler.
|
||||
Clear()
|
||||
}
|
||||
|
||||
var (
|
||||
registry map[string]Handler
|
||||
pblindRegistry []*PBlindHandler
|
||||
scrambleRegistry []*ScrambleHandler
|
||||
|
||||
registryLock sync.RWMutex
|
||||
)
|
||||
|
||||
func init() {
|
||||
initRegistry()
|
||||
}
|
||||
|
||||
func initRegistry() {
|
||||
registry = make(map[string]Handler)
|
||||
pblindRegistry = make([]*PBlindHandler, 0, 1)
|
||||
scrambleRegistry = make([]*ScrambleHandler, 0, 1)
|
||||
}
|
||||
|
||||
// RegisterPBlindHandler registers a pblind handler with the registry.
|
||||
func RegisterPBlindHandler(h *PBlindHandler) error {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
if err := registerHandler(h, h.opts.Zone); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pblindRegistry = append(pblindRegistry, h)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterScrambleHandler registers a scramble handler with the registry.
|
||||
func RegisterScrambleHandler(h *ScrambleHandler) error {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
if err := registerHandler(h, h.opts.Zone); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scrambleRegistry = append(scrambleRegistry, h)
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerHandler(h Handler, zone string) error {
|
||||
if zone == "" {
|
||||
return ErrNoZone
|
||||
}
|
||||
|
||||
_, ok := registry[zone]
|
||||
if ok {
|
||||
return ErrZoneTaken
|
||||
}
|
||||
|
||||
registry[zone] = h
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHandler returns the handler of the given zone.
|
||||
func GetHandler(zone string) (handler Handler, ok bool) {
|
||||
registryLock.RLock()
|
||||
defer registryLock.RUnlock()
|
||||
|
||||
handler, ok = registry[zone]
|
||||
return
|
||||
}
|
||||
|
||||
// ResetRegistry resets the token handler registry.
|
||||
func ResetRegistry() {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
initRegistry()
|
||||
}
|
||||
|
||||
// RegistrySize returns the amount of handler registered.
|
||||
func RegistrySize() int {
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
return len(registry)
|
||||
}
|
||||
244
spn/access/token/request.go
Normal file
244
spn/access/token/request.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mr-tron/base58"
|
||||
)
|
||||
|
||||
const sessionIDSize = 32
|
||||
|
||||
// RequestHandlingState is a request handling state.
|
||||
type RequestHandlingState struct {
|
||||
SessionID string
|
||||
PBlind map[string]*PBlindSignerState
|
||||
}
|
||||
|
||||
// SetupRequest is a setup request.
|
||||
type SetupRequest struct {
|
||||
PBlind map[string]struct{} `json:"PB,omitempty"`
|
||||
}
|
||||
|
||||
// SetupResponse is a setup response.
|
||||
type SetupResponse struct {
|
||||
SessionID string `json:"ID,omitempty"`
|
||||
PBlind map[string]*PBlindSetupResponse `json:"PB,omitempty"`
|
||||
}
|
||||
|
||||
// TokenRequest is a token request.
|
||||
type TokenRequest struct { //nolint:golint // Be explicit.
|
||||
SessionID string `json:"ID,omitempty"`
|
||||
PBlind map[string]*PBlindTokenRequest `json:"PB,omitempty"`
|
||||
Scramble map[string]*ScrambleTokenRequest `json:"S,omitempty"`
|
||||
}
|
||||
|
||||
// IssuedTokens are issued tokens.
|
||||
type IssuedTokens struct {
|
||||
PBlind map[string]*IssuedPBlindTokens `json:"PB,omitempty"`
|
||||
Scramble map[string]*IssuedScrambleTokens `json:"SC,omitempty"`
|
||||
}
|
||||
|
||||
// CreateSetupRequest creates a combined setup request for all registered tokens, if needed.
|
||||
func CreateSetupRequest() (request *SetupRequest, setupRequired bool) {
|
||||
registryLock.RLock()
|
||||
defer registryLock.RUnlock()
|
||||
|
||||
request = &SetupRequest{
|
||||
PBlind: make(map[string]struct{}, len(pblindRegistry)),
|
||||
}
|
||||
|
||||
// Go through handlers and create request setups.
|
||||
for _, pblindHandler := range pblindRegistry {
|
||||
// Check if we need to request with this handler.
|
||||
if pblindHandler.ShouldRequest() {
|
||||
request.PBlind[pblindHandler.Zone()] = struct{}{}
|
||||
setupRequired = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// HandleSetupRequest handles a setup request for all registered tokens.
|
||||
func HandleSetupRequest(request *SetupRequest) (*RequestHandlingState, *SetupResponse, error) {
|
||||
registryLock.RLock()
|
||||
defer registryLock.RUnlock()
|
||||
|
||||
// Generate session token.
|
||||
randomID := make([]byte, sessionIDSize)
|
||||
n, err := rand.Read(randomID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
if n != sessionIDSize {
|
||||
return nil, nil, fmt.Errorf("failed to get full session ID: only got %d bytes", n)
|
||||
}
|
||||
sessionID := base58.Encode(randomID)
|
||||
|
||||
// Create state and response.
|
||||
state := &RequestHandlingState{
|
||||
SessionID: sessionID,
|
||||
PBlind: make(map[string]*PBlindSignerState, len(pblindRegistry)),
|
||||
}
|
||||
setup := &SetupResponse{
|
||||
SessionID: sessionID,
|
||||
PBlind: make(map[string]*PBlindSetupResponse, len(pblindRegistry)),
|
||||
}
|
||||
|
||||
// Go through handlers and create setups.
|
||||
for _, pblindHandler := range pblindRegistry {
|
||||
// Check if we have a request for this handler.
|
||||
_, ok := request.PBlind[pblindHandler.Zone()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
plindState, pblindSetup, err := pblindHandler.CreateSetup()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create setup for %s: %w", pblindHandler.Zone(), err)
|
||||
}
|
||||
|
||||
state.PBlind[pblindHandler.Zone()] = plindState
|
||||
setup.PBlind[pblindHandler.Zone()] = pblindSetup
|
||||
}
|
||||
|
||||
return state, setup, nil
|
||||
}
|
||||
|
||||
// CreateTokenRequest creates a token request for all registered tokens.
|
||||
func CreateTokenRequest(setup *SetupResponse) (request *TokenRequest, requestRequired bool, err error) {
|
||||
registryLock.RLock()
|
||||
defer registryLock.RUnlock()
|
||||
|
||||
// Check setup data.
|
||||
if setup != nil && setup.SessionID == "" {
|
||||
return nil, false, errors.New("setup data is missing a session ID")
|
||||
}
|
||||
|
||||
// Create token request.
|
||||
request = &TokenRequest{
|
||||
PBlind: make(map[string]*PBlindTokenRequest, len(pblindRegistry)),
|
||||
Scramble: make(map[string]*ScrambleTokenRequest, len(scrambleRegistry)),
|
||||
}
|
||||
if setup != nil {
|
||||
request.SessionID = setup.SessionID
|
||||
}
|
||||
|
||||
// Go through handlers and create requests.
|
||||
if setup != nil {
|
||||
for _, pblindHandler := range pblindRegistry {
|
||||
// Check if we have setup data for this handler.
|
||||
pblindSetup, ok := setup.PBlind[pblindHandler.Zone()]
|
||||
if !ok {
|
||||
// TODO: Abort if we should have received request data.
|
||||
continue
|
||||
}
|
||||
|
||||
// Create request.
|
||||
pblindRequest, err := pblindHandler.CreateTokenRequest(pblindSetup)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to create token request for %s: %w", pblindHandler.Zone(), err)
|
||||
}
|
||||
|
||||
requestRequired = true
|
||||
request.PBlind[pblindHandler.Zone()] = pblindRequest
|
||||
}
|
||||
}
|
||||
for _, scrambleHandler := range scrambleRegistry {
|
||||
// Check if we need to request with this handler.
|
||||
if scrambleHandler.ShouldRequest() {
|
||||
requestRequired = true
|
||||
request.Scramble[scrambleHandler.Zone()] = scrambleHandler.CreateTokenRequest()
|
||||
}
|
||||
}
|
||||
|
||||
return request, requestRequired, nil
|
||||
}
|
||||
|
||||
// IssueTokens issues tokens for all registered tokens.
|
||||
func IssueTokens(state *RequestHandlingState, request *TokenRequest) (response *IssuedTokens, err error) {
|
||||
registryLock.RLock()
|
||||
defer registryLock.RUnlock()
|
||||
|
||||
// Create token response.
|
||||
response = &IssuedTokens{
|
||||
PBlind: make(map[string]*IssuedPBlindTokens, len(pblindRegistry)),
|
||||
Scramble: make(map[string]*IssuedScrambleTokens, len(scrambleRegistry)),
|
||||
}
|
||||
|
||||
// Go through handlers and create requests.
|
||||
for _, pblindHandler := range pblindRegistry {
|
||||
// Check if we have all the data for issuing.
|
||||
pblindState, ok := state.PBlind[pblindHandler.Zone()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
pblindRequest, ok := request.PBlind[pblindHandler.Zone()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Issue tokens.
|
||||
pblindTokens, err := pblindHandler.IssueTokens(pblindState, pblindRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to issue tokens for %s: %w", pblindHandler.Zone(), err)
|
||||
}
|
||||
|
||||
response.PBlind[pblindHandler.Zone()] = pblindTokens
|
||||
}
|
||||
for _, scrambleHandler := range scrambleRegistry {
|
||||
// Check if we have all the data for issuing.
|
||||
scrambleRequest, ok := request.Scramble[scrambleHandler.Zone()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Issue tokens.
|
||||
scrambleTokens, err := scrambleHandler.IssueTokens(scrambleRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to issue tokens for %s: %w", scrambleHandler.Zone(), err)
|
||||
}
|
||||
|
||||
response.Scramble[scrambleHandler.Zone()] = scrambleTokens
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ProcessIssuedTokens processes issued tokens for all registered tokens.
|
||||
func ProcessIssuedTokens(response *IssuedTokens) error {
|
||||
registryLock.RLock()
|
||||
defer registryLock.RUnlock()
|
||||
|
||||
// Go through handlers and create requests.
|
||||
for _, pblindHandler := range pblindRegistry {
|
||||
// Check if we received tokens.
|
||||
pblindResponse, ok := response.PBlind[pblindHandler.Zone()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process issued tokens.
|
||||
err := pblindHandler.ProcessIssuedTokens(pblindResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process issued tokens for %s: %w", pblindHandler.Zone(), err)
|
||||
}
|
||||
}
|
||||
for _, scrambleHandler := range scrambleRegistry {
|
||||
// Check if we received tokens.
|
||||
scrambleResponse, ok := response.Scramble[scrambleHandler.Zone()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process issued tokens.
|
||||
err := scrambleHandler.ProcessIssuedTokens(scrambleResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process issued tokens for %s: %w", scrambleHandler.Zone(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
125
spn/access/token/request_test.go
Normal file
125
spn/access/token/request_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
)
|
||||
|
||||
func TestFull(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testStart := time.Now()
|
||||
|
||||
// Roundtrip 1
|
||||
|
||||
start := time.Now()
|
||||
setupRequest, setupRequired := CreateSetupRequest()
|
||||
if !setupRequired {
|
||||
t.Fatal("setup should be required")
|
||||
}
|
||||
setupRequestData, err := dsd.Dump(setupRequest, dsd.CBOR)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
setupRequest = nil // nolint:ineffassign,wastedassign // Just to be sure.
|
||||
t.Logf("setupRequest: %s, %d bytes", time.Since(start), len(setupRequestData))
|
||||
|
||||
start = time.Now()
|
||||
loadedSetupRequest := &SetupRequest{}
|
||||
_, err = dsd.Load(setupRequestData, loadedSetupRequest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverState, setupResponse, err := HandleSetupRequest(loadedSetupRequest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
setupResponseData, err := dsd.Dump(setupResponse, dsd.CBOR)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
setupResponse = nil // nolint:ineffassign,wastedassign // Just to be sure.
|
||||
t.Logf("setupResponse: %s, %d bytes", time.Since(start), len(setupResponseData))
|
||||
|
||||
// Roundtrip 2
|
||||
|
||||
start = time.Now()
|
||||
loadedSetupResponse := &SetupResponse{}
|
||||
_, err = dsd.Load(setupResponseData, loadedSetupResponse)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
request, requestRequired, err := CreateTokenRequest(loadedSetupResponse)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !requestRequired {
|
||||
t.Fatal("request should be required")
|
||||
}
|
||||
requestData, err := dsd.Dump(request, dsd.CBOR)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
request = nil // nolint:ineffassign,wastedassign // Just to be sure.
|
||||
t.Logf("request: %s, %d bytes", time.Since(start), len(requestData))
|
||||
|
||||
start = time.Now()
|
||||
loadedRequest := &TokenRequest{}
|
||||
_, err = dsd.Load(requestData, loadedRequest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
response, err := IssueTokens(serverState, loadedRequest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
responseData, err := dsd.Dump(response, dsd.CBOR)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
response = nil // nolint:ineffassign,wastedassign // Just to be sure.
|
||||
t.Logf("response: %s, %d bytes", time.Since(start), len(responseData))
|
||||
|
||||
start = time.Now()
|
||||
loadedResponse := &IssuedTokens{}
|
||||
_, err = dsd.Load(responseData, loadedResponse)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ProcessIssuedTokens(loadedResponse)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("processing: %s", time.Since(start))
|
||||
|
||||
// Token Usage
|
||||
|
||||
for _, testZone := range []string{
|
||||
PBlindTestZone,
|
||||
ScrambleTestZone,
|
||||
} {
|
||||
start = time.Now()
|
||||
|
||||
token, err := GetToken(testZone)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tokenData := token.Raw()
|
||||
token = nil // nolint:wastedassign // Just to be sure.
|
||||
|
||||
loadedToken, err := ParseRawToken(tokenData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = VerifyToken(loadedToken)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("using %s token: %s", testZone, time.Since(start))
|
||||
}
|
||||
|
||||
t.Logf("full simulation took %s", time.Since(testStart))
|
||||
}
|
||||
240
spn/access/token/scramble.go
Normal file
240
spn/access/token/scramble.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/mr-tron/base58"
|
||||
|
||||
"github.com/safing/jess/lhash"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
)
|
||||
|
||||
const (
|
||||
scrambleSecretSize = 32
|
||||
)
|
||||
|
||||
// ScrambleToken is token based on hashing.
|
||||
type ScrambleToken struct {
|
||||
Token []byte
|
||||
}
|
||||
|
||||
// Pack packs the token.
|
||||
func (pbt *ScrambleToken) Pack() ([]byte, error) {
|
||||
return pbt.Token, nil
|
||||
}
|
||||
|
||||
// UnpackScrambleToken unpacks the token.
|
||||
func UnpackScrambleToken(token []byte) (*ScrambleToken, error) {
|
||||
return &ScrambleToken{Token: token}, nil
|
||||
}
|
||||
|
||||
// ScrambleHandler is a handler for the scramble tokens.
|
||||
type ScrambleHandler struct {
|
||||
sync.Mutex
|
||||
opts *ScrambleOptions
|
||||
|
||||
storageLock sync.Mutex
|
||||
Storage []*ScrambleToken
|
||||
|
||||
verifiersLock sync.RWMutex
|
||||
verifiers map[string]*ScrambleToken
|
||||
}
|
||||
|
||||
// ScrambleOptions are options for the ScrambleHandler.
|
||||
type ScrambleOptions struct {
|
||||
Zone string
|
||||
Algorithm lhash.Algorithm
|
||||
InitialTokens []string
|
||||
InitialVerifiers []string
|
||||
Fallback bool
|
||||
}
|
||||
|
||||
// ScrambleTokenRequest is a token request.
|
||||
type ScrambleTokenRequest struct{}
|
||||
|
||||
// IssuedScrambleTokens are issued scrambled tokens.
|
||||
type IssuedScrambleTokens struct {
|
||||
Tokens []*ScrambleToken
|
||||
}
|
||||
|
||||
// NewScrambleHandler creates a new scramble handler.
|
||||
func NewScrambleHandler(opts ScrambleOptions) (*ScrambleHandler, error) {
|
||||
sh := &ScrambleHandler{
|
||||
opts: &opts,
|
||||
verifiers: make(map[string]*ScrambleToken, len(opts.InitialTokens)+len(opts.InitialVerifiers)),
|
||||
}
|
||||
|
||||
// Add initial tokens.
|
||||
sh.Storage = make([]*ScrambleToken, len(opts.InitialTokens))
|
||||
for i, token := range opts.InitialTokens {
|
||||
// Add to storage.
|
||||
tokenData, err := base58.Decode(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode initial token %q: %w", token, err)
|
||||
}
|
||||
sh.Storage[i] = &ScrambleToken{
|
||||
Token: tokenData,
|
||||
}
|
||||
|
||||
// Add to verifiers.
|
||||
scrambledToken := lhash.Digest(sh.opts.Algorithm, tokenData).Bytes()
|
||||
sh.verifiers[string(scrambledToken)] = sh.Storage[i]
|
||||
}
|
||||
|
||||
// Add initial verifiers.
|
||||
for _, verifier := range opts.InitialVerifiers {
|
||||
verifierData, err := base58.Decode(verifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode verifier %q: %w", verifier, err)
|
||||
}
|
||||
sh.verifiers[string(verifierData)] = &ScrambleToken{}
|
||||
}
|
||||
|
||||
return sh, nil
|
||||
}
|
||||
|
||||
// Zone returns the zone name.
|
||||
func (sh *ScrambleHandler) Zone() string {
|
||||
return sh.opts.Zone
|
||||
}
|
||||
|
||||
// ShouldRequest returns whether the new tokens should be requested.
|
||||
func (sh *ScrambleHandler) ShouldRequest() bool {
|
||||
sh.storageLock.Lock()
|
||||
defer sh.storageLock.Unlock()
|
||||
|
||||
return len(sh.Storage) == 0
|
||||
}
|
||||
|
||||
// Amount returns the current amount of tokens in this handler.
|
||||
func (sh *ScrambleHandler) Amount() int {
|
||||
sh.storageLock.Lock()
|
||||
defer sh.storageLock.Unlock()
|
||||
|
||||
return len(sh.Storage)
|
||||
}
|
||||
|
||||
// IsFallback returns whether this handler should only be used as a fallback.
|
||||
func (sh *ScrambleHandler) IsFallback() bool {
|
||||
return sh.opts.Fallback
|
||||
}
|
||||
|
||||
// CreateTokenRequest creates a token request to be sent to the token server.
|
||||
func (sh *ScrambleHandler) CreateTokenRequest() (request *ScrambleTokenRequest) {
|
||||
return &ScrambleTokenRequest{}
|
||||
}
|
||||
|
||||
// IssueTokens sign the requested tokens.
|
||||
func (sh *ScrambleHandler) IssueTokens(request *ScrambleTokenRequest) (response *IssuedScrambleTokens, err error) {
|
||||
// Copy the storage.
|
||||
tokens := make([]*ScrambleToken, len(sh.Storage))
|
||||
copy(tokens, sh.Storage)
|
||||
|
||||
return &IssuedScrambleTokens{
|
||||
Tokens: tokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessIssuedTokens processes the issued token from the server.
|
||||
func (sh *ScrambleHandler) ProcessIssuedTokens(issuedTokens *IssuedScrambleTokens) error {
|
||||
sh.verifiersLock.RLock()
|
||||
defer sh.verifiersLock.RUnlock()
|
||||
|
||||
// Validate tokens.
|
||||
for i, newToken := range issuedTokens.Tokens {
|
||||
// Scramle token.
|
||||
scrambledToken := lhash.Digest(sh.opts.Algorithm, newToken.Token).Bytes()
|
||||
|
||||
// Check if token is valid.
|
||||
_, ok := sh.verifiers[string(scrambledToken)]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token on #%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Copy to storage.
|
||||
sh.Storage = issuedTokens.Tokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify verifies the given token.
|
||||
func (sh *ScrambleHandler) Verify(token *Token) error {
|
||||
if token.Zone != sh.opts.Zone {
|
||||
return ErrZoneMismatch
|
||||
}
|
||||
|
||||
// Hash the data.
|
||||
scrambledToken := lhash.Digest(sh.opts.Algorithm, token.Data).Bytes()
|
||||
|
||||
sh.verifiersLock.RLock()
|
||||
defer sh.verifiersLock.RUnlock()
|
||||
|
||||
// Check if token is valid.
|
||||
_, ok := sh.verifiers[string(scrambledToken)]
|
||||
if !ok {
|
||||
return ErrTokenInvalid
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetToken returns a token.
|
||||
func (sh *ScrambleHandler) GetToken() (*Token, error) {
|
||||
sh.storageLock.Lock()
|
||||
defer sh.storageLock.Unlock()
|
||||
|
||||
if len(sh.Storage) == 0 {
|
||||
return nil, ErrEmpty
|
||||
}
|
||||
|
||||
return &Token{
|
||||
Zone: sh.opts.Zone,
|
||||
Data: sh.Storage[0].Token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ScrambleStorage is a storage for scramble tokens.
|
||||
type ScrambleStorage struct {
|
||||
Storage []*ScrambleToken
|
||||
}
|
||||
|
||||
// Save serializes and returns the current tokens.
|
||||
func (sh *ScrambleHandler) Save() ([]byte, error) {
|
||||
sh.storageLock.Lock()
|
||||
defer sh.storageLock.Unlock()
|
||||
|
||||
if len(sh.Storage) == 0 {
|
||||
return nil, ErrEmpty
|
||||
}
|
||||
|
||||
s := &ScrambleStorage{
|
||||
Storage: sh.Storage,
|
||||
}
|
||||
|
||||
return dsd.Dump(s, dsd.CBOR)
|
||||
}
|
||||
|
||||
// Load loads the given tokens into the handler.
|
||||
func (sh *ScrambleHandler) Load(data []byte) error {
|
||||
sh.storageLock.Lock()
|
||||
defer sh.storageLock.Unlock()
|
||||
|
||||
s := &ScrambleStorage{}
|
||||
_, err := dsd.Load(data, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sh.Storage = s.Storage
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear clears all the tokens in the handler.
|
||||
func (sh *ScrambleHandler) Clear() {
|
||||
sh.storageLock.Lock()
|
||||
defer sh.storageLock.Unlock()
|
||||
|
||||
sh.Storage = nil
|
||||
}
|
||||
48
spn/access/token/scramble_gen_test.go
Normal file
48
spn/access/token/scramble_gen_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mr-tron/base58"
|
||||
|
||||
"github.com/safing/jess/lhash"
|
||||
)
|
||||
|
||||
type genAlgs struct {
|
||||
alg lhash.Algorithm
|
||||
name string
|
||||
}
|
||||
|
||||
func TestGenerateScrambleKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, alg := range []genAlgs{
|
||||
{alg: lhash.SHA2_256, name: "SHA2_256"},
|
||||
{alg: lhash.SHA3_256, name: "SHA3_256"},
|
||||
{alg: lhash.SHA3_512, name: "SHA3_512"},
|
||||
{alg: lhash.BLAKE2b_256, name: "BLAKE2b_256"},
|
||||
} {
|
||||
token := make([]byte, scrambleSecretSize)
|
||||
n, err := rand.Read(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != scrambleSecretSize {
|
||||
t.Fatalf("only got %d bytes", n)
|
||||
}
|
||||
scrambledToken := lhash.Digest(alg.alg, token).Bytes()
|
||||
|
||||
fmt.Printf(
|
||||
"%s secret token: %s\n",
|
||||
alg.name,
|
||||
base58.Encode(token),
|
||||
)
|
||||
fmt.Printf(
|
||||
"%s scrambled (public) token: %s\n",
|
||||
alg.name,
|
||||
base58.Encode(scrambledToken),
|
||||
)
|
||||
}
|
||||
}
|
||||
84
spn/access/token/scramble_test.go
Normal file
84
spn/access/token/scramble_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/jess/lhash"
|
||||
)
|
||||
|
||||
const ScrambleTestZone = "test-scramble"
|
||||
|
||||
func init() {
|
||||
// Combined testing config.
|
||||
|
||||
h, err := NewScrambleHandler(ScrambleOptions{
|
||||
Zone: ScrambleTestZone,
|
||||
Algorithm: lhash.SHA2_256,
|
||||
InitialTokens: []string{"2VqJ8BvDew1tUpytZhR7tuvq7ToPpW3tQtHvu3veE3iW"},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = RegisterScrambleHandler(h)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScramble(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
opts := &ScrambleOptions{
|
||||
Zone: ScrambleTestZone,
|
||||
Algorithm: lhash.SHA2_256,
|
||||
}
|
||||
|
||||
// Issuer
|
||||
opts.InitialTokens = []string{"2VqJ8BvDew1tUpytZhR7tuvq7ToPpW3tQtHvu3veE3iW"}
|
||||
issuer, err := NewScrambleHandler(*opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Client
|
||||
opts.InitialTokens = nil
|
||||
opts.InitialVerifiers = []string{"Cy9tz37Xq9NiXGDRU9yicjGU62GjXskE9KqUmuoddSxaE3"}
|
||||
client, err := NewScrambleHandler(*opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verifier
|
||||
verifier, err := NewScrambleHandler(*opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Play through the whole use case.
|
||||
|
||||
request := client.CreateTokenRequest()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
issuedTokens, err := issuer.IssueTokens(request)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = client.ProcessIssuedTokens(issuedTokens)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := client.GetToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = verifier.Verify(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
83
spn/access/token/token.go
Normal file
83
spn/access/token/token.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mr-tron/base58"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
)
|
||||
|
||||
// Token represents a token, consisting of a zone (name) and some data.
|
||||
type Token struct {
|
||||
Zone string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// GetToken returns a token of the given zone.
|
||||
func GetToken(zone string) (*Token, error) {
|
||||
handler, ok := GetHandler(zone)
|
||||
if !ok {
|
||||
return nil, ErrZoneUnknown
|
||||
}
|
||||
|
||||
return handler.GetToken()
|
||||
}
|
||||
|
||||
// VerifyToken verifies the given token.
|
||||
func VerifyToken(token *Token) error {
|
||||
handler, ok := GetHandler(token.Zone)
|
||||
if !ok {
|
||||
return ErrZoneUnknown
|
||||
}
|
||||
|
||||
return handler.Verify(token)
|
||||
}
|
||||
|
||||
// Raw returns the raw format of the token.
|
||||
func (c *Token) Raw() []byte {
|
||||
cont := container.New()
|
||||
cont.Append([]byte(c.Zone))
|
||||
cont.Append([]byte(":"))
|
||||
cont.Append(c.Data)
|
||||
return cont.CompileData()
|
||||
}
|
||||
|
||||
// String returns the stringified format of the token.
|
||||
func (c *Token) String() string {
|
||||
return c.Zone + ":" + base58.Encode(c.Data)
|
||||
}
|
||||
|
||||
// ParseRawToken parses a raw token.
|
||||
func ParseRawToken(code []byte) (*Token, error) {
|
||||
splitted := bytes.SplitN(code, []byte(":"), 2)
|
||||
if len(splitted) < 2 {
|
||||
return nil, errors.New("invalid code format: zone/data separator missing")
|
||||
}
|
||||
|
||||
return &Token{
|
||||
Zone: string(splitted[0]),
|
||||
Data: splitted[1],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseToken parses a stringified token.
|
||||
func ParseToken(code string) (*Token, error) {
|
||||
splitted := strings.SplitN(code, ":", 2)
|
||||
if len(splitted) < 2 {
|
||||
return nil, errors.New("invalid code format: zone/data separator missing")
|
||||
}
|
||||
|
||||
data, err := base58.Decode(splitted[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid code format: %w", err)
|
||||
}
|
||||
|
||||
return &Token{
|
||||
Zone: splitted[0],
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
33
spn/access/token/token_test.go
Normal file
33
spn/access/token/token_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portbase/rng"
|
||||
)
|
||||
|
||||
func TestToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
randomData, err := rng.Bytes(32)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Token{
|
||||
Zone: "test",
|
||||
Data: randomData,
|
||||
}
|
||||
|
||||
s := c.String()
|
||||
_, err = ParseToken(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := c.Raw()
|
||||
_, err = ParseRawToken(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
257
spn/access/zones.go
Normal file
257
spn/access/zones.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/jess/lhash"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/access/token"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
var (
|
||||
// ExpandAndConnectZones are the zones that grant access to the expand and
|
||||
// connect operations.
|
||||
ExpandAndConnectZones = []string{"pblind1", "alpha2", "fallback1"}
|
||||
|
||||
zonePermissions = map[string]terminal.Permission{
|
||||
"pblind1": terminal.AddPermissions(terminal.MayExpand, terminal.MayConnect),
|
||||
"alpha2": terminal.AddPermissions(terminal.MayExpand, terminal.MayConnect),
|
||||
"fallback1": terminal.AddPermissions(terminal.MayExpand, terminal.MayConnect),
|
||||
}
|
||||
persistentZones = ExpandAndConnectZones
|
||||
|
||||
enableTestMode = abool.New()
|
||||
)
|
||||
|
||||
// EnableTestMode enables the test mode, leading the access module to only
|
||||
// register a test zone.
|
||||
// This should not be used to test the access module itself.
|
||||
func EnableTestMode() {
|
||||
enableTestMode.Set()
|
||||
}
|
||||
|
||||
// InitializeZones initialized the permission zones.
|
||||
// It initializes the test zones, if EnableTestMode was called before.
|
||||
// Must only be called once.
|
||||
func InitializeZones() error {
|
||||
// Check if we are testing.
|
||||
if enableTestMode.IsSet() {
|
||||
return initializeTestZone()
|
||||
}
|
||||
|
||||
// Special client zone config.
|
||||
var requestSignalHandler func(token.Handler)
|
||||
if conf.Client() {
|
||||
requestSignalHandler = shouldRequestTokensHandler
|
||||
}
|
||||
|
||||
// Register pblind1 as the first primary zone.
|
||||
ph, err := token.NewPBlindHandler(token.PBlindOptions{
|
||||
Zone: "pblind1",
|
||||
CurveName: "P-256",
|
||||
PublicKey: "eXoJXzXbM66UEsM2eVi9HwyBPLMfVnNrC7gNrsfMUJDs",
|
||||
UseSerials: true,
|
||||
BatchSize: 1000,
|
||||
RandomizeOrder: true,
|
||||
SignalShouldRequest: requestSignalHandler,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create pblind1 token handler: %w", err)
|
||||
}
|
||||
err = token.RegisterPBlindHandler(ph)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register pblind1 token handler: %w", err)
|
||||
}
|
||||
|
||||
// Register fallback1 zone as fallback when the issuer is not available.
|
||||
sh, err := token.NewScrambleHandler(token.ScrambleOptions{
|
||||
Zone: "fallback1",
|
||||
Algorithm: lhash.BLAKE2b_256,
|
||||
InitialVerifiers: []string{"ZwkQoaAttVBMURzeLzNXokFBMAMUUwECfM1iHojcVKBmjk"},
|
||||
Fallback: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create fallback1 token handler: %w", err)
|
||||
}
|
||||
err = token.RegisterScrambleHandler(sh)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register fallback1 token handler: %w", err)
|
||||
}
|
||||
|
||||
// Register alpha2 zone for transition phase.
|
||||
sh, err = token.NewScrambleHandler(token.ScrambleOptions{
|
||||
Zone: "alpha2",
|
||||
Algorithm: lhash.BLAKE2b_256,
|
||||
InitialVerifiers: []string{"ZwojEvXZmAv7SZdNe7m94Xzu7F9J8vULqKf7QYtoTpN2tH"},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create alpha2 token handler: %w", err)
|
||||
}
|
||||
err = token.RegisterScrambleHandler(sh)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register alpha2 token handler: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initializeTestZone() error {
|
||||
// Safeguard checks if we should really enable the test zone.
|
||||
if !strings.HasSuffix(os.Args[0], ".test") {
|
||||
return errors.New("tried to enable test mode, but no test binary was detected")
|
||||
}
|
||||
if token.RegistrySize() > 0 {
|
||||
return fmt.Errorf("tried to enable test zone, but %d handlers are already registered", token.RegistrySize())
|
||||
}
|
||||
|
||||
// Reset zones.
|
||||
token.ResetRegistry()
|
||||
|
||||
// Set eligible zones.
|
||||
ExpandAndConnectZones = []string{"unittest"}
|
||||
zonePermissions = map[string]terminal.Permission{
|
||||
"unittest": terminal.AddPermissions(terminal.MayExpand, terminal.MayConnect),
|
||||
}
|
||||
|
||||
// Register unittest zone as for testing.
|
||||
sh, err := token.NewScrambleHandler(token.ScrambleOptions{
|
||||
Zone: "unittest",
|
||||
Algorithm: lhash.BLAKE2b_256,
|
||||
InitialTokens: []string{"6jFqLA93uSLL52utGKrvctG3ZfopSQ8WFqjsRK1c2Svt"},
|
||||
InitialVerifiers: []string{"ZwoEoL59sr81s7WnF2vydGzjeejE3u8CqVafig1NTQzUr7"},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create unittest token handler: %w", err)
|
||||
}
|
||||
err = token.RegisterScrambleHandler(sh)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register unittest token handler: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldRequestTokensHandler(_ token.Handler) {
|
||||
// accountUpdateTask is always set in client mode and when the module is online.
|
||||
// Check if it's set in case this gets executed in other circumstances.
|
||||
if accountUpdateTask == nil {
|
||||
log.Warningf("spn/access: trying to trigger account update, but the task is not available")
|
||||
return
|
||||
}
|
||||
|
||||
accountUpdateTask.StartASAP()
|
||||
}
|
||||
|
||||
// GetTokenAmount returns the amount of tokens for the given zones.
|
||||
func GetTokenAmount(zones []string) (regular, fallback int) {
|
||||
handlerLoop:
|
||||
for _, zone := range zones {
|
||||
// Get handler and check if it should be used.
|
||||
handler, ok := token.GetHandler(zone)
|
||||
if !ok {
|
||||
log.Warningf("spn/access: use of non-registered zone %q", zone)
|
||||
continue handlerLoop
|
||||
}
|
||||
|
||||
if handler.IsFallback() {
|
||||
fallback += handler.Amount()
|
||||
} else {
|
||||
regular += handler.Amount()
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ShouldRequest returns whether tokens should be requested for the given zones.
|
||||
func ShouldRequest(zones []string) (shouldRequest bool) {
|
||||
handlerLoop:
|
||||
for _, zone := range zones {
|
||||
// Get handler and check if it should be used.
|
||||
handler, ok := token.GetHandler(zone)
|
||||
if !ok {
|
||||
log.Warningf("spn/access: use of non-registered zone %q", zone)
|
||||
continue handlerLoop
|
||||
}
|
||||
|
||||
// Go through all handlers every time as this will be the case anyway most
|
||||
// of the time and will help us better catch zone misconfiguration.
|
||||
if handler.ShouldRequest() {
|
||||
shouldRequest = true
|
||||
}
|
||||
}
|
||||
|
||||
return shouldRequest
|
||||
}
|
||||
|
||||
// GetToken returns a token of one of the given zones.
|
||||
func GetToken(zones []string) (t *token.Token, err error) {
|
||||
handlerSelection:
|
||||
for _, zone := range zones {
|
||||
// Get handler and check if it should be used.
|
||||
handler, ok := token.GetHandler(zone)
|
||||
switch {
|
||||
case !ok:
|
||||
log.Warningf("spn/access: use of non-registered zone %q", zone)
|
||||
continue handlerSelection
|
||||
case handler.IsFallback() && !TokenIssuerIsFailing():
|
||||
// Skip fallback zone if everything works.
|
||||
continue handlerSelection
|
||||
}
|
||||
|
||||
// Get token from handler.
|
||||
t, err = token.GetToken(zone)
|
||||
if err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Return existing error, if exists.
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, token.ErrEmpty
|
||||
}
|
||||
|
||||
// VerifyRawToken verifies a raw token.
|
||||
func VerifyRawToken(data []byte) (granted terminal.Permission, err error) {
|
||||
t, err := token.ParseRawToken(data)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
return VerifyToken(t)
|
||||
}
|
||||
|
||||
// VerifyToken verifies a token.
|
||||
func VerifyToken(t *token.Token) (granted terminal.Permission, err error) {
|
||||
handler, ok := token.GetHandler(t.Zone)
|
||||
if !ok {
|
||||
return terminal.NoPermission, token.ErrZoneUnknown
|
||||
}
|
||||
|
||||
// Check if the token is a fallback token.
|
||||
if handler.IsFallback() && !healthCheck() {
|
||||
return terminal.NoPermission, ErrFallbackNotAvailable
|
||||
}
|
||||
|
||||
// Verify token.
|
||||
err = handler.Verify(t)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to verify token: %w", err)
|
||||
}
|
||||
|
||||
// Return permission of zone.
|
||||
granted, ok = zonePermissions[t.Zone]
|
||||
if !ok {
|
||||
return terminal.NoPermission, nil
|
||||
}
|
||||
return granted, nil
|
||||
}
|
||||
392
spn/cabin/config-public.go
Normal file
392
spn/cabin/config-public.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
// Configuration Keys.
|
||||
var (
|
||||
// Name of the node.
|
||||
publicCfgOptionNameKey = "spn/publicHub/name"
|
||||
publicCfgOptionName config.StringOption
|
||||
publicCfgOptionNameDefault = ""
|
||||
publicCfgOptionNameOrder = 512
|
||||
|
||||
// Person or organisation, who is in control of the node (should be same for all nodes of this person or organisation).
|
||||
publicCfgOptionGroupKey = "spn/publicHub/group"
|
||||
publicCfgOptionGroup config.StringOption
|
||||
publicCfgOptionGroupDefault = ""
|
||||
publicCfgOptionGroupOrder = 513
|
||||
|
||||
// Contact possibility (recommended, but optional).
|
||||
publicCfgOptionContactAddressKey = "spn/publicHub/contactAddress"
|
||||
publicCfgOptionContactAddress config.StringOption
|
||||
publicCfgOptionContactAddressDefault = ""
|
||||
publicCfgOptionContactAddressOrder = 514
|
||||
|
||||
// Type of service of the contact address, if not email.
|
||||
publicCfgOptionContactServiceKey = "spn/publicHub/contactService"
|
||||
publicCfgOptionContactService config.StringOption
|
||||
publicCfgOptionContactServiceDefault = ""
|
||||
publicCfgOptionContactServiceOrder = 515
|
||||
|
||||
// Hosters - supply chain (reseller, hosting provider, datacenter operator, ...).
|
||||
publicCfgOptionHostersKey = "spn/publicHub/hosters"
|
||||
publicCfgOptionHosters config.StringArrayOption
|
||||
publicCfgOptionHostersDefault = []string{}
|
||||
publicCfgOptionHostersOrder = 516
|
||||
|
||||
// Datacenter
|
||||
// Format: CC-COMPANY-INTERNALCODE
|
||||
// Eg: DE-Hetzner-FSN1-DC5
|
||||
//.
|
||||
publicCfgOptionDatacenterKey = "spn/publicHub/datacenter"
|
||||
publicCfgOptionDatacenter config.StringOption
|
||||
publicCfgOptionDatacenterDefault = ""
|
||||
publicCfgOptionDatacenterOrder = 517
|
||||
|
||||
// Network Location and Access.
|
||||
|
||||
// IPv4 must be global and accessible.
|
||||
publicCfgOptionIPv4Key = "spn/publicHub/ip4"
|
||||
publicCfgOptionIPv4 config.StringOption
|
||||
publicCfgOptionIPv4Default = ""
|
||||
publicCfgOptionIPv4Order = 518
|
||||
|
||||
// IPv6 must be global and accessible.
|
||||
publicCfgOptionIPv6Key = "spn/publicHub/ip6"
|
||||
publicCfgOptionIPv6 config.StringOption
|
||||
publicCfgOptionIPv6Default = ""
|
||||
publicCfgOptionIPv6Order = 519
|
||||
|
||||
// Transports.
|
||||
publicCfgOptionTransportsKey = "spn/publicHub/transports"
|
||||
publicCfgOptionTransports config.StringArrayOption
|
||||
publicCfgOptionTransportsDefault = []string{
|
||||
"tcp:17",
|
||||
}
|
||||
publicCfgOptionTransportsOrder = 520
|
||||
|
||||
// Entry Policy.
|
||||
publicCfgOptionEntryKey = "spn/publicHub/entry"
|
||||
publicCfgOptionEntry config.StringArrayOption
|
||||
publicCfgOptionEntryDefault = []string{}
|
||||
publicCfgOptionEntryOrder = 521
|
||||
|
||||
// Exit Policy.
|
||||
publicCfgOptionExitKey = "spn/publicHub/exit"
|
||||
publicCfgOptionExit config.StringArrayOption
|
||||
publicCfgOptionExitDefault = []string{"- * TCP/25"}
|
||||
publicCfgOptionExitOrder = 522
|
||||
|
||||
// Allow Unencrypted.
|
||||
publicCfgOptionAllowUnencryptedKey = "spn/publicHub/allowUnencrypted"
|
||||
publicCfgOptionAllowUnencrypted config.BoolOption
|
||||
publicCfgOptionAllowUnencryptedDefault = false
|
||||
publicCfgOptionAllowUnencryptedOrder = 523
|
||||
)
|
||||
|
||||
func prepPublicHubConfig() error {
|
||||
err := config.Register(&config.Option{
|
||||
Name: "Name",
|
||||
Key: publicCfgOptionNameKey,
|
||||
Description: "Human readable name of the Hub.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionNameDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionNameOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionName = config.GetAsString(publicCfgOptionNameKey, publicCfgOptionNameDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Group",
|
||||
Key: publicCfgOptionGroupKey,
|
||||
Description: "Name of the hub group this Hub belongs to.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionGroupDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionGroupOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionGroup = config.GetAsString(publicCfgOptionGroupKey, publicCfgOptionGroupDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Contact Address",
|
||||
Key: publicCfgOptionContactAddressKey,
|
||||
Description: "Contact address where the Hub operator can be reached.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionContactAddressDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionContactAddressOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionContactAddress = config.GetAsString(publicCfgOptionContactAddressKey, publicCfgOptionContactAddressDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Contact Service",
|
||||
Key: publicCfgOptionContactServiceKey,
|
||||
Description: "Name of the service the contact address corresponds to, if not email.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionContactServiceDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionContactServiceOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionContactService = config.GetAsString(publicCfgOptionContactServiceKey, publicCfgOptionContactServiceDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Hosters",
|
||||
Key: publicCfgOptionHostersKey,
|
||||
Description: "List of all involved entities and organisations that are involved in hosting this Hub.",
|
||||
OptType: config.OptTypeStringArray,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionHostersDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionHostersOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionHosters = config.GetAsStringArray(publicCfgOptionHostersKey, publicCfgOptionHostersDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Datacenter",
|
||||
Key: publicCfgOptionDatacenterKey,
|
||||
Description: "Identifier of the datacenter this Hub is hosted in.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionDatacenterDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionDatacenterOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionDatacenter = config.GetAsString(publicCfgOptionDatacenterKey, publicCfgOptionDatacenterDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "IPv4",
|
||||
Key: publicCfgOptionIPv4Key,
|
||||
Description: "IPv4 address of this Hub. Must be globally reachable.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionIPv4Default,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionIPv4Order,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionIPv4 = config.GetAsString(publicCfgOptionIPv4Key, publicCfgOptionIPv4Default)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "IPv6",
|
||||
Key: publicCfgOptionIPv6Key,
|
||||
Description: "IPv6 address of this Hub. Must be globally reachable.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionIPv6Default,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionIPv6Order,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionIPv6 = config.GetAsString(publicCfgOptionIPv6Key, publicCfgOptionIPv6Default)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Transports",
|
||||
Key: publicCfgOptionTransportsKey,
|
||||
Description: "List of transports this Hub supports.",
|
||||
OptType: config.OptTypeStringArray,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionTransportsDefault,
|
||||
ValidationFunc: func(value any) error {
|
||||
if transports, ok := value.([]string); ok {
|
||||
for i, transport := range transports {
|
||||
if _, err := hub.ParseTransport(transport); err != nil {
|
||||
return fmt.Errorf("failed to parse transport #%d: %w", i, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("not a []string, but %T", value)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionTransportsOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionTransports = config.GetAsStringArray(publicCfgOptionTransportsKey, publicCfgOptionTransportsDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Entry",
|
||||
Key: publicCfgOptionEntryKey,
|
||||
Description: "Define an entry policy. The format is the same for the endpoint lists. Default is permit.",
|
||||
OptType: config.OptTypeStringArray,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionEntryDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionEntryOrder,
|
||||
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionEntry = config.GetAsStringArray(publicCfgOptionEntryKey, publicCfgOptionEntryDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Exit",
|
||||
Key: publicCfgOptionExitKey,
|
||||
Description: "Define an exit policy. The format is the same for the endpoint lists. Default is permit.",
|
||||
OptType: config.OptTypeStringArray,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionExitDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionExitOrder,
|
||||
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionExit = config.GetAsStringArray(publicCfgOptionExitKey, publicCfgOptionExitDefault)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Allow Unencrypted Connections",
|
||||
Key: publicCfgOptionAllowUnencryptedKey,
|
||||
Description: "Advertise that this Hub is available for handling unencrypted connections, as detected by clients.",
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: publicCfgOptionAllowUnencryptedDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: publicCfgOptionAllowUnencryptedOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicCfgOptionAllowUnencrypted = config.GetAsBool(publicCfgOptionAllowUnencryptedKey, publicCfgOptionAllowUnencryptedDefault)
|
||||
|
||||
// update defaults from system
|
||||
setDynamicPublicDefaults()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPublicHubInfo() *hub.Announcement {
|
||||
// get configuration
|
||||
info := &hub.Announcement{
|
||||
Name: publicCfgOptionName(),
|
||||
Group: publicCfgOptionGroup(),
|
||||
ContactAddress: publicCfgOptionContactAddress(),
|
||||
ContactService: publicCfgOptionContactService(),
|
||||
Hosters: publicCfgOptionHosters(),
|
||||
Datacenter: publicCfgOptionDatacenter(),
|
||||
Transports: publicCfgOptionTransports(),
|
||||
Entry: publicCfgOptionEntry(),
|
||||
Exit: publicCfgOptionExit(),
|
||||
Flags: []string{},
|
||||
}
|
||||
|
||||
if publicCfgOptionAllowUnencrypted() {
|
||||
info.Flags = append(info.Flags, hub.FlagAllowUnencrypted)
|
||||
}
|
||||
|
||||
ip4 := publicCfgOptionIPv4()
|
||||
if ip4 != "" {
|
||||
ip := net.ParseIP(ip4)
|
||||
if ip == nil {
|
||||
log.Warningf("spn/cabin: invalid %s config: %s", publicCfgOptionIPv4Key, ip4)
|
||||
} else {
|
||||
info.IPv4 = ip
|
||||
}
|
||||
}
|
||||
|
||||
ip6 := publicCfgOptionIPv6()
|
||||
if ip6 != "" {
|
||||
ip := net.ParseIP(ip6)
|
||||
if ip == nil {
|
||||
log.Warningf("spn/cabin: invalid %s config: %s", publicCfgOptionIPv6Key, ip6)
|
||||
} else {
|
||||
info.IPv6 = ip
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func setDynamicPublicDefaults() {
|
||||
// name
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil {
|
||||
err := config.SetDefaultConfigOption(publicCfgOptionNameKey, hostname)
|
||||
if err != nil {
|
||||
log.Warningf("spn/cabin: failed to set %s default to %s", publicCfgOptionNameKey, hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// IPs
|
||||
v4IPs, v6IPs, err := netenv.GetAssignedGlobalAddresses()
|
||||
if err != nil {
|
||||
log.Warningf("spn/cabin: failed to get assigned addresses: %s", err)
|
||||
return
|
||||
}
|
||||
if len(v4IPs) == 1 {
|
||||
err = config.SetDefaultConfigOption(publicCfgOptionIPv4Key, v4IPs[0].String())
|
||||
if err != nil {
|
||||
log.Warningf("spn/cabin: failed to set %s default to %s", publicCfgOptionIPv4Key, v4IPs[0].String())
|
||||
}
|
||||
}
|
||||
if len(v6IPs) == 1 {
|
||||
err = config.SetDefaultConfigOption(publicCfgOptionIPv6Key, v6IPs[0].String())
|
||||
if err != nil {
|
||||
log.Warningf("spn/cabin: failed to set %s default to %s", publicCfgOptionIPv6Key, v6IPs[0].String())
|
||||
}
|
||||
}
|
||||
}
|
||||
98
spn/cabin/database.go
Normal file
98
spn/cabin/database.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
var db = database.NewInterface(nil)
|
||||
|
||||
// LoadIdentity loads an identify with the given key.
|
||||
func LoadIdentity(key string) (id *Identity, changed bool, err error) {
|
||||
r, err := db.Get(key)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
id, err = EnsureIdentity(r)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to parse identity: %w", err)
|
||||
}
|
||||
|
||||
// Check if required fields are present.
|
||||
switch {
|
||||
case id.Hub == nil:
|
||||
return nil, false, errors.New("missing id.Hub")
|
||||
case id.Signet == nil:
|
||||
return nil, false, errors.New("missing id.Signet")
|
||||
case id.Hub.Info == nil:
|
||||
return nil, false, errors.New("missing hub.Info")
|
||||
case id.Hub.Status == nil:
|
||||
return nil, false, errors.New("missing hub.Status")
|
||||
case id.ID != id.Hub.ID:
|
||||
return nil, false, errors.New("hub.ID mismatch")
|
||||
case id.ID != id.Hub.Info.ID:
|
||||
return nil, false, errors.New("hub.Info.ID mismatch")
|
||||
case id.Map == "":
|
||||
return nil, false, errors.New("invalid id.Map")
|
||||
case id.Hub.Map == "":
|
||||
return nil, false, errors.New("invalid hub.Map")
|
||||
case id.Hub.FirstSeen.IsZero():
|
||||
return nil, false, errors.New("missing hub.FirstSeen")
|
||||
case id.Hub.Info.Timestamp == 0:
|
||||
return nil, false, errors.New("missing hub.Info.Timestamp")
|
||||
case id.Hub.Status.Timestamp == 0:
|
||||
return nil, false, errors.New("missing hub.Status.Timestamp")
|
||||
}
|
||||
|
||||
// Run a initial maintenance routine.
|
||||
infoChanged, err := id.MaintainAnnouncement(nil, true)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to initialize announcement: %w", err)
|
||||
}
|
||||
statusChanged, err := id.MaintainStatus(nil, nil, nil, true)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to initialize status: %w", err)
|
||||
}
|
||||
|
||||
// Ensure the Measurements reset the values.
|
||||
measurements := id.Hub.GetMeasurements()
|
||||
measurements.SetLatency(0)
|
||||
measurements.SetCapacity(0)
|
||||
measurements.SetCalculatedCost(hub.MaxCalculatedCost)
|
||||
|
||||
return id, infoChanged || statusChanged, nil
|
||||
}
|
||||
|
||||
// EnsureIdentity makes sure a database record is an Identity.
|
||||
func EnsureIdentity(r record.Record) (*Identity, error) {
|
||||
// unwrap
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
id := &Identity{}
|
||||
err := record.Unwrap(r, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
id, ok := r.(*Identity)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *Identity, but %T", r)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Save saves the Identity to the database.
|
||||
func (id *Identity) Save() error {
|
||||
if !id.KeyIsSet() {
|
||||
return errors.New("no key set")
|
||||
}
|
||||
|
||||
return db.Put(id)
|
||||
}
|
||||
311
spn/cabin/identity.go
Normal file
311
spn/cabin/identity.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/jess"
|
||||
"github.com/safing/jess/tools"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/info"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultIDKeyScheme is the default jess tool for creating ID keys.
|
||||
DefaultIDKeyScheme = "Ed25519"
|
||||
|
||||
// DefaultIDKeySecurityLevel is the default security level for creating ID keys.
|
||||
DefaultIDKeySecurityLevel = 256 // Ed25519 security level is fixed, setting is ignored.
|
||||
)
|
||||
|
||||
// Identity holds the identity of a Hub.
|
||||
type Identity struct {
|
||||
record.Base
|
||||
|
||||
ID string
|
||||
Map string
|
||||
Hub *hub.Hub
|
||||
Signet *jess.Signet
|
||||
|
||||
ExchKeys map[string]*ExchKey
|
||||
|
||||
infoExportCache []byte
|
||||
statusExportCache []byte
|
||||
}
|
||||
|
||||
// Lock locks the Identity through the Hub lock.
|
||||
func (id *Identity) Lock() {
|
||||
id.Hub.Lock()
|
||||
}
|
||||
|
||||
// Unlock unlocks the Identity through the Hub lock.
|
||||
func (id *Identity) Unlock() {
|
||||
id.Hub.Unlock()
|
||||
}
|
||||
|
||||
// ExchKey holds the private information of a HubKey.
|
||||
type ExchKey struct {
|
||||
Created time.Time
|
||||
Expires time.Time
|
||||
key *jess.Signet
|
||||
tool *tools.Tool
|
||||
}
|
||||
|
||||
// CreateIdentity creates a new identity.
|
||||
func CreateIdentity(ctx context.Context, mapName string) (*Identity, error) {
|
||||
id := &Identity{
|
||||
Map: mapName,
|
||||
ExchKeys: make(map[string]*ExchKey),
|
||||
}
|
||||
|
||||
// create signet
|
||||
signet, recipient, err := hub.CreateHubSignet(DefaultIDKeyScheme, DefaultIDKeySecurityLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id.Signet = signet
|
||||
id.ID = signet.ID
|
||||
id.Hub = &hub.Hub{
|
||||
ID: id.ID,
|
||||
Map: mapName,
|
||||
PublicKey: recipient,
|
||||
}
|
||||
|
||||
// initial maintenance routine
|
||||
_, err = id.MaintainAnnouncement(nil, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize announcement: %w", err)
|
||||
}
|
||||
_, err = id.MaintainStatus([]*hub.Lane{}, new(int), nil, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize status: %w", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// MaintainAnnouncement maintains the Hub's Announcenemt and returns whether
|
||||
// there was a change that should be communicated to other Hubs.
|
||||
// If newInfo is nil, it will be derived from configuration.
|
||||
func (id *Identity) MaintainAnnouncement(newInfo *hub.Announcement, selfcheck bool) (changed bool, err error) {
|
||||
id.Lock()
|
||||
defer id.Unlock()
|
||||
|
||||
// Populate new info with data.
|
||||
if newInfo == nil {
|
||||
newInfo = getPublicHubInfo()
|
||||
}
|
||||
newInfo.ID = id.Hub.ID
|
||||
if id.Hub.Info != nil {
|
||||
newInfo.Timestamp = id.Hub.Info.Timestamp
|
||||
}
|
||||
if !newInfo.Equal(id.Hub.Info) {
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
// Update timestamp.
|
||||
newInfo.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
if changed || selfcheck {
|
||||
// Export new data.
|
||||
newInfoData, err := newInfo.Export(id.signingEnvelope())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to export: %w", err)
|
||||
}
|
||||
|
||||
// Apply the status as all other Hubs would in order to check if it's valid.
|
||||
_, _, _, err = hub.ApplyAnnouncement(id.Hub, newInfoData, conf.MainMapName, conf.MainMapScope, true)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to apply new announcement: %w", err)
|
||||
}
|
||||
id.infoExportCache = newInfoData
|
||||
|
||||
// Save message to hub message storage.
|
||||
err = hub.SaveHubMsg(id.ID, conf.MainMapName, hub.MsgTypeAnnouncement, newInfoData)
|
||||
if err != nil {
|
||||
log.Warningf("spn/cabin: failed to save own new/updated announcement of %s: %s", id.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// MaintainStatus maintains the Hub's Status and returns whether there was a change that should be communicated to other Hubs.
|
||||
func (id *Identity) MaintainStatus(lanes []*hub.Lane, load *int, flags []string, selfcheck bool) (changed bool, err error) {
|
||||
id.Lock()
|
||||
defer id.Unlock()
|
||||
|
||||
// Create a new status or make a copy of the status for editing.
|
||||
var newStatus *hub.Status
|
||||
if id.Hub.Status != nil {
|
||||
newStatus = id.Hub.Status.Copy()
|
||||
} else {
|
||||
newStatus = &hub.Status{}
|
||||
}
|
||||
|
||||
// Update software version.
|
||||
if newStatus.Version != info.Version() {
|
||||
newStatus.Version = info.Version()
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Update keys.
|
||||
keysChanged, err := id.MaintainExchKeys(newStatus, time.Now())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to maintain keys: %w", err)
|
||||
}
|
||||
if keysChanged {
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Update lanes.
|
||||
if lanes != nil && !hub.LanesEqual(newStatus.Lanes, lanes) {
|
||||
newStatus.Lanes = lanes
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Update load.
|
||||
if load != nil && newStatus.Load != *load {
|
||||
newStatus.Load = *load
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Update flags.
|
||||
if !hub.FlagsEqual(newStatus.Flags, flags) {
|
||||
newStatus.Flags = flags
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Update timestamp if something changed.
|
||||
if changed {
|
||||
newStatus.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
if changed || selfcheck {
|
||||
// Export new data.
|
||||
newStatusData, err := newStatus.Export(id.signingEnvelope())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to export: %w", err)
|
||||
}
|
||||
|
||||
// Apply the status as all other Hubs would in order to check if it's valid.
|
||||
_, _, _, err = hub.ApplyStatus(id.Hub, newStatusData, conf.MainMapName, conf.MainMapScope, true)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to apply new status: %w", err)
|
||||
}
|
||||
id.statusExportCache = newStatusData
|
||||
|
||||
// Save message to hub message storage.
|
||||
err = hub.SaveHubMsg(id.ID, conf.MainMapName, hub.MsgTypeStatus, newStatusData)
|
||||
if err != nil {
|
||||
log.Warningf("spn/cabin: failed to save own new/updated status: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// MakeOfflineStatus creates and signs an offline status message.
|
||||
func (id *Identity) MakeOfflineStatus() (offlineStatusExport []byte, err error) {
|
||||
// Make offline status.
|
||||
newStatus := &hub.Status{
|
||||
Timestamp: time.Now().Unix(),
|
||||
Version: info.Version(),
|
||||
Flags: []string{hub.FlagOffline},
|
||||
}
|
||||
|
||||
// Export new data.
|
||||
newStatusData, err := newStatus.Export(id.signingEnvelope())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to export: %w", err)
|
||||
}
|
||||
|
||||
return newStatusData, nil
|
||||
}
|
||||
|
||||
func (id *Identity) signingEnvelope() *jess.Envelope {
|
||||
env := jess.NewUnconfiguredEnvelope()
|
||||
env.SuiteID = jess.SuiteSignV1
|
||||
env.Senders = []*jess.Signet{id.Signet}
|
||||
|
||||
return env
|
||||
}
|
||||
|
||||
// ExportAnnouncement serializes and signs the Announcement.
|
||||
func (id *Identity) ExportAnnouncement() ([]byte, error) {
|
||||
id.Lock()
|
||||
defer id.Unlock()
|
||||
|
||||
if id.infoExportCache == nil {
|
||||
return nil, errors.New("announcement not exported")
|
||||
}
|
||||
|
||||
return id.infoExportCache, nil
|
||||
}
|
||||
|
||||
// ExportStatus serializes and signs the Status.
|
||||
func (id *Identity) ExportStatus() ([]byte, error) {
|
||||
id.Lock()
|
||||
defer id.Unlock()
|
||||
|
||||
if id.statusExportCache == nil {
|
||||
return nil, errors.New("status not exported")
|
||||
}
|
||||
|
||||
return id.statusExportCache, nil
|
||||
}
|
||||
|
||||
// SignHubMsg signs a data blob with the identity's private key.
|
||||
func (id *Identity) SignHubMsg(data []byte) ([]byte, error) {
|
||||
return hub.SignHubMsg(data, id.signingEnvelope(), false)
|
||||
}
|
||||
|
||||
// GetSignet returns the private exchange key with the given ID.
|
||||
func (id *Identity) GetSignet(keyID string, recipient bool) (*jess.Signet, error) {
|
||||
if recipient {
|
||||
return nil, errors.New("cabin.Identity only serves private keys")
|
||||
}
|
||||
|
||||
id.Lock()
|
||||
defer id.Unlock()
|
||||
|
||||
key, ok := id.ExchKeys[keyID]
|
||||
if !ok {
|
||||
return nil, errors.New("the requested key does not exist")
|
||||
}
|
||||
if time.Now().After(key.Expires) || key.key == nil {
|
||||
return nil, errors.New("the requested key has expired")
|
||||
}
|
||||
|
||||
return key.key, nil
|
||||
}
|
||||
|
||||
func (ek *ExchKey) toHubKey() (*hub.Key, error) {
|
||||
if ek.key == nil {
|
||||
return nil, errors.New("no key")
|
||||
}
|
||||
|
||||
// export public key
|
||||
rcpt, err := ek.key.AsRecipient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = rcpt.StoreKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// repackage
|
||||
return &hub.Key{
|
||||
Scheme: rcpt.Scheme,
|
||||
Key: rcpt.Key,
|
||||
Expires: ek.Expires.Unix(),
|
||||
}, nil
|
||||
}
|
||||
129
spn/cabin/identity_test.go
Normal file
129
spn/cabin/identity_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
func TestIdentity(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Register config options for public hub.
|
||||
if err := prepPublicHubConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create new identity.
|
||||
identityTestKey := "core:spn/public/identity"
|
||||
id, err := CreateIdentity(module.Ctx, conf.MainMapName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
id.SetKey(identityTestKey)
|
||||
|
||||
// Check values
|
||||
// Identity
|
||||
assert.NotEmpty(t, id.ID, "id.ID must be set")
|
||||
assert.NotEmpty(t, id.Map, "id.Map must be set")
|
||||
assert.NotNil(t, id.Signet, "id.Signet must be set")
|
||||
assert.NotNil(t, id.infoExportCache, "id.infoExportCache must be set")
|
||||
assert.NotNil(t, id.statusExportCache, "id.statusExportCache must be set")
|
||||
// Hub
|
||||
assert.NotEmpty(t, id.Hub.ID, "hub.ID must be set")
|
||||
assert.NotEmpty(t, id.Hub.Map, "hub.Map must be set")
|
||||
assert.NotZero(t, id.Hub.FirstSeen, "hub.FirstSeen must be set")
|
||||
// Info
|
||||
assert.NotEmpty(t, id.Hub.Info.ID, "info.ID must be set")
|
||||
assert.NotEqual(t, 0, id.Hub.Info.Timestamp, "info.Timestamp must be set")
|
||||
assert.NotEqual(t, "", id.Hub.Info.Name, "info.Name must be set (to hostname)")
|
||||
// Status
|
||||
assert.NotEqual(t, 0, id.Hub.Status.Timestamp, "status.Timestamp must be set")
|
||||
assert.NotEmpty(t, id.Hub.Status.Keys, "status.Keys must be set")
|
||||
|
||||
fmt.Printf("id: %+v\n", id)
|
||||
fmt.Printf("id.hub: %+v\n", id.Hub)
|
||||
fmt.Printf("id.Hub.Info: %+v\n", id.Hub.Info)
|
||||
fmt.Printf("id.Hub.Status: %+v\n", id.Hub.Status)
|
||||
|
||||
// Maintenance is run in creation, so nothing should change now.
|
||||
changed, err := id.MaintainAnnouncement(nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("unexpected change of announcement")
|
||||
}
|
||||
changed, err = id.MaintainStatus(nil, nil, nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("unexpected change of status")
|
||||
}
|
||||
|
||||
// Change lanes.
|
||||
lanes := []*hub.Lane{
|
||||
{
|
||||
ID: "A",
|
||||
Capacity: 1,
|
||||
Latency: 2,
|
||||
},
|
||||
{
|
||||
ID: "B",
|
||||
Capacity: 3,
|
||||
Latency: 4,
|
||||
},
|
||||
{
|
||||
ID: "C",
|
||||
Capacity: 5,
|
||||
Latency: 6,
|
||||
},
|
||||
}
|
||||
changed, err = id.MaintainStatus(lanes, new(int), nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !changed {
|
||||
t.Error("status should have changed")
|
||||
}
|
||||
|
||||
// Change nothing.
|
||||
changed, err = id.MaintainStatus(lanes, new(int), nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("unexpected change of status")
|
||||
}
|
||||
|
||||
// Exporting
|
||||
_, err = id.ExportAnnouncement()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = id.ExportStatus()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save to and load from database.
|
||||
err = id.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
id2, changed, err := LoadIdentity(identityTestKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("unexpected change")
|
||||
}
|
||||
|
||||
// Check if they match
|
||||
assert.Equal(t, id, id2, "identities should be equal")
|
||||
}
|
||||
179
spn/cabin/keys.go
Normal file
179
spn/cabin/keys.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/jess"
|
||||
"github.com/safing/jess/tools"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
type providedExchKeyScheme struct {
|
||||
id string
|
||||
securityLevel int //nolint:structcheck // TODO
|
||||
tool *tools.Tool
|
||||
}
|
||||
|
||||
var (
|
||||
// validFor defines how long keys are valid for use by clients.
|
||||
validFor = 48 * time.Hour // 2 days
|
||||
// renewBeforeExpiry defines the duration how long before expiry keys should be renewed.
|
||||
renewBeforeExpiry = 24 * time.Hour // 1 day
|
||||
|
||||
// burnAfter defines how long after expiry keys are burnt/deleted.
|
||||
burnAfter = 12 * time.Hour // 1/2 day
|
||||
// reuseAfter defines how long IDs should be blocked after expiry (and not be reused for new keys).
|
||||
reuseAfter = 2 * 7 * 24 * time.Hour // 2 weeks
|
||||
|
||||
// provideExchKeySchemes defines the jess tools for creating exchange keys.
|
||||
provideExchKeySchemes = []*providedExchKeyScheme{
|
||||
{
|
||||
id: "ECDH-X25519",
|
||||
securityLevel: 128, // informative only, security level of ECDH-X25519 is fixed
|
||||
},
|
||||
// TODO: test with rsa keys
|
||||
}
|
||||
)
|
||||
|
||||
func initProvidedExchKeySchemes() error {
|
||||
for _, eks := range provideExchKeySchemes {
|
||||
tool, err := tools.Get(eks.id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eks.tool = tool
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainExchKeys maintains the exchange keys, creating new ones and
|
||||
// deprecating and deleting old ones.
|
||||
func (id *Identity) MaintainExchKeys(newStatus *hub.Status, now time.Time) (changed bool, err error) {
|
||||
// create Keys map
|
||||
if id.ExchKeys == nil {
|
||||
id.ExchKeys = make(map[string]*ExchKey)
|
||||
}
|
||||
|
||||
// lifecycle management
|
||||
for keyID, exchKey := range id.ExchKeys {
|
||||
if exchKey.key != nil && now.After(exchKey.Expires.Add(burnAfter)) {
|
||||
// delete key
|
||||
err := exchKey.tool.StaticLogic.BurnKey(exchKey.key)
|
||||
if err != nil {
|
||||
log.Warningf(
|
||||
"spn/cabin: failed to burn key %s (%s) of %s: %s",
|
||||
keyID,
|
||||
exchKey.tool.Info.Name,
|
||||
id.Hub.ID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
// remove reference
|
||||
exchKey.key = nil
|
||||
}
|
||||
if now.After(exchKey.Expires.Add(reuseAfter)) {
|
||||
// remove key
|
||||
delete(id.ExchKeys, keyID)
|
||||
}
|
||||
}
|
||||
|
||||
// find or create current keys
|
||||
for _, eks := range provideExchKeySchemes {
|
||||
found := false
|
||||
for _, exchKey := range id.ExchKeys {
|
||||
if exchKey.key != nil &&
|
||||
exchKey.key.Scheme == eks.id &&
|
||||
now.Before(exchKey.Expires.Add(-renewBeforeExpiry)) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
err := id.createExchKey(eks, now)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create %s exchange key: %w", eks.tool.Info.Name, err)
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// export most recent keys to HubStatus
|
||||
if changed || len(newStatus.Keys) == 0 {
|
||||
// reset
|
||||
newStatus.Keys = make(map[string]*hub.Key)
|
||||
|
||||
// find longest valid key for every provided scheme
|
||||
for _, eks := range provideExchKeySchemes {
|
||||
// find key of scheme that is valid the longest
|
||||
longestValid := &ExchKey{
|
||||
Expires: now,
|
||||
}
|
||||
for _, exchKey := range id.ExchKeys {
|
||||
if exchKey.key != nil &&
|
||||
exchKey.key.Scheme == eks.id &&
|
||||
exchKey.Expires.After(longestValid.Expires) {
|
||||
longestValid = exchKey
|
||||
}
|
||||
}
|
||||
|
||||
// check result
|
||||
if longestValid.key == nil {
|
||||
log.Warningf("spn/cabin: could not find export candidate for exchange key scheme %s", eks.id)
|
||||
continue
|
||||
}
|
||||
|
||||
// export
|
||||
hubKey, err := longestValid.toHubKey()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to export %s exchange key: %w", longestValid.tool.Info.Name, err)
|
||||
}
|
||||
// add
|
||||
newStatus.Keys[longestValid.key.ID] = hubKey
|
||||
}
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
func (id *Identity) createExchKey(eks *providedExchKeyScheme, now time.Time) error {
|
||||
// get ID
|
||||
var keyID string
|
||||
for i := 0; i < 1000000; i++ { // not forever
|
||||
// generate new ID
|
||||
b, err := rng.Bytes(3)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get random data for key ID: %w", err)
|
||||
}
|
||||
keyID = base64.RawURLEncoding.EncodeToString(b)
|
||||
_, exists := id.ExchKeys[keyID]
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
}
|
||||
if keyID == "" {
|
||||
return errors.New("unable to find available exchange key ID")
|
||||
}
|
||||
|
||||
// generate key
|
||||
signet := jess.NewSignetBase(eks.tool)
|
||||
signet.ID = keyID
|
||||
// TODO: use security level for key generation
|
||||
if err := signet.GenerateKey(); err != nil {
|
||||
return fmt.Errorf("failed to get new exchange key: %w", err)
|
||||
}
|
||||
|
||||
// add to key map
|
||||
id.ExchKeys[keyID] = &ExchKey{
|
||||
Created: now,
|
||||
Expires: now.Add(validFor),
|
||||
key: signet,
|
||||
tool: eks.tool,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
43
spn/cabin/keys_test.go
Normal file
43
spn/cabin/keys_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
func TestKeyMaintenance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
id, err := CreateIdentity(module.Ctx, conf.MainMapName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iterations := 1000
|
||||
changeCnt := 0
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < iterations; i++ {
|
||||
changed, err := id.MaintainExchKeys(id.Hub.Status, now)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changed {
|
||||
changeCnt++
|
||||
t.Logf("===== exchange keys updated at %s:\n", now)
|
||||
for keyID, exchKey := range id.ExchKeys {
|
||||
t.Logf("[%s] %s %v\n", exchKey.Created, keyID, exchKey.key)
|
||||
}
|
||||
}
|
||||
now = now.Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
if iterations/changeCnt > 25 { // one new key every 24 hours/ticks
|
||||
t.Fatal("more changes than expected")
|
||||
}
|
||||
if len(id.ExchKeys) > 17 { // one new key every day for two weeks + 3 in use
|
||||
t.Fatal("more keys than expected")
|
||||
}
|
||||
}
|
||||
26
spn/cabin/module.go
Normal file
26
spn/cabin/module.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
var module *modules.Module
|
||||
|
||||
func init() {
|
||||
module = modules.Register("cabin", prep, nil, nil, "base", "rng")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
if err := initProvidedExchKeySchemes(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if conf.PublicHub() {
|
||||
if err := prepPublicHubConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
13
spn/cabin/module_test.go
Normal file
13
spn/cabin/module_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/service/core/pmtesting"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
conf.EnablePublicHub(true)
|
||||
pmtesting.TestMain(m, module)
|
||||
}
|
||||
157
spn/cabin/verification.go
Normal file
157
spn/cabin/verification.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/jess"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
var (
|
||||
verificationChallengeSize = 32
|
||||
verificationChallengeMinSize = 16
|
||||
verificationSigningSuite = jess.SuiteSignV1
|
||||
verificationRequirements = jess.NewRequirements().
|
||||
Remove(jess.Confidentiality).
|
||||
Remove(jess.Integrity).
|
||||
Remove(jess.RecipientAuthentication)
|
||||
)
|
||||
|
||||
// Verification is used to verify certain aspects of another Hub.
|
||||
type Verification struct {
|
||||
// Challenge is a random value chosen by the client.
|
||||
Challenge []byte `json:"c"`
|
||||
// Purpose defines the purpose of the verification. Protects against using verification for other purposes.
|
||||
Purpose string `json:"p"`
|
||||
// ClientReference is an optional field for exchanging metadata about the client. Protects against forwarding/relay attacks.
|
||||
ClientReference string `json:"cr"`
|
||||
// ServerReference is an optional field for exchanging metadata about the server. Protects against forwarding/relay attacks.
|
||||
ServerReference string `json:"sr"`
|
||||
}
|
||||
|
||||
// CreateVerificationRequest creates a new verification request with the given
|
||||
// purpose and references.
|
||||
func CreateVerificationRequest(purpose, clientReference, serverReference string) (v *Verification, request []byte, err error) {
|
||||
// Generate random challenge.
|
||||
challenge, err := rng.Bytes(verificationChallengeSize)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate challenge: %w", err)
|
||||
}
|
||||
|
||||
// Create verification object.
|
||||
v = &Verification{
|
||||
Purpose: purpose,
|
||||
ClientReference: clientReference,
|
||||
Challenge: challenge,
|
||||
}
|
||||
|
||||
// Serialize verification.
|
||||
request, err = dsd.Dump(v, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to serialize verification request: %w", err)
|
||||
}
|
||||
|
||||
// The server reference is not sent to the server, but needs to be supplied
|
||||
// by the server.
|
||||
v.ServerReference = serverReference
|
||||
|
||||
return v, request, nil
|
||||
}
|
||||
|
||||
// SignVerificationRequest sign a verification request.
|
||||
// The purpose and references must match the request, else the verification
|
||||
// will fail.
|
||||
func (id *Identity) SignVerificationRequest(request []byte, purpose, clientReference, serverReference string) (response []byte, err error) {
|
||||
// Parse request.
|
||||
v := new(Verification)
|
||||
_, err = dsd.Load(request, v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse request: %w", err)
|
||||
}
|
||||
|
||||
// Validate request.
|
||||
if len(v.Challenge) < verificationChallengeMinSize {
|
||||
return nil, errors.New("challenge too small")
|
||||
}
|
||||
if v.Purpose != purpose {
|
||||
return nil, errors.New("purpose mismatch")
|
||||
}
|
||||
if v.ClientReference != clientReference {
|
||||
return nil, errors.New("client reference mismatch")
|
||||
}
|
||||
|
||||
// Assign server reference and serialize.
|
||||
v.ServerReference = serverReference
|
||||
dataToSign, err := dsd.Dump(v, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize verification response: %w", err)
|
||||
}
|
||||
|
||||
// Sign response.
|
||||
e := jess.NewUnconfiguredEnvelope()
|
||||
e.SuiteID = verificationSigningSuite
|
||||
e.Senders = []*jess.Signet{id.Signet}
|
||||
jession, err := e.Correspondence(nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to setup signer: %w", err)
|
||||
}
|
||||
letter, err := jession.Close(dataToSign)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign: %w", err)
|
||||
}
|
||||
|
||||
// Serialize and return.
|
||||
signedResponse, err := letter.ToDSD(dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize letter: %w", err)
|
||||
}
|
||||
|
||||
return signedResponse, nil
|
||||
}
|
||||
|
||||
// Verify verifies the verification response and checks if everything is valid.
|
||||
func (v *Verification) Verify(response []byte, h *hub.Hub) error {
|
||||
// Parse response.
|
||||
letter, err := jess.LetterFromDSD(response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// Verify response.
|
||||
responseData, err := letter.Open(
|
||||
verificationRequirements,
|
||||
&hub.SingleTrustStore{
|
||||
Signet: h.PublicKey,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify response: %w", err)
|
||||
}
|
||||
|
||||
// Parse verified response.
|
||||
responseV := new(Verification)
|
||||
_, err = dsd.Load(responseData, responseV)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse verified response: %w", err)
|
||||
}
|
||||
|
||||
// Validate request.
|
||||
if subtle.ConstantTimeCompare(v.Challenge, responseV.Challenge) != 1 {
|
||||
return errors.New("challenge mismatch")
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(v.Purpose), []byte(responseV.Purpose)) != 1 {
|
||||
return errors.New("purpose mismatch")
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(v.ClientReference), []byte(responseV.ClientReference)) != 1 {
|
||||
return errors.New("client reference mismatch")
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(v.ServerReference), []byte(responseV.ServerReference)) != 1 {
|
||||
return errors.New("server reference mismatch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
127
spn/cabin/verification_test.go
Normal file
127
spn/cabin/verification_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package cabin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
id, err := CreateIdentity(module.Ctx, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := testVerificationWith(
|
||||
t, id,
|
||||
"a", "b", "c",
|
||||
"a", "b", "c",
|
||||
"", "", "", nil,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := testVerificationWith(
|
||||
t, id,
|
||||
"a", "b", "c",
|
||||
"x", "b", "c",
|
||||
"", "", "", nil,
|
||||
); err == nil {
|
||||
t.Fatal("should fail on purpose mismatch")
|
||||
}
|
||||
|
||||
if err := testVerificationWith(
|
||||
t, id,
|
||||
"a", "b", "c",
|
||||
"a", "x", "c",
|
||||
"", "", "", nil,
|
||||
); err == nil {
|
||||
t.Fatal("should fail on client ref mismatch")
|
||||
}
|
||||
|
||||
if err := testVerificationWith(
|
||||
t, id,
|
||||
"a", "b", "c",
|
||||
"a", "b", "x",
|
||||
"", "", "", nil,
|
||||
); err == nil {
|
||||
t.Fatal("should fail on server ref mismatch")
|
||||
}
|
||||
|
||||
if err := testVerificationWith(
|
||||
t, id,
|
||||
"a", "b", "c",
|
||||
"a", "b", "c",
|
||||
"x", "", "", nil,
|
||||
); err == nil {
|
||||
t.Fatal("should fail on purpose mismatch")
|
||||
}
|
||||
|
||||
if err := testVerificationWith(
|
||||
t, id,
|
||||
"a", "b", "c",
|
||||
"a", "b", "c",
|
||||
"", "x", "", nil,
|
||||
); err == nil {
|
||||
t.Fatal("should fail on client ref mismatch")
|
||||
}
|
||||
|
||||
if err := testVerificationWith(
|
||||
t, id,
|
||||
"a", "b", "c",
|
||||
"a", "b", "c",
|
||||
"", "", "x", nil,
|
||||
); err == nil {
|
||||
t.Fatal("should fail on server ref mismatch")
|
||||
}
|
||||
|
||||
if err := testVerificationWith(
|
||||
t, id,
|
||||
"a", "b", "c",
|
||||
"a", "b", "c",
|
||||
"", "", "", []byte{1, 2, 3, 4},
|
||||
); err == nil {
|
||||
t.Fatal("should fail on challenge mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func testVerificationWith(
|
||||
t *testing.T, id *Identity,
|
||||
purpose1, clientRef1, serverRef1 string, //nolint:unparam
|
||||
purpose2, clientRef2, serverRef2 string,
|
||||
mitmPurpose, mitmClientRef, mitmServerRef string,
|
||||
mitmChallenge []byte,
|
||||
) error {
|
||||
t.Helper()
|
||||
|
||||
v, request, err := CreateVerificationRequest(purpose1, clientRef1, serverRef1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create verification request: %w", err)
|
||||
}
|
||||
|
||||
response, err := id.SignVerificationRequest(request, purpose2, clientRef2, serverRef2)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign verification response: %w", err)
|
||||
}
|
||||
|
||||
if mitmPurpose != "" {
|
||||
v.Purpose = mitmPurpose
|
||||
}
|
||||
if mitmClientRef != "" {
|
||||
v.ClientReference = mitmClientRef
|
||||
}
|
||||
if mitmServerRef != "" {
|
||||
v.ServerReference = mitmServerRef
|
||||
}
|
||||
if mitmChallenge != nil {
|
||||
v.Challenge = mitmChallenge
|
||||
}
|
||||
|
||||
err = v.Verify(response, id.Hub)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
68
spn/captain/api.go
Normal file
68
spn/captain/api.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
const (
|
||||
apiPathForSPNReInit = "spn/reinit"
|
||||
)
|
||||
|
||||
func registerAPIEndpoints() error {
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: apiPathForSPNReInit,
|
||||
Write: api.PermitAdmin,
|
||||
// BelongsTo: module, // Do not attach to module, as this must run outside of the module.
|
||||
ActionFunc: handleReInit,
|
||||
Name: "Re-initialize SPN",
|
||||
Description: "Stops the SPN, resets all caches and starts it again. The SPN account and settings are not changed.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleReInit(ar *api.Request) (msg string, err error) {
|
||||
// Disable module and check
|
||||
changed := module.Disable()
|
||||
if !changed {
|
||||
return "", errors.New("can only re-initialize when the SPN is enabled")
|
||||
}
|
||||
|
||||
// Run module manager.
|
||||
err = modules.ManageModules()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to stop SPN: %w", err)
|
||||
}
|
||||
|
||||
// Delete SPN cache.
|
||||
db := database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
deletedRecords, err := db.Purge(ar.Context(), query.New("cache:spn/"))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to delete SPN cache: %w", err)
|
||||
}
|
||||
|
||||
// Enable module.
|
||||
module.Enable()
|
||||
|
||||
// Run module manager.
|
||||
err = modules.ManageModules()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to start SPN after cache reset: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Completed SPN re-initialization and deleted %d cache records in the process.",
|
||||
deletedRecords,
|
||||
), nil
|
||||
}
|
||||
152
spn/captain/bootstrap.go
Normal file
152
spn/captain/bootstrap.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
)
|
||||
|
||||
// BootstrapFile is used for sideloading bootstrap data.
|
||||
type BootstrapFile struct {
|
||||
Main BootstrapFileEntry
|
||||
}
|
||||
|
||||
// BootstrapFileEntry is the bootstrap data structure for one map.
|
||||
type BootstrapFileEntry struct {
|
||||
Hubs []string
|
||||
}
|
||||
|
||||
var (
|
||||
bootstrapHubFlag string
|
||||
bootstrapFileFlag string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&bootstrapHubFlag, "bootstrap-hub", "", "transport address of hub for bootstrapping with the hub ID in the fragment")
|
||||
flag.StringVar(&bootstrapFileFlag, "bootstrap-file", "", "bootstrap file containing bootstrap hubs - will be initialized if running a public hub and it doesn't exist")
|
||||
}
|
||||
|
||||
// prepBootstrapHubFlag checks the bootstrap-hub argument if it is valid.
|
||||
func prepBootstrapHubFlag() error {
|
||||
if bootstrapHubFlag != "" {
|
||||
_, _, _, err := hub.ParseBootstrapHub(bootstrapHubFlag)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processBootstrapHubFlag processes the bootstrap-hub argument.
|
||||
func processBootstrapHubFlag() error {
|
||||
if bootstrapHubFlag != "" {
|
||||
return navigator.Main.AddBootstrapHubs([]string{bootstrapHubFlag})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processBootstrapFileFlag processes the bootstrap-file argument.
|
||||
func processBootstrapFileFlag() error {
|
||||
if bootstrapFileFlag == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := os.Stat(bootstrapFileFlag)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return createBootstrapFile(bootstrapFileFlag)
|
||||
}
|
||||
return fmt.Errorf("failed to access bootstrap hub file: %w", err)
|
||||
}
|
||||
|
||||
return loadBootstrapFile(bootstrapFileFlag)
|
||||
}
|
||||
|
||||
// bootstrapWithUpdates loads bootstrap hubs from the updates server and imports them.
|
||||
func bootstrapWithUpdates() error {
|
||||
if bootstrapFileFlag != "" {
|
||||
return errors.New("using the bootstrap-file argument disables bootstrapping via the update system")
|
||||
}
|
||||
|
||||
return updateSPNIntel(module.Ctx, nil)
|
||||
}
|
||||
|
||||
// loadBootstrapFile loads a file with bootstrap hub entries and imports them.
|
||||
func loadBootstrapFile(filename string) (err error) {
|
||||
// Load bootstrap file from disk and parse it.
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load bootstrap file: %w", err)
|
||||
}
|
||||
bootstrapFile := &BootstrapFile{}
|
||||
_, err = dsd.Load(data, bootstrapFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse bootstrap file: %w", err)
|
||||
}
|
||||
if len(bootstrapFile.Main.Hubs) == 0 {
|
||||
return errors.New("bootstrap holds no hubs for main map")
|
||||
}
|
||||
|
||||
// Add Hubs to map.
|
||||
err = navigator.Main.AddBootstrapHubs(bootstrapFile.Main.Hubs)
|
||||
if err == nil {
|
||||
log.Infof("spn/captain: loaded bootstrap file %s", filename)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// createBootstrapFile save a bootstrap hub file with an entry of the public identity.
|
||||
func createBootstrapFile(filename string) error {
|
||||
if !conf.PublicHub() {
|
||||
log.Infof("spn/captain: skipped writing a bootstrap hub file, as this is not a public hub")
|
||||
return nil
|
||||
}
|
||||
|
||||
// create bootstrap hub
|
||||
if len(publicIdentity.Hub.Info.Transports) == 0 {
|
||||
return errors.New("public identity has no transports available")
|
||||
}
|
||||
// parse first transport
|
||||
t, err := hub.ParseTransport(publicIdentity.Hub.Info.Transports[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse transport of public identity: %w", err)
|
||||
}
|
||||
// add IP address
|
||||
switch {
|
||||
case publicIdentity.Hub.Info.IPv4 != nil:
|
||||
t.Domain = publicIdentity.Hub.Info.IPv4.String()
|
||||
case publicIdentity.Hub.Info.IPv6 != nil:
|
||||
t.Domain = "[" + publicIdentity.Hub.Info.IPv6.String() + "]"
|
||||
default:
|
||||
return errors.New("public identity has no IP address available")
|
||||
}
|
||||
// add Hub ID
|
||||
t.Option = publicIdentity.Hub.ID
|
||||
// put together
|
||||
bs := &BootstrapFile{
|
||||
Main: BootstrapFileEntry{
|
||||
Hubs: []string{t.String()},
|
||||
},
|
||||
}
|
||||
|
||||
// serialize
|
||||
fileData, err := dsd.Dump(bs, dsd.JSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// save to disk
|
||||
err = os.WriteFile(filename, fileData, 0o0664) //nolint:gosec // Should be able to be read by others.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("spn/captain: created bootstrap file %s", filename)
|
||||
return nil
|
||||
}
|
||||
506
spn/captain/client.go
Normal file
506
spn/captain/client.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/spn/access"
|
||||
"github.com/safing/portmaster/spn/crew"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
var (
|
||||
ready = abool.New()
|
||||
|
||||
spnLoginButton = notifications.Action{
|
||||
Text: "Login",
|
||||
Type: notifications.ActionTypeOpenPage,
|
||||
Payload: "spn",
|
||||
}
|
||||
spnOpenAccountPage = notifications.Action{
|
||||
Text: "Open Account Page",
|
||||
Type: notifications.ActionTypeOpenURL,
|
||||
Payload: "https://account.safing.io",
|
||||
}
|
||||
)
|
||||
|
||||
// ClientReady signifies if the SPN client is fully ready to handle connections.
|
||||
func ClientReady() bool {
|
||||
return ready.IsSet()
|
||||
}
|
||||
|
||||
type (
|
||||
clientComponentFunc func(ctx context.Context) clientComponentResult
|
||||
clientComponentResult uint8
|
||||
)
|
||||
|
||||
const (
|
||||
clientResultOk clientComponentResult = iota // Continue and clean module status.
|
||||
clientResultRetry // Go back to start of current step, don't clear module status.
|
||||
clientResultReconnect // Stop current connection and start from zero.
|
||||
clientResultShutdown // SPN Module is shutting down.
|
||||
)
|
||||
|
||||
var (
|
||||
clientNetworkChangedFlag = netenv.GetNetworkChangedFlag()
|
||||
clientIneligibleAccountUpdateDelay = 1 * time.Minute
|
||||
clientRetryConnectBackoffDuration = 5 * time.Second
|
||||
clientInitialHealthCheckDelay = 10 * time.Second
|
||||
clientHealthCheckTickDuration = 1 * time.Minute
|
||||
clientHealthCheckTickDurationSleepMode = 5 * time.Minute
|
||||
clientHealthCheckTimeout = 15 * time.Second
|
||||
|
||||
clientHealthCheckTrigger = make(chan struct{}, 1)
|
||||
lastHealthCheck time.Time
|
||||
)
|
||||
|
||||
func triggerClientHealthCheck() {
|
||||
select {
|
||||
case clientHealthCheckTrigger <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func clientManager(ctx context.Context) error {
|
||||
defer func() {
|
||||
ready.UnSet()
|
||||
netenv.ConnectedToSPN.UnSet()
|
||||
resetSPNStatus(StatusDisabled, true)
|
||||
module.Resolve("")
|
||||
clientStopHomeHub(ctx)
|
||||
}()
|
||||
|
||||
module.Hint(
|
||||
"spn:establishing-home-hub",
|
||||
"Connecting to SPN...",
|
||||
"Connecting to the SPN network is in progress.",
|
||||
)
|
||||
|
||||
// TODO: When we are starting and the SPN module is faster online than the
|
||||
// nameserver, then updating the account will fail as the DNS query is
|
||||
// redirected to a closed port.
|
||||
// We also can't add the nameserver as a module dependency, as the nameserver
|
||||
// is not part of the server.
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
|
||||
healthCheckTicker := module.NewSleepyTicker(clientHealthCheckTickDuration, clientHealthCheckTickDurationSleepMode)
|
||||
|
||||
reconnect:
|
||||
for {
|
||||
// Check if we are shutting down.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Reset SPN status.
|
||||
if ready.SetToIf(true, false) {
|
||||
netenv.ConnectedToSPN.UnSet()
|
||||
log.Info("spn/captain: client not ready")
|
||||
}
|
||||
resetSPNStatus(StatusConnecting, true)
|
||||
|
||||
// Check everything and connect to the SPN.
|
||||
for _, clientFunc := range []clientComponentFunc{
|
||||
clientStopHomeHub,
|
||||
clientCheckNetworkReady,
|
||||
clientCheckAccountAndTokens,
|
||||
clientConnectToHomeHub,
|
||||
clientSetActiveConnectionStatus,
|
||||
} {
|
||||
switch clientFunc(ctx) {
|
||||
case clientResultOk:
|
||||
// Continue
|
||||
case clientResultRetry, clientResultReconnect:
|
||||
// Wait for a short time to not loop too quickly.
|
||||
select {
|
||||
case <-time.After(clientRetryConnectBackoffDuration):
|
||||
continue reconnect
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
case clientResultShutdown:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("spn/captain: client is ready")
|
||||
ready.Set()
|
||||
netenv.ConnectedToSPN.Set()
|
||||
|
||||
module.TriggerEvent(SPNConnectedEvent, nil)
|
||||
module.StartWorker("update quick setting countries", navigator.Main.UpdateConfigQuickSettings)
|
||||
|
||||
// Reset last health check value, as we have just connected.
|
||||
lastHealthCheck = time.Now()
|
||||
|
||||
// Back off before starting initial health checks.
|
||||
select {
|
||||
case <-time.After(clientInitialHealthCheckDelay):
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
// Check health of the current SPN connection and monitor the user status.
|
||||
maintainers:
|
||||
for _, clientFunc := range []clientComponentFunc{
|
||||
clientCheckHomeHubConnection,
|
||||
clientCheckAccountAndTokens,
|
||||
clientSetActiveConnectionStatus,
|
||||
} {
|
||||
switch clientFunc(ctx) {
|
||||
case clientResultOk:
|
||||
// Continue
|
||||
case clientResultRetry:
|
||||
// Abort and wait for the next run.
|
||||
break maintainers
|
||||
case clientResultReconnect:
|
||||
continue reconnect
|
||||
case clientResultShutdown:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for signal to run maintenance again.
|
||||
select {
|
||||
case <-healthCheckTicker.Wait():
|
||||
case <-clientHealthCheckTrigger:
|
||||
case <-crew.ConnectErrors():
|
||||
case <-clientNetworkChangedFlag.Signal():
|
||||
clientNetworkChangedFlag.Refresh()
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func clientCheckNetworkReady(ctx context.Context) clientComponentResult {
|
||||
// Check if we are online enough for connecting.
|
||||
switch netenv.GetOnlineStatus() { //nolint:exhaustive
|
||||
case netenv.StatusOffline,
|
||||
netenv.StatusLimited:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return clientResultShutdown
|
||||
case <-time.After(1 * time.Second):
|
||||
return clientResultRetry
|
||||
}
|
||||
}
|
||||
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
// DisableAccount disables using any account related SPN functionality.
|
||||
// Attempts to use the same will result in errors.
|
||||
var DisableAccount bool
|
||||
|
||||
func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult {
|
||||
if DisableAccount {
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
// Get SPN user.
|
||||
user, err := access.GetUser()
|
||||
if err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
|
||||
notifications.NotifyError(
|
||||
"spn:failed-to-get-user",
|
||||
"SPN Internal Error",
|
||||
`Please restart Portmaster.`,
|
||||
// TODO: Add restart button.
|
||||
// TODO: Use special UI restart action in order to reload UI on restart.
|
||||
).AttachToModule(module)
|
||||
resetSPNStatus(StatusFailed, true)
|
||||
log.Errorf("spn/captain: client internal error: %s", err)
|
||||
return clientResultReconnect
|
||||
}
|
||||
|
||||
// Check if user is logged in.
|
||||
if user == nil || !user.IsLoggedIn() {
|
||||
notifications.NotifyWarn(
|
||||
"spn:not-logged-in",
|
||||
"SPN Login Required",
|
||||
`Please log in to access the SPN.`,
|
||||
spnLoginButton,
|
||||
).AttachToModule(module)
|
||||
resetSPNStatus(StatusFailed, true)
|
||||
log.Warningf("spn/captain: enabled but not logged in")
|
||||
return clientResultReconnect
|
||||
}
|
||||
|
||||
// Check if user is eligible.
|
||||
if !user.MayUseTheSPN() {
|
||||
// Update user in case there was a change.
|
||||
// Only update here if we need to - there is an update task in the access
|
||||
// module for periodic updates.
|
||||
if time.Now().Add(-clientIneligibleAccountUpdateDelay).After(time.Unix(user.Meta().Modified, 0)) {
|
||||
_, _, err := access.UpdateUser()
|
||||
if err != nil {
|
||||
notifications.NotifyError(
|
||||
"spn:failed-to-update-user",
|
||||
"SPN Account Server Error",
|
||||
fmt.Sprintf(`The status of your SPN account could not be updated: %s`, err),
|
||||
).AttachToModule(module)
|
||||
resetSPNStatus(StatusFailed, true)
|
||||
log.Errorf("spn/captain: failed to update ineligible account: %s", err)
|
||||
return clientResultReconnect
|
||||
}
|
||||
}
|
||||
|
||||
// Check if user is eligible after a possible update.
|
||||
if !user.MayUseTheSPN() {
|
||||
|
||||
// If package is generally valid, then the current package does not have access to the SPN.
|
||||
if user.MayUse("") {
|
||||
notifications.NotifyError(
|
||||
"spn:package-not-eligible",
|
||||
"SPN Not Included In Package",
|
||||
"Your current Portmaster Package does not include access to the SPN. Please upgrade your package on the Account Page.",
|
||||
spnOpenAccountPage,
|
||||
).AttachToModule(module)
|
||||
resetSPNStatus(StatusFailed, true)
|
||||
return clientResultReconnect
|
||||
}
|
||||
|
||||
// Otherwise, include the message from the user view.
|
||||
message := "There is an issue with your Portmaster Package. Please check the Account Page."
|
||||
if user.View != nil && user.View.Message != "" {
|
||||
message = user.View.Message
|
||||
}
|
||||
notifications.NotifyError(
|
||||
"spn:subscription-inactive",
|
||||
"Portmaster Package Issue",
|
||||
"Cannot enable SPN: "+message,
|
||||
spnOpenAccountPage,
|
||||
).AttachToModule(module)
|
||||
resetSPNStatus(StatusFailed, true)
|
||||
return clientResultReconnect
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have enough tokens.
|
||||
if access.ShouldRequest(access.ExpandAndConnectZones) {
|
||||
err := access.UpdateTokens()
|
||||
if err != nil {
|
||||
log.Errorf("spn/captain: failed to get tokens: %s", err)
|
||||
|
||||
// There was an error updating the account.
|
||||
// Check if we have enough tokens to continue anyway.
|
||||
regular, _ := access.GetTokenAmount(access.ExpandAndConnectZones)
|
||||
if regular == 0 /* && fallback == 0 */ { // TODO: Add fallback token check when fallback was tested on servers.
|
||||
notifications.NotifyError(
|
||||
"spn:tokens-exhausted",
|
||||
"SPN Access Tokens Exhausted",
|
||||
`The Portmaster failed to get new access tokens to access the SPN. The Portmaster will automatically retry to get new access tokens.`,
|
||||
).AttachToModule(module)
|
||||
resetSPNStatus(StatusFailed, false)
|
||||
}
|
||||
return clientResultRetry
|
||||
}
|
||||
}
|
||||
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
func clientStopHomeHub(ctx context.Context) clientComponentResult {
|
||||
// Don't use the context in this function, as it will likely be canceled
|
||||
// already and would disrupt any context usage in here.
|
||||
|
||||
// Get crane connecting to home.
|
||||
home, _ := navigator.Main.GetHome()
|
||||
if home == nil {
|
||||
return clientResultOk
|
||||
}
|
||||
crane := docks.GetAssignedCrane(home.Hub.ID)
|
||||
if crane == nil {
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
// Stop crane and all connected terminals.
|
||||
crane.Stop(nil)
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
func clientConnectToHomeHub(ctx context.Context) clientComponentResult {
|
||||
err := establishHomeHub(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("spn/captain: failed to establish connection to home hub: %s", err)
|
||||
resetSPNStatus(StatusFailed, true)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, ErrAllHomeHubsExcluded):
|
||||
notifications.NotifyError(
|
||||
"spn:all-home-hubs-excluded",
|
||||
"All Home Nodes Excluded",
|
||||
"Your current Home Node Rules exclude all available and eligible SPN Nodes. Please change your rules to allow for at least one available and eligible Home Node.",
|
||||
notifications.Action{
|
||||
Text: "Configure",
|
||||
Type: notifications.ActionTypeOpenSetting,
|
||||
Payload: ¬ifications.ActionTypeOpenSettingPayload{
|
||||
Key: CfgOptionHomeHubPolicyKey,
|
||||
},
|
||||
},
|
||||
).AttachToModule(module)
|
||||
|
||||
case errors.Is(err, ErrReInitSPNSuggested):
|
||||
notifications.NotifyError(
|
||||
"spn:cannot-bootstrap",
|
||||
"SPN Cannot Bootstrap",
|
||||
"The local state of the SPN network is likely outdated. Portmaster was not able to identify a server to connect to. Please re-initialize the SPN using the tools menu or the button on the notification.",
|
||||
notifications.Action{
|
||||
ID: "re-init",
|
||||
Text: "Re-Init SPN",
|
||||
Type: notifications.ActionTypeWebhook,
|
||||
Payload: ¬ifications.ActionTypeWebhookPayload{
|
||||
URL: apiPathForSPNReInit,
|
||||
ResultAction: "display",
|
||||
},
|
||||
},
|
||||
).AttachToModule(module)
|
||||
|
||||
default:
|
||||
notifications.NotifyWarn(
|
||||
"spn:home-hub-failure",
|
||||
"SPN Failed to Connect",
|
||||
fmt.Sprintf("Failed to connect to a home hub: %s. The Portmaster will retry to connect automatically.", err),
|
||||
).AttachToModule(module)
|
||||
}
|
||||
|
||||
return clientResultReconnect
|
||||
}
|
||||
|
||||
// Log new connection.
|
||||
home, _ := navigator.Main.GetHome()
|
||||
if home != nil {
|
||||
log.Infof("spn/captain: established new home %s", home.Hub)
|
||||
}
|
||||
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
func clientSetActiveConnectionStatus(ctx context.Context) clientComponentResult {
|
||||
// Get current home.
|
||||
home, homeTerminal := navigator.Main.GetHome()
|
||||
if home == nil || homeTerminal == nil {
|
||||
return clientResultReconnect
|
||||
}
|
||||
|
||||
// Resolve any connection error.
|
||||
module.Resolve("")
|
||||
|
||||
// Update SPN Status with connection information, if not already correctly set.
|
||||
spnStatus.Lock()
|
||||
defer spnStatus.Unlock()
|
||||
|
||||
if spnStatus.Status != StatusConnected || spnStatus.HomeHubID != home.Hub.ID {
|
||||
// Fill connection status data.
|
||||
spnStatus.Status = StatusConnected
|
||||
spnStatus.HomeHubID = home.Hub.ID
|
||||
spnStatus.HomeHubName = home.Hub.Info.Name
|
||||
|
||||
connectedIP, _, err := netutils.IPPortFromAddr(homeTerminal.RemoteAddr())
|
||||
if err != nil {
|
||||
spnStatus.ConnectedIP = homeTerminal.RemoteAddr().String()
|
||||
} else {
|
||||
spnStatus.ConnectedIP = connectedIP.String()
|
||||
}
|
||||
spnStatus.ConnectedTransport = homeTerminal.Transport().String()
|
||||
|
||||
geoLoc := home.GetLocation(connectedIP)
|
||||
if geoLoc != nil {
|
||||
spnStatus.ConnectedCountry = &geoLoc.Country
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
spnStatus.ConnectedSince = &now
|
||||
|
||||
// Push new status.
|
||||
pushSPNStatusUpdate()
|
||||
}
|
||||
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
func clientCheckHomeHubConnection(ctx context.Context) clientComponentResult {
|
||||
// Check the status of the Home Hub.
|
||||
home, homeTerminal := navigator.Main.GetHome()
|
||||
if home == nil || homeTerminal == nil || homeTerminal.IsBeingAbandoned() {
|
||||
return clientResultReconnect
|
||||
}
|
||||
|
||||
// Get crane controller for health check.
|
||||
crane := docks.GetAssignedCrane(home.Hub.ID)
|
||||
if crane == nil {
|
||||
log.Errorf("spn/captain: could not find home hub crane for health check")
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
// Ping home hub.
|
||||
latency, tErr := pingHome(ctx, crane.Controller, clientHealthCheckTimeout)
|
||||
if tErr != nil {
|
||||
log.Warningf("spn/captain: failed to ping home hub: %s", tErr)
|
||||
|
||||
// Prepare to reconnect to the network.
|
||||
|
||||
// Reset all failing states, as these might have been caused by the failing home hub.
|
||||
navigator.Main.ResetFailingStates(ctx)
|
||||
|
||||
// If the last health check is clearly too long ago, assume that the device was sleeping and do not set the home node to failing yet.
|
||||
if time.Since(lastHealthCheck) > clientHealthCheckTickDuration+
|
||||
clientHealthCheckTickDurationSleepMode+
|
||||
(clientHealthCheckTimeout*2) {
|
||||
return clientResultReconnect
|
||||
}
|
||||
|
||||
// Mark the home hub itself as failing, as we want to try to connect to somewhere else.
|
||||
home.MarkAsFailingFor(5 * time.Minute)
|
||||
|
||||
return clientResultReconnect
|
||||
}
|
||||
lastHealthCheck = time.Now()
|
||||
|
||||
log.Debugf("spn/captain: pinged home hub in %s", latency)
|
||||
return clientResultOk
|
||||
}
|
||||
|
||||
func pingHome(ctx context.Context, t terminal.Terminal, timeout time.Duration) (latency time.Duration, err *terminal.Error) {
|
||||
started := time.Now()
|
||||
|
||||
// Start ping operation.
|
||||
pingOp, tErr := crew.NewPingOp(t)
|
||||
if tErr != nil {
|
||||
return 0, tErr
|
||||
}
|
||||
|
||||
// Wait for response.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, terminal.ErrCanceled
|
||||
case <-time.After(timeout):
|
||||
return 0, terminal.ErrTimeout
|
||||
case result := <-pingOp.Result:
|
||||
if result.Is(terminal.ErrExplicitAck) {
|
||||
return time.Since(started), nil
|
||||
}
|
||||
if result.IsOK() {
|
||||
return 0, result.Wrap("unexpected response")
|
||||
}
|
||||
return 0, result
|
||||
}
|
||||
}
|
||||
253
spn/captain/config.go
Normal file
253
spn/captain/config.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portmaster/service/profile"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
)
|
||||
|
||||
var (
|
||||
// CfgOptionEnableSPNKey is the configuration key for the SPN module.
|
||||
CfgOptionEnableSPNKey = "spn/enable"
|
||||
cfgOptionEnableSPNOrder = 128
|
||||
|
||||
// CfgOptionHomeHubPolicyKey is the configuration key for the SPN home policy.
|
||||
CfgOptionHomeHubPolicyKey = "spn/homePolicy"
|
||||
cfgOptionHomeHubPolicy config.StringArrayOption
|
||||
cfgOptionHomeHubPolicyOrder = 145
|
||||
|
||||
// CfgOptionDNSExitHubPolicyKey is the configuration key for the SPN DNS exit policy.
|
||||
CfgOptionDNSExitHubPolicyKey = "spn/dnsExitPolicy"
|
||||
cfgOptionDNSExitHubPolicy config.StringArrayOption
|
||||
cfgOptionDNSExitHubPolicyOrder = 148
|
||||
|
||||
// CfgOptionUseCommunityNodesKey is the configuration key for whether to use community nodes.
|
||||
CfgOptionUseCommunityNodesKey = "spn/useCommunityNodes"
|
||||
cfgOptionUseCommunityNodes config.BoolOption
|
||||
cfgOptionUseCommunityNodesOrder = 149
|
||||
|
||||
// NonCommunityVerifiedOwners holds a list of verified owners that are not
|
||||
// considered "community".
|
||||
NonCommunityVerifiedOwners = []string{"Safing"}
|
||||
|
||||
// CfgOptionTrustNodeNodesKey is the configuration key for whether additional trusted nodes.
|
||||
CfgOptionTrustNodeNodesKey = "spn/trustNodes"
|
||||
cfgOptionTrustNodeNodes config.StringArrayOption
|
||||
cfgOptionTrustNodeNodesOrder = 150
|
||||
|
||||
// Special Access Code.
|
||||
cfgOptionSpecialAccessCodeKey = "spn/specialAccessCode"
|
||||
cfgOptionSpecialAccessCodeDefault = "none"
|
||||
cfgOptionSpecialAccessCode config.StringOption //nolint:unused // Linter, you drunk?
|
||||
cfgOptionSpecialAccessCodeOrder = 160
|
||||
|
||||
// IPv6 must be global and accessible.
|
||||
cfgOptionBindToAdvertisedKey = "spn/publicHub/bindToAdvertised"
|
||||
cfgOptionBindToAdvertised config.BoolOption
|
||||
cfgOptionBindToAdvertisedDefault = false
|
||||
cfgOptionBindToAdvertisedOrder = 161
|
||||
|
||||
// Config options for use.
|
||||
cfgOptionRoutingAlgorithm config.StringOption
|
||||
)
|
||||
|
||||
func prepConfig() error {
|
||||
// Home Node Rules
|
||||
err := config.Register(&config.Option{
|
||||
Name: "Home Node Rules",
|
||||
Key: CfgOptionHomeHubPolicyKey,
|
||||
Description: `Customize which countries should or should not be used for your Home Node. The Home Node is your entry into the SPN. You connect directly to it and all your connections are routed through it.
|
||||
|
||||
By default, the Portmaster tries to choose the nearest node as your Home Node in order to reduce your exposure to the open Internet.
|
||||
|
||||
Reconnect to the SPN in order to apply new rules.`,
|
||||
Help: profile.SPNRulesHelp,
|
||||
Sensitive: true,
|
||||
OptType: config.OptTypeStringArray,
|
||||
RequiresRestart: true,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
DefaultValue: []string{},
|
||||
Annotations: config.Annotations{
|
||||
config.CategoryAnnotation: "Routing",
|
||||
config.DisplayOrderAnnotation: cfgOptionHomeHubPolicyOrder,
|
||||
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
|
||||
config.QuickSettingsAnnotation: profile.SPNRulesQuickSettings,
|
||||
endpoints.EndpointListVerdictNamesAnnotation: profile.SPNRulesVerdictNames,
|
||||
},
|
||||
ValidationRegex: endpoints.ListEntryValidationRegex,
|
||||
ValidationFunc: endpoints.ValidateEndpointListConfigOption,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfgOptionHomeHubPolicy = config.Concurrent.GetAsStringArray(CfgOptionHomeHubPolicyKey, []string{})
|
||||
|
||||
// DNS Exit Node Rules
|
||||
err = config.Register(&config.Option{
|
||||
Name: "DNS Exit Node Rules",
|
||||
Key: CfgOptionDNSExitHubPolicyKey,
|
||||
Description: `Customize which countries should or should not be used as DNS Exit Nodes.
|
||||
|
||||
By default, the Portmaster will exit DNS requests directly at your Home Node in order to keep them fast and close to your location. This is important, as DNS resolution often takes your approximate location into account when deciding which optimized DNS records are returned to you. As the Portmaster encrypts your DNS requests by default, you effectively gain a two-hop security level for your DNS requests in order to protect your privacy.
|
||||
|
||||
This setting mainly exists for when you need to simulate your presence in another location on a lower level too. This might be necessary to defeat more intelligent geo-blocking systems.`,
|
||||
Help: profile.SPNRulesHelp,
|
||||
Sensitive: true,
|
||||
OptType: config.OptTypeStringArray,
|
||||
RequiresRestart: true,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
DefaultValue: []string{},
|
||||
Annotations: config.Annotations{
|
||||
config.CategoryAnnotation: "Routing",
|
||||
config.DisplayOrderAnnotation: cfgOptionDNSExitHubPolicyOrder,
|
||||
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
|
||||
config.QuickSettingsAnnotation: profile.SPNRulesQuickSettings,
|
||||
endpoints.EndpointListVerdictNamesAnnotation: profile.SPNRulesVerdictNames,
|
||||
},
|
||||
ValidationRegex: endpoints.ListEntryValidationRegex,
|
||||
ValidationFunc: endpoints.ValidateEndpointListConfigOption,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfgOptionDNSExitHubPolicy = config.Concurrent.GetAsStringArray(CfgOptionDNSExitHubPolicyKey, []string{})
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Use Community Nodes",
|
||||
Key: CfgOptionUseCommunityNodesKey,
|
||||
Description: "Use nodes (servers) not operated by Safing themselves. The use of community nodes is recommended as it diversifies the ownership of the nodes you use for your connections and further strengthens your privacy. Plain connections (eg. http, smtp, ...) will never exit via community nodes, making this setting safe to use.",
|
||||
Sensitive: true,
|
||||
OptType: config.OptTypeBool,
|
||||
RequiresRestart: true,
|
||||
DefaultValue: true,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionUseCommunityNodesOrder,
|
||||
config.CategoryAnnotation: "Routing",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfgOptionUseCommunityNodes = config.Concurrent.GetAsBool(CfgOptionUseCommunityNodesKey, true)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Trust Nodes",
|
||||
Key: CfgOptionTrustNodeNodesKey,
|
||||
Description: "Specify which community nodes to additionally trust. These nodes may then also be used as a Home Node, as well as an Exit Node for unencrypted connections.",
|
||||
Help: "You can specify nodes by their ID or their verified operator.",
|
||||
Sensitive: true,
|
||||
OptType: config.OptTypeStringArray,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
DefaultValue: []string{},
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionTrustNodeNodesOrder,
|
||||
config.CategoryAnnotation: "Routing",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfgOptionTrustNodeNodes = config.Concurrent.GetAsStringArray(CfgOptionTrustNodeNodesKey, []string{})
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Special Access Code",
|
||||
Key: cfgOptionSpecialAccessCodeKey,
|
||||
Description: "Special Access Codes grant access to the SPN for testing or evaluation purposes.",
|
||||
Sensitive: true,
|
||||
OptType: config.OptTypeString,
|
||||
DefaultValue: cfgOptionSpecialAccessCodeDefault,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionSpecialAccessCodeOrder,
|
||||
config.CategoryAnnotation: "Advanced",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfgOptionSpecialAccessCode = config.Concurrent.GetAsString(cfgOptionSpecialAccessCodeKey, "")
|
||||
|
||||
if conf.PublicHub() {
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Connect From Advertised IPs Only",
|
||||
Key: cfgOptionBindToAdvertisedKey,
|
||||
Description: "Only connect from (bind to) the advertised IP addresses.",
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
DefaultValue: cfgOptionBindToAdvertisedDefault,
|
||||
RequiresRestart: true,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionBindToAdvertisedOrder,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfgOptionBindToAdvertised = config.GetAsBool(cfgOptionBindToAdvertisedKey, cfgOptionBindToAdvertisedDefault)
|
||||
}
|
||||
|
||||
// Config options for use.
|
||||
cfgOptionRoutingAlgorithm = config.Concurrent.GetAsString(profile.CfgOptionRoutingAlgorithmKey, navigator.DefaultRoutingProfileID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
homeHubPolicy endpoints.Endpoints
|
||||
homeHubPolicyLock sync.Mutex
|
||||
homeHubPolicyConfigFlag = config.NewValidityFlag()
|
||||
)
|
||||
|
||||
func getHomeHubPolicy() (endpoints.Endpoints, error) {
|
||||
homeHubPolicyLock.Lock()
|
||||
defer homeHubPolicyLock.Unlock()
|
||||
|
||||
// Return cached value if config is still valid.
|
||||
if homeHubPolicyConfigFlag.IsValid() {
|
||||
return homeHubPolicy, nil
|
||||
}
|
||||
homeHubPolicyConfigFlag.Refresh()
|
||||
|
||||
// Parse new policy.
|
||||
policy, err := endpoints.ParseEndpoints(cfgOptionHomeHubPolicy())
|
||||
if err != nil {
|
||||
homeHubPolicy = nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Save and return the new policy.
|
||||
homeHubPolicy = policy
|
||||
return homeHubPolicy, nil
|
||||
}
|
||||
|
||||
var (
|
||||
dnsExitHubPolicy endpoints.Endpoints
|
||||
dnsExitHubPolicyLock sync.Mutex
|
||||
dnsExitHubPolicyConfigFlag = config.NewValidityFlag()
|
||||
)
|
||||
|
||||
// GetDNSExitHubPolicy return the current DNS exit policy.
|
||||
func GetDNSExitHubPolicy() (endpoints.Endpoints, error) {
|
||||
dnsExitHubPolicyLock.Lock()
|
||||
defer dnsExitHubPolicyLock.Unlock()
|
||||
|
||||
// Return cached value if config is still valid.
|
||||
if dnsExitHubPolicyConfigFlag.IsValid() {
|
||||
return dnsExitHubPolicy, nil
|
||||
}
|
||||
dnsExitHubPolicyConfigFlag.Refresh()
|
||||
|
||||
// Parse new policy.
|
||||
policy, err := endpoints.ParseEndpoints(cfgOptionDNSExitHubPolicy())
|
||||
if err != nil {
|
||||
dnsExitHubPolicy = nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Save and return the new policy.
|
||||
dnsExitHubPolicy = policy
|
||||
return dnsExitHubPolicy, nil
|
||||
}
|
||||
105
spn/captain/establish.go
Normal file
105
spn/captain/establish.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// EstablishCrane establishes a crane to another Hub.
|
||||
func EstablishCrane(callerCtx context.Context, dst *hub.Hub) (*docks.Crane, error) {
|
||||
if conf.PublicHub() && dst.ID == publicIdentity.ID {
|
||||
return nil, errors.New("connecting to self")
|
||||
}
|
||||
if docks.GetAssignedCrane(dst.ID) != nil {
|
||||
return nil, fmt.Errorf("route to %s already exists", dst.ID)
|
||||
}
|
||||
|
||||
ship, err := ships.Launch(callerCtx, dst, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to launch ship: %w", err)
|
||||
}
|
||||
|
||||
// On pure clients, mark all ships as public in order to show unmasked data in logs.
|
||||
if conf.Client() && !conf.PublicHub() {
|
||||
ship.MarkPublic()
|
||||
}
|
||||
|
||||
crane, err := docks.NewCrane(ship, dst, publicIdentity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create crane: %w", err)
|
||||
}
|
||||
|
||||
err = crane.Start(callerCtx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start crane: %w", err)
|
||||
}
|
||||
|
||||
// Start gossip op for live map updates.
|
||||
_, tErr := NewGossipOp(crane.Controller)
|
||||
if tErr != nil {
|
||||
crane.Stop(tErr)
|
||||
return nil, fmt.Errorf("failed to start gossip op: %w", tErr)
|
||||
}
|
||||
|
||||
return crane, nil
|
||||
}
|
||||
|
||||
// EstablishPublicLane establishes a crane to another Hub and publishes it.
|
||||
func EstablishPublicLane(ctx context.Context, dst *hub.Hub) (*docks.Crane, *terminal.Error) {
|
||||
// Create new context with timeout.
|
||||
// The maximum timeout is a worst case safeguard.
|
||||
// Keep in mind that multiple IPs and protocols may be tried in all configurations.
|
||||
// Some servers will be (possibly on purpose) hard to reach.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Connect to destination and establish communication.
|
||||
crane, err := EstablishCrane(ctx, dst)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to establish crane: %w", err)
|
||||
}
|
||||
|
||||
// Publish as Lane.
|
||||
publishOp, tErr := NewPublishOp(crane.Controller, publicIdentity)
|
||||
if tErr != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to publish: %w", err)
|
||||
}
|
||||
|
||||
// Wait for publishing to complete.
|
||||
select {
|
||||
case tErr := <-publishOp.Result():
|
||||
if !tErr.Is(terminal.ErrExplicitAck) {
|
||||
// Stop crane again, because we failed to publish it.
|
||||
defer crane.Stop(nil)
|
||||
return nil, terminal.ErrInternalError.With("failed to publish lane: %w", tErr)
|
||||
}
|
||||
|
||||
case <-crane.Controller.Ctx().Done():
|
||||
defer crane.Stop(nil)
|
||||
return nil, terminal.ErrStopping
|
||||
|
||||
case <-ctx.Done():
|
||||
defer crane.Stop(nil)
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
return nil, terminal.ErrTimeout
|
||||
}
|
||||
return nil, terminal.ErrCanceled
|
||||
}
|
||||
|
||||
// Query all gossip msgs.
|
||||
_, tErr = NewGossipQueryOp(crane.Controller)
|
||||
if tErr != nil {
|
||||
log.Warningf("spn/captain: failed to start initial gossip query: %s", tErr)
|
||||
}
|
||||
|
||||
return crane, nil
|
||||
}
|
||||
28
spn/captain/exceptions.go
Normal file
28
spn/captain/exceptions.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
exceptionLock sync.Mutex
|
||||
exceptIPv4 net.IP
|
||||
exceptIPv6 net.IP
|
||||
)
|
||||
|
||||
func setExceptions(ipv4, ipv6 net.IP) {
|
||||
exceptionLock.Lock()
|
||||
defer exceptionLock.Unlock()
|
||||
|
||||
exceptIPv4 = ipv4
|
||||
exceptIPv6 = ipv6
|
||||
}
|
||||
|
||||
// IsExcepted checks if the given IP is currently excepted from the SPN.
|
||||
func IsExcepted(ip net.IP) bool {
|
||||
exceptionLock.Lock()
|
||||
defer exceptionLock.Unlock()
|
||||
|
||||
return ip.Equal(exceptIPv4) || ip.Equal(exceptIPv6)
|
||||
}
|
||||
38
spn/captain/gossip.go
Normal file
38
spn/captain/gossip.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
gossipOps = make(map[string]*GossipOp)
|
||||
gossipOpsLock sync.RWMutex
|
||||
)
|
||||
|
||||
func registerGossipOp(craneID string, op *GossipOp) {
|
||||
gossipOpsLock.Lock()
|
||||
defer gossipOpsLock.Unlock()
|
||||
|
||||
gossipOps[craneID] = op
|
||||
}
|
||||
|
||||
func deleteGossipOp(craneID string) {
|
||||
gossipOpsLock.Lock()
|
||||
defer gossipOpsLock.Unlock()
|
||||
|
||||
delete(gossipOps, craneID)
|
||||
}
|
||||
|
||||
func gossipRelayMsg(receivedFrom string, msgType GossipMsgType, data []byte) {
|
||||
gossipOpsLock.RLock()
|
||||
defer gossipOpsLock.RUnlock()
|
||||
|
||||
for craneID, gossipOp := range gossipOps {
|
||||
// Don't return same msg back to sender.
|
||||
if craneID == receivedFrom {
|
||||
continue
|
||||
}
|
||||
|
||||
gossipOp.sendMsg(msgType, data)
|
||||
}
|
||||
}
|
||||
47
spn/captain/hooks.go
Normal file
47
spn/captain/hooks.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
)
|
||||
|
||||
func startDockHooks() {
|
||||
docks.RegisterCraneUpdateHook(handleCraneUpdate)
|
||||
}
|
||||
|
||||
func stopDockHooks() {
|
||||
docks.ResetCraneUpdateHook()
|
||||
}
|
||||
|
||||
func handleCraneUpdate(crane *docks.Crane) {
|
||||
if crane == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if conf.Client() && crane.Controller != nil && crane.Controller.Abandoning.IsSet() {
|
||||
// Check connection to home hub.
|
||||
triggerClientHealthCheck()
|
||||
}
|
||||
|
||||
if conf.PublicHub() && crane.Public() {
|
||||
// Update Hub status.
|
||||
updateConnectionStatus()
|
||||
}
|
||||
}
|
||||
|
||||
func updateConnectionStatus() {
|
||||
// Delay updating status for a better chance to combine multiple changes.
|
||||
statusUpdateTask.Schedule(time.Now().Add(maintainStatusUpdateDelay))
|
||||
|
||||
// Check if we lost all connections and trigger a pending restart if we did.
|
||||
for _, crane := range docks.GetAllAssignedCranes() {
|
||||
if crane.Public() && !crane.Stopped() {
|
||||
// There is at least one public and active crane, so don't restart now.
|
||||
return
|
||||
}
|
||||
}
|
||||
updates.TriggerRestartIfPending()
|
||||
}
|
||||
108
spn/captain/intel.go
Normal file
108
spn/captain/intel.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/updater"
|
||||
"github.com/safing/portmaster/service/updates"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
)
|
||||
|
||||
var (
|
||||
intelResource *updater.File
|
||||
intelResourcePath = "intel/spn/main-intel.yaml"
|
||||
intelResourceMapName = "main"
|
||||
intelResourceUpdateLock sync.Mutex
|
||||
)
|
||||
|
||||
func registerIntelUpdateHook() error {
|
||||
if err := module.RegisterEventHook(
|
||||
updates.ModuleName,
|
||||
updates.ResourceUpdateEvent,
|
||||
"update SPN intel",
|
||||
updateSPNIntel,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := module.RegisterEventHook(
|
||||
"config",
|
||||
config.ChangeEvent,
|
||||
"update SPN intel",
|
||||
updateSPNIntel,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSPNIntel(ctx context.Context, _ interface{}) (err error) {
|
||||
intelResourceUpdateLock.Lock()
|
||||
defer intelResourceUpdateLock.Unlock()
|
||||
|
||||
// Only update SPN intel when using the matching map.
|
||||
if conf.MainMapName != intelResourceMapName {
|
||||
return fmt.Errorf("intel resource not for map %q", conf.MainMapName)
|
||||
}
|
||||
|
||||
// Check if there is something to do.
|
||||
if intelResource != nil && !intelResource.UpgradeAvailable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get intel file and load it from disk.
|
||||
intelResource, err = updates.GetFile(intelResourcePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get SPN intel update: %w", err)
|
||||
}
|
||||
intelData, err := os.ReadFile(intelResource.Path())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load SPN intel update: %w", err)
|
||||
}
|
||||
|
||||
// Parse and apply intel data.
|
||||
intel, err := hub.ParseIntel(intelData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse SPN intel update: %w", err)
|
||||
}
|
||||
|
||||
setVirtualNetworkConfig(intel.VirtualNetworks)
|
||||
return navigator.Main.UpdateIntel(intel, cfgOptionTrustNodeNodes())
|
||||
}
|
||||
|
||||
func resetSPNIntel() {
|
||||
intelResourceUpdateLock.Lock()
|
||||
defer intelResourceUpdateLock.Unlock()
|
||||
|
||||
intelResource = nil
|
||||
}
|
||||
|
||||
func setVirtualNetworkConfig(configs []*hub.VirtualNetworkConfig) {
|
||||
// Do nothing if not public Hub.
|
||||
if !conf.PublicHub() {
|
||||
return
|
||||
}
|
||||
// Reset if there are no virtual networks configured.
|
||||
if len(configs) == 0 {
|
||||
ships.SetVirtualNetworkConfig(nil)
|
||||
}
|
||||
|
||||
// Check if we are in a virtual network.
|
||||
for _, config := range configs {
|
||||
if _, ok := config.Mapping[publicIdentity.Hub.ID]; ok {
|
||||
ships.SetVirtualNetworkConfig(config)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If not, reset - we might have been in one before.
|
||||
ships.SetVirtualNetworkConfig(nil)
|
||||
}
|
||||
219
spn/captain/module.go
Normal file
219
spn/captain/module.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/crew"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/safing/portmaster/spn/patrol"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
_ "github.com/safing/portmaster/spn/sluice"
|
||||
)
|
||||
|
||||
const controlledFailureExitCode = 24
|
||||
|
||||
var module *modules.Module
|
||||
|
||||
// SPNConnectedEvent is the name of the event that is fired when the SPN has connected and is ready.
|
||||
const SPNConnectedEvent = "spn connect"
|
||||
|
||||
func init() {
|
||||
module = modules.Register("captain", prep, start, stop, "base", "terminal", "cabin", "ships", "docks", "crew", "navigator", "sluice", "patrol", "netenv")
|
||||
module.RegisterEvent(SPNConnectedEvent, false)
|
||||
subsystems.Register(
|
||||
"spn",
|
||||
"SPN",
|
||||
"Safing Privacy Network",
|
||||
module,
|
||||
"config:spn/",
|
||||
&config.Option{
|
||||
Name: "SPN Module",
|
||||
Key: CfgOptionEnableSPNKey,
|
||||
Description: "Start the Safing Privacy Network module. If turned off, the SPN is fully disabled on this device.",
|
||||
OptType: config.OptTypeBool,
|
||||
DefaultValue: false,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: cfgOptionEnableSPNOrder,
|
||||
config.CategoryAnnotation: "General",
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
// Check if we can parse the bootstrap hub flag.
|
||||
if err := prepBootstrapHubFlag(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register SPN status provider.
|
||||
if err := registerSPNStatusProvider(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register API endpoints.
|
||||
if err := registerAPIEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if conf.PublicHub() {
|
||||
// Register API authenticator.
|
||||
if err := api.SetAuthenticator(apiAuthenticator); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := module.RegisterEventHook(
|
||||
"patrol",
|
||||
patrol.ChangeSignalEventName,
|
||||
"trigger hub status maintenance",
|
||||
func(_ context.Context, _ any) error {
|
||||
TriggerHubStatusMaintenance()
|
||||
return nil
|
||||
},
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return prepConfig()
|
||||
}
|
||||
|
||||
func start() error {
|
||||
maskingBytes, err := rng.Bytes(16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get random bytes for masking: %w", err)
|
||||
}
|
||||
ships.EnableMasking(maskingBytes)
|
||||
|
||||
// Initialize intel.
|
||||
if err := registerIntelUpdateHook(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := updateSPNIntel(module.Ctx, nil); err != nil {
|
||||
log.Errorf("spn/captain: failed to update SPN intel: %s", err)
|
||||
}
|
||||
|
||||
// Initialize identity and piers.
|
||||
if conf.PublicHub() {
|
||||
// Load identity.
|
||||
if err := loadPublicIdentity(); err != nil {
|
||||
// We cannot recover from this, set controlled failure (do not retry).
|
||||
modules.SetExitStatusCode(controlledFailureExitCode)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if any networks are configured.
|
||||
if !conf.HubHasIPv4() && !conf.HubHasIPv6() {
|
||||
// We cannot recover from this, set controlled failure (do not retry).
|
||||
modules.SetExitStatusCode(controlledFailureExitCode)
|
||||
|
||||
return errors.New("no IP addresses for Hub configured (or detected)")
|
||||
}
|
||||
|
||||
// Start management of identity and piers.
|
||||
if err := prepPublicIdentityMgmt(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Set ID to display on http info page.
|
||||
ships.DisplayHubID = publicIdentity.ID
|
||||
// Start listeners.
|
||||
if err := startPiers(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Enable connect operation.
|
||||
crew.EnableConnecting(publicIdentity.Hub)
|
||||
}
|
||||
|
||||
// Subscribe to updates of cranes.
|
||||
startDockHooks()
|
||||
|
||||
// bootstrapping
|
||||
if err := processBootstrapHubFlag(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := processBootstrapFileFlag(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// network optimizer
|
||||
if conf.PublicHub() {
|
||||
module.NewTask("optimize network", optimizeNetwork).
|
||||
Repeat(1 * time.Minute).
|
||||
Schedule(time.Now().Add(15 * time.Second))
|
||||
}
|
||||
|
||||
// client + home hub manager
|
||||
if conf.Client() {
|
||||
module.StartServiceWorker("client manager", 0, clientManager)
|
||||
|
||||
// Reset failing hubs when the network changes while not connected.
|
||||
if err := module.RegisterEventHook(
|
||||
"netenv",
|
||||
"network changed",
|
||||
"reset failing hubs",
|
||||
func(_ context.Context, _ interface{}) error {
|
||||
if ready.IsNotSet() {
|
||||
navigator.Main.ResetFailingStates(module.Ctx)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
// Reset intel resource so that it is loaded again when starting.
|
||||
resetSPNIntel()
|
||||
|
||||
// Unregister crane update hook.
|
||||
stopDockHooks()
|
||||
|
||||
// Send shutdown status message.
|
||||
if conf.PublicHub() {
|
||||
publishShutdownStatus()
|
||||
stopPiers()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// apiAuthenticator grants User permissions for local API requests.
|
||||
func apiAuthenticator(r *http.Request, s *http.Server) (*api.AuthToken, error) {
|
||||
// Get remote IP.
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to split host/port: %w", err)
|
||||
}
|
||||
remoteIP := net.ParseIP(host)
|
||||
if remoteIP == nil {
|
||||
return nil, fmt.Errorf("failed to parse remote address %s", host)
|
||||
}
|
||||
|
||||
if !netutils.GetIPScope(remoteIP).IsLocalhost() {
|
||||
return nil, api.ErrAPIAccessDeniedMessage
|
||||
}
|
||||
|
||||
return &api.AuthToken{
|
||||
Read: api.PermitUser,
|
||||
Write: api.PermitUser,
|
||||
}, nil
|
||||
}
|
||||
306
spn/captain/navigation.go
Normal file
306
spn/captain/navigation.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/intel"
|
||||
"github.com/safing/portmaster/service/netenv"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
"github.com/safing/portmaster/spn/access"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const stopCraneAfterBeingUnsuggestedFor = 6 * time.Hour
|
||||
|
||||
var (
|
||||
// ErrAllHomeHubsExcluded is returned when all available home hubs were excluded.
|
||||
ErrAllHomeHubsExcluded = errors.New("all home hubs are excluded")
|
||||
|
||||
// ErrReInitSPNSuggested is returned when no home hub can be found, even without rules.
|
||||
ErrReInitSPNSuggested = errors.New("SPN re-init suggested")
|
||||
)
|
||||
|
||||
func establishHomeHub(ctx context.Context) error {
|
||||
// Get own IP.
|
||||
locations, ok := netenv.GetInternetLocation()
|
||||
if !ok || len(locations.All) == 0 {
|
||||
return errors.New("failed to locate own device")
|
||||
}
|
||||
log.Debugf(
|
||||
"spn/captain: looking for new home hub near %s and %s",
|
||||
locations.BestV4(),
|
||||
locations.BestV6(),
|
||||
)
|
||||
|
||||
// Get own entity.
|
||||
// Checking the entity against the entry policies is somewhat hit and miss
|
||||
// anyway, as the device location is an approximation.
|
||||
var myEntity *intel.Entity
|
||||
if dl := locations.BestV4(); dl != nil && dl.IP != nil {
|
||||
myEntity = (&intel.Entity{IP: dl.IP}).Init(0)
|
||||
myEntity.FetchData(ctx)
|
||||
} else if dl := locations.BestV6(); dl != nil && dl.IP != nil {
|
||||
myEntity = (&intel.Entity{IP: dl.IP}).Init(0)
|
||||
myEntity.FetchData(ctx)
|
||||
}
|
||||
|
||||
// Get home hub policy for selecting the home hub.
|
||||
homePolicy, err := getHomeHubPolicy()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build navigation options for searching for a home hub.
|
||||
opts := &navigator.Options{
|
||||
Home: &navigator.HomeHubOptions{
|
||||
HubPolicies: []endpoints.Endpoints{homePolicy},
|
||||
CheckHubPolicyWith: myEntity,
|
||||
},
|
||||
}
|
||||
|
||||
// Add requirement to only use Safing nodes when not using community nodes.
|
||||
if !cfgOptionUseCommunityNodes() {
|
||||
opts.Home.RequireVerifiedOwners = NonCommunityVerifiedOwners
|
||||
}
|
||||
|
||||
// Require a trusted home node when the routing profile requires less than two hops.
|
||||
routingProfile := navigator.GetRoutingProfile(cfgOptionRoutingAlgorithm())
|
||||
if routingProfile.MinHops < 2 {
|
||||
opts.Home.Regard = opts.Home.Regard.Add(navigator.StateTrusted)
|
||||
}
|
||||
|
||||
// Find nearby hubs.
|
||||
findCandidates:
|
||||
candidates, err := navigator.Main.FindNearestHubs(
|
||||
locations.BestV4().LocationOrNil(),
|
||||
locations.BestV6().LocationOrNil(),
|
||||
opts, navigator.HomeHub,
|
||||
)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, navigator.ErrEmptyMap):
|
||||
// bootstrap to the network!
|
||||
err := bootstrapWithUpdates()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
goto findCandidates
|
||||
|
||||
case errors.Is(err, navigator.ErrAllPinsDisregarded):
|
||||
if len(homePolicy) > 0 {
|
||||
return ErrAllHomeHubsExcluded
|
||||
}
|
||||
return ErrReInitSPNSuggested
|
||||
|
||||
default:
|
||||
return fmt.Errorf("failed to find nearby hubs: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Try connecting to a hub.
|
||||
var tries int
|
||||
var candidate *hub.Hub
|
||||
for tries, candidate = range candidates {
|
||||
err = connectToHomeHub(ctx, candidate)
|
||||
if err != nil {
|
||||
// Check if context is canceled.
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
// Check if the SPN protocol is stopping again.
|
||||
if errors.Is(err, terminal.ErrStopping) {
|
||||
return err
|
||||
}
|
||||
log.Warningf("spn/captain: failed to connect to %s as new home: %s", candidate, err)
|
||||
} else {
|
||||
log.Infof("spn/captain: established connection to %s as new home with %d failed tries", candidate, tries)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to a new home hub - tried %d hubs: %w", tries+1, err)
|
||||
}
|
||||
return fmt.Errorf("no home hub candidates available")
|
||||
}
|
||||
|
||||
func connectToHomeHub(ctx context.Context, dst *hub.Hub) error {
|
||||
// Create new context with timeout.
|
||||
// The maximum timeout is a worst case safeguard.
|
||||
// Keep in mind that multiple IPs and protocols may be tried in all configurations.
|
||||
// Some servers will be (possibly on purpose) hard to reach.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Set and clean up exceptions.
|
||||
setExceptions(dst.Info.IPv4, dst.Info.IPv6)
|
||||
defer setExceptions(nil, nil)
|
||||
|
||||
// Connect to hub.
|
||||
crane, err := EstablishCrane(ctx, dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cleanup connection in case of failure.
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
crane.Stop(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
// Query all gossip msgs on first connection.
|
||||
gossipQuery, tErr := NewGossipQueryOp(crane.Controller)
|
||||
if tErr != nil {
|
||||
log.Warningf("spn/captain: failed to start initial gossip query: %s", tErr)
|
||||
}
|
||||
// Wait for gossip query to complete.
|
||||
select {
|
||||
case <-gossipQuery.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return context.Canceled
|
||||
}
|
||||
|
||||
// Create communication terminal.
|
||||
homeTerminal, initData, tErr := docks.NewLocalCraneTerminal(crane, nil, terminal.DefaultHomeHubTerminalOpts())
|
||||
if tErr != nil {
|
||||
return tErr.Wrap("failed to create home terminal")
|
||||
}
|
||||
tErr = crane.EstablishNewTerminal(homeTerminal, initData)
|
||||
if tErr != nil {
|
||||
return tErr.Wrap("failed to connect home terminal")
|
||||
}
|
||||
|
||||
if !DisableAccount {
|
||||
// Authenticate to home hub.
|
||||
authOp, tErr := access.AuthorizeToTerminal(homeTerminal)
|
||||
if tErr != nil {
|
||||
return tErr.Wrap("failed to authorize")
|
||||
}
|
||||
select {
|
||||
case tErr := <-authOp.Result:
|
||||
if !tErr.Is(terminal.ErrExplicitAck) {
|
||||
return tErr.Wrap("failed to authenticate to")
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
return terminal.ErrTimeout.With("waiting for auth to complete")
|
||||
case <-ctx.Done():
|
||||
return terminal.ErrStopping
|
||||
}
|
||||
}
|
||||
|
||||
// Set new home on map.
|
||||
ok := navigator.Main.SetHome(dst.ID, homeTerminal)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to set home hub on map")
|
||||
}
|
||||
|
||||
// Assign crane to home hub in order to query it later.
|
||||
docks.AssignCrane(crane.ConnectedHub.ID, crane)
|
||||
|
||||
success = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func optimizeNetwork(ctx context.Context, task *modules.Task) error {
|
||||
if publicIdentity == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
optimize:
|
||||
result, err := navigator.Main.Optimize(nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, navigator.ErrEmptyMap) {
|
||||
// bootstrap to the network!
|
||||
err := bootstrapWithUpdates()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
goto optimize
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Create any new connections.
|
||||
var createdConnections int
|
||||
var attemptedConnections int
|
||||
for _, connectTo := range result.SuggestedConnections {
|
||||
// Skip duplicates.
|
||||
if connectTo.Duplicate {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if connection already exists.
|
||||
crane := docks.GetAssignedCrane(connectTo.Hub.ID)
|
||||
if crane != nil {
|
||||
// Update last suggested timestamp.
|
||||
crane.NetState.UpdateLastSuggestedAt()
|
||||
// Continue crane if stopping.
|
||||
if crane.AbortStopping() {
|
||||
log.Infof("spn/captain: optimization aborted retiring of %s, removed stopping mark", crane)
|
||||
crane.NotifyUpdate()
|
||||
}
|
||||
|
||||
// Create new connections if we have connects left.
|
||||
} else if createdConnections < result.MaxConnect {
|
||||
attemptedConnections++
|
||||
|
||||
crane, tErr := EstablishPublicLane(ctx, connectTo.Hub)
|
||||
if !tErr.IsOK() {
|
||||
log.Warningf("spn/captain: failed to establish lane to %s: %s", connectTo.Hub, tErr)
|
||||
} else {
|
||||
createdConnections++
|
||||
crane.NetState.UpdateLastSuggestedAt()
|
||||
|
||||
log.Infof("spn/captain: established lane to %s", connectTo.Hub)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log optimization result.
|
||||
if attemptedConnections > 0 {
|
||||
log.Infof(
|
||||
"spn/captain: created %d/%d new connections for %s optimization",
|
||||
createdConnections,
|
||||
attemptedConnections,
|
||||
result.Purpose)
|
||||
} else {
|
||||
log.Infof(
|
||||
"spn/captain: checked %d connections for %s optimization",
|
||||
len(result.SuggestedConnections),
|
||||
result.Purpose,
|
||||
)
|
||||
}
|
||||
|
||||
// Retire cranes if unsuggested for a while.
|
||||
if result.StopOthers {
|
||||
for _, crane := range docks.GetAllAssignedCranes() {
|
||||
switch {
|
||||
case crane.Stopped():
|
||||
// Crane already stopped.
|
||||
case crane.IsStopping():
|
||||
// Crane is stopping, forcibly stop if mine and suggested.
|
||||
if crane.IsMine() && crane.NetState.StopSuggested() {
|
||||
crane.Stop(nil)
|
||||
}
|
||||
case crane.IsMine() && crane.NetState.StoppingSuggested():
|
||||
// Mark as stopping if mine and suggested.
|
||||
crane.MarkStopping()
|
||||
case crane.NetState.RequestStoppingSuggested(stopCraneAfterBeingUnsuggestedFor):
|
||||
// Mark as stopping requested.
|
||||
crane.MarkStoppingRequested()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
156
spn/captain/op_gossip.go
Normal file
156
spn/captain/op_gossip.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// GossipOpType is the type ID of the gossip operation.
|
||||
const GossipOpType string = "gossip"
|
||||
|
||||
// GossipMsgType is the gossip message type.
|
||||
type GossipMsgType uint8
|
||||
|
||||
// Gossip Message Types.
|
||||
const (
|
||||
GossipHubAnnouncementMsg GossipMsgType = 1
|
||||
GossipHubStatusMsg GossipMsgType = 2
|
||||
)
|
||||
|
||||
func (msgType GossipMsgType) String() string {
|
||||
switch msgType {
|
||||
case GossipHubAnnouncementMsg:
|
||||
return "hub announcement"
|
||||
case GossipHubStatusMsg:
|
||||
return "hub status"
|
||||
default:
|
||||
return "unknown gossip msg"
|
||||
}
|
||||
}
|
||||
|
||||
// GossipOp is used to gossip Hub messages.
|
||||
type GossipOp struct {
|
||||
terminal.OperationBase
|
||||
|
||||
craneID string
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *GossipOp) Type() string {
|
||||
return GossipOpType
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: GossipOpType,
|
||||
Requires: terminal.IsCraneController,
|
||||
Start: runGossipOp,
|
||||
})
|
||||
}
|
||||
|
||||
// NewGossipOp start a new gossip operation.
|
||||
func NewGossipOp(controller *docks.CraneControllerTerminal) (*GossipOp, *terminal.Error) {
|
||||
// Create and init.
|
||||
op := &GossipOp{
|
||||
craneID: controller.Crane.ID,
|
||||
}
|
||||
err := controller.StartOperation(op, nil, 1*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
op.InitOperationBase(controller, op.ID())
|
||||
|
||||
// Register and return.
|
||||
registerGossipOp(controller.Crane.ID, op)
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func runGossipOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Check if we are run by a controller.
|
||||
controller, ok := t.(*docks.CraneControllerTerminal)
|
||||
if !ok {
|
||||
return nil, terminal.ErrIncorrectUsage.With("gossip op may only be started by a crane controller terminal, but was started by %T", t)
|
||||
}
|
||||
|
||||
// Create, init, register and return.
|
||||
op := &GossipOp{
|
||||
craneID: controller.Crane.ID,
|
||||
}
|
||||
op.InitOperationBase(t, opID)
|
||||
registerGossipOp(controller.Crane.ID, op)
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (op *GossipOp) sendMsg(msgType GossipMsgType, data []byte) {
|
||||
// Create message.
|
||||
msg := op.NewEmptyMsg()
|
||||
msg.Data = container.New(
|
||||
varint.Pack8(uint8(msgType)),
|
||||
data,
|
||||
)
|
||||
msg.Unit.MakeHighPriority()
|
||||
|
||||
// Send.
|
||||
err := op.Send(msg, 1*time.Second)
|
||||
if err != nil {
|
||||
log.Debugf("spn/captain: failed to forward %s via %s: %s", msgType, op.craneID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
func (op *GossipOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
defer msg.Finish()
|
||||
|
||||
gossipMsgTypeN, err := msg.Data.GetNextN8()
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to parse gossip message type")
|
||||
}
|
||||
gossipMsgType := GossipMsgType(gossipMsgTypeN)
|
||||
|
||||
// Prepare data.
|
||||
data := msg.Data.CompileData()
|
||||
var announcementData, statusData []byte
|
||||
switch gossipMsgType {
|
||||
case GossipHubAnnouncementMsg:
|
||||
announcementData = data
|
||||
case GossipHubStatusMsg:
|
||||
statusData = data
|
||||
default:
|
||||
log.Warningf("spn/captain: received unknown gossip message type from %s: %d", op.craneID, gossipMsgType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Import and verify.
|
||||
h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope)
|
||||
if tErr != nil {
|
||||
if tErr.Is(hub.ErrOldData) {
|
||||
log.Debugf("spn/captain: ignoring old %s from %s", gossipMsgType, op.craneID)
|
||||
} else {
|
||||
log.Warningf("spn/captain: failed to import %s from %s: %s", gossipMsgType, op.craneID, tErr)
|
||||
}
|
||||
} else if forward {
|
||||
// Only log if we received something to save/forward.
|
||||
log.Infof("spn/captain: received %s for %s", gossipMsgType, h)
|
||||
}
|
||||
|
||||
// Relay data.
|
||||
if forward {
|
||||
gossipRelayMsg(op.craneID, gossipMsgType, data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *GossipOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
deleteGossipOp(op.craneID)
|
||||
return err
|
||||
}
|
||||
195
spn/captain/op_gossip_query.go
Normal file
195
spn/captain/op_gossip_query.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// GossipQueryOpType is the type ID of the gossip query operation.
|
||||
const GossipQueryOpType string = "gossip/query"
|
||||
|
||||
// GossipQueryOp is used to query gossip messages.
|
||||
type GossipQueryOp struct {
|
||||
terminal.OperationBase
|
||||
|
||||
t terminal.Terminal
|
||||
client bool
|
||||
importCnt int
|
||||
|
||||
ctx context.Context
|
||||
cancelCtx context.CancelFunc
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *GossipQueryOp) Type() string {
|
||||
return GossipQueryOpType
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: GossipQueryOpType,
|
||||
Requires: terminal.IsCraneController,
|
||||
Start: runGossipQueryOp,
|
||||
})
|
||||
}
|
||||
|
||||
// NewGossipQueryOp starts a new gossip query operation.
|
||||
func NewGossipQueryOp(t terminal.Terminal) (*GossipQueryOp, *terminal.Error) {
|
||||
// Create and init.
|
||||
op := &GossipQueryOp{
|
||||
t: t,
|
||||
client: true,
|
||||
}
|
||||
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
|
||||
err := t.StartOperation(op, nil, 1*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func runGossipQueryOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Create, init, register and return.
|
||||
op := &GossipQueryOp{t: t}
|
||||
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
|
||||
op.InitOperationBase(t, opID)
|
||||
|
||||
module.StartWorker("gossip query handler", op.handler)
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (op *GossipQueryOp) handler(_ context.Context) error {
|
||||
tErr := op.sendMsgs(hub.MsgTypeAnnouncement)
|
||||
if tErr != nil {
|
||||
op.Stop(op, tErr)
|
||||
return nil // Clean worker exit.
|
||||
}
|
||||
|
||||
tErr = op.sendMsgs(hub.MsgTypeStatus)
|
||||
if tErr != nil {
|
||||
op.Stop(op, tErr)
|
||||
return nil // Clean worker exit.
|
||||
}
|
||||
|
||||
op.Stop(op, nil)
|
||||
return nil // Clean worker exit.
|
||||
}
|
||||
|
||||
func (op *GossipQueryOp) sendMsgs(msgType hub.MsgType) *terminal.Error {
|
||||
it, err := hub.QueryRawGossipMsgs(conf.MainMapName, msgType)
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to query: %w", err)
|
||||
}
|
||||
defer it.Cancel()
|
||||
|
||||
iterating:
|
||||
for {
|
||||
select {
|
||||
case r := <-it.Next:
|
||||
// Check if we are done.
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure we're handling a hub msg.
|
||||
hubMsg, err := hub.EnsureHubMsg(r)
|
||||
if err != nil {
|
||||
log.Warningf("spn/captain: failed to load hub msg: %s", err)
|
||||
continue iterating
|
||||
}
|
||||
|
||||
// Create gossip msg.
|
||||
var c *container.Container
|
||||
switch hubMsg.Type {
|
||||
case hub.MsgTypeAnnouncement:
|
||||
c = container.New(
|
||||
varint.Pack8(uint8(GossipHubAnnouncementMsg)),
|
||||
hubMsg.Data,
|
||||
)
|
||||
case hub.MsgTypeStatus:
|
||||
c = container.New(
|
||||
varint.Pack8(uint8(GossipHubStatusMsg)),
|
||||
hubMsg.Data,
|
||||
)
|
||||
default:
|
||||
log.Warningf("spn/captain: unknown hub msg for gossip query at %q: %s", hubMsg.Key(), hubMsg.Type)
|
||||
}
|
||||
|
||||
// Send msg.
|
||||
if c != nil {
|
||||
msg := op.NewEmptyMsg()
|
||||
msg.Unit.MakeHighPriority()
|
||||
msg.Data = c
|
||||
tErr := op.Send(msg, 1*time.Second)
|
||||
if tErr != nil {
|
||||
return tErr.Wrap("failed to send msg")
|
||||
}
|
||||
}
|
||||
|
||||
case <-op.ctx.Done():
|
||||
return terminal.ErrStopping
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver delivers the message to the operation.
|
||||
func (op *GossipQueryOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
defer msg.Finish()
|
||||
|
||||
gossipMsgTypeN, err := msg.Data.GetNextN8()
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to parse gossip message type")
|
||||
}
|
||||
gossipMsgType := GossipMsgType(gossipMsgTypeN)
|
||||
|
||||
// Prepare data.
|
||||
data := msg.Data.CompileData()
|
||||
var announcementData, statusData []byte
|
||||
switch gossipMsgType {
|
||||
case GossipHubAnnouncementMsg:
|
||||
announcementData = data
|
||||
case GossipHubStatusMsg:
|
||||
statusData = data
|
||||
default:
|
||||
log.Warningf("spn/captain: received unknown gossip message type from gossip query: %d", gossipMsgType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Import and verify.
|
||||
h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope)
|
||||
if tErr != nil {
|
||||
log.Warningf("spn/captain: failed to import %s from gossip query: %s", gossipMsgType, tErr)
|
||||
} else {
|
||||
log.Infof("spn/captain: received %s for %s from gossip query", gossipMsgType, h)
|
||||
op.importCnt++
|
||||
}
|
||||
|
||||
// Relay data.
|
||||
if forward {
|
||||
// TODO: Find better way to get craneID.
|
||||
craneID := strings.SplitN(op.t.FmtID(), "#", 2)[0]
|
||||
gossipRelayMsg(craneID, gossipMsgType, data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *GossipQueryOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
if op.client {
|
||||
log.Infof("spn/captain: gossip query imported %d entries", op.importCnt)
|
||||
}
|
||||
op.cancelCtx()
|
||||
return err
|
||||
}
|
||||
183
spn/captain/op_publish.go
Normal file
183
spn/captain/op_publish.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// PublishOpType is the type ID of the publish operation.
|
||||
const PublishOpType string = "publish"
|
||||
|
||||
// PublishOp is used to publish a connection.
|
||||
type PublishOp struct {
|
||||
terminal.OperationBase
|
||||
controller *docks.CraneControllerTerminal
|
||||
|
||||
identity *cabin.Identity
|
||||
requestingHub *hub.Hub
|
||||
verification *cabin.Verification
|
||||
result chan *terminal.Error
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *PublishOp) Type() string {
|
||||
return PublishOpType
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: PublishOpType,
|
||||
Requires: terminal.IsCraneController,
|
||||
Start: runPublishOp,
|
||||
})
|
||||
}
|
||||
|
||||
// NewPublishOp start a new publish operation.
|
||||
func NewPublishOp(controller *docks.CraneControllerTerminal, identity *cabin.Identity) (*PublishOp, *terminal.Error) {
|
||||
// Create and init.
|
||||
op := &PublishOp{
|
||||
controller: controller,
|
||||
identity: identity,
|
||||
result: make(chan *terminal.Error, 1),
|
||||
}
|
||||
msg := container.New()
|
||||
|
||||
// Add Hub Announcement.
|
||||
announcementData, err := identity.ExportAnnouncement()
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to export announcement: %w", err)
|
||||
}
|
||||
msg.AppendAsBlock(announcementData)
|
||||
|
||||
// Add Hub Status.
|
||||
statusData, err := identity.ExportStatus()
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to export status: %w", err)
|
||||
}
|
||||
msg.AppendAsBlock(statusData)
|
||||
|
||||
tErr := controller.StartOperation(op, msg, 10*time.Second)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func runPublishOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Check if we are run by a controller.
|
||||
controller, ok := t.(*docks.CraneControllerTerminal)
|
||||
if !ok {
|
||||
return nil, terminal.ErrIncorrectUsage.With("publish op may only be started by a crane controller terminal, but was started by %T", t)
|
||||
}
|
||||
|
||||
// Parse and import Announcement and Status.
|
||||
announcementData, err := data.GetNextBlock()
|
||||
if err != nil {
|
||||
return nil, terminal.ErrMalformedData.With("failed to get announcement: %w", err)
|
||||
}
|
||||
statusData, err := data.GetNextBlock()
|
||||
if err != nil {
|
||||
return nil, terminal.ErrMalformedData.With("failed to get status: %w", err)
|
||||
}
|
||||
h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope)
|
||||
if tErr != nil {
|
||||
return nil, tErr.Wrap("failed to import and verify hub")
|
||||
}
|
||||
// Update reference in case it was changed by the import.
|
||||
controller.Crane.ConnectedHub = h
|
||||
|
||||
// Relay data.
|
||||
if forward {
|
||||
gossipRelayMsg(controller.Crane.ID, GossipHubAnnouncementMsg, announcementData)
|
||||
gossipRelayMsg(controller.Crane.ID, GossipHubStatusMsg, statusData)
|
||||
}
|
||||
|
||||
// Create verification request.
|
||||
v, request, err := cabin.CreateVerificationRequest(PublishOpType, "", "")
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to create verification request: %w", err)
|
||||
}
|
||||
|
||||
// Create operation.
|
||||
op := &PublishOp{
|
||||
controller: controller,
|
||||
requestingHub: h,
|
||||
verification: v,
|
||||
result: make(chan *terminal.Error, 1),
|
||||
}
|
||||
op.InitOperationBase(controller, opID)
|
||||
|
||||
// Reply with verification request.
|
||||
tErr = op.Send(op.NewMsg(request), 10*time.Second)
|
||||
if tErr != nil {
|
||||
return nil, tErr.Wrap("failed to send verification request")
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
func (op *PublishOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
defer msg.Finish()
|
||||
|
||||
if op.identity != nil {
|
||||
// Client
|
||||
|
||||
// Sign the received verification request.
|
||||
response, err := op.identity.SignVerificationRequest(msg.Data.CompileData(), PublishOpType, "", "")
|
||||
if err != nil {
|
||||
return terminal.ErrPermissionDenied.With("signing verification request failed: %w", err)
|
||||
}
|
||||
|
||||
return op.Send(op.NewMsg(response), 10*time.Second)
|
||||
} else if op.requestingHub != nil {
|
||||
// Server
|
||||
|
||||
// Verify the signed request.
|
||||
err := op.verification.Verify(msg.Data.CompileData(), op.requestingHub)
|
||||
if err != nil {
|
||||
return terminal.ErrPermissionDenied.With("checking verification request failed: %w", err)
|
||||
}
|
||||
return terminal.ErrExplicitAck
|
||||
}
|
||||
|
||||
return terminal.ErrInternalError.With("invalid operation state")
|
||||
}
|
||||
|
||||
// Result returns the result (end error) of the operation.
|
||||
func (op *PublishOp) Result() <-chan *terminal.Error {
|
||||
return op.result
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *PublishOp) HandleStop(tErr *terminal.Error) (errorToSend *terminal.Error) {
|
||||
if tErr.Is(terminal.ErrExplicitAck) {
|
||||
// TODO: Check for concurrenct access.
|
||||
if op.controller.Crane.ConnectedHub == nil {
|
||||
op.controller.Crane.ConnectedHub = op.requestingHub
|
||||
}
|
||||
|
||||
// Publish crane, abort if it fails.
|
||||
err := op.controller.Crane.Publish()
|
||||
if err != nil {
|
||||
tErr = terminal.ErrInternalError.With("failed to publish crane: %w", err)
|
||||
op.controller.Crane.Stop(tErr)
|
||||
} else {
|
||||
op.controller.Crane.NotifyUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case op.result <- tErr:
|
||||
default:
|
||||
}
|
||||
return tErr
|
||||
}
|
||||
131
spn/captain/piers.go
Normal file
131
spn/captain/piers.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/intel"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
)
|
||||
|
||||
var (
|
||||
dockingRequests = make(chan ships.Ship, 100)
|
||||
piers []ships.Pier
|
||||
)
|
||||
|
||||
func startPiers() error {
|
||||
// Get and check transports.
|
||||
transports := publicIdentity.Hub.Info.Transports
|
||||
if len(transports) == 0 {
|
||||
return errors.New("no transports defined")
|
||||
}
|
||||
|
||||
piers = make([]ships.Pier, 0, len(transports))
|
||||
for _, t := range transports {
|
||||
// Parse transport.
|
||||
transport, err := hub.ParseTransport(t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot build pier for invalid transport %q: %w", t, err)
|
||||
}
|
||||
|
||||
// Establish pier / listener.
|
||||
pier, err := ships.EstablishPier(transport, dockingRequests)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to establish pier for transport %q: %w", t, err)
|
||||
}
|
||||
|
||||
piers = append(piers, pier)
|
||||
log.Infof("spn/captain: pier for transport %q built", t)
|
||||
}
|
||||
|
||||
// Start worker to handle docking requests.
|
||||
module.StartServiceWorker("docking request handler", 0, dockingRequestHandler)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopPiers() {
|
||||
for _, pier := range piers {
|
||||
pier.Abolish()
|
||||
}
|
||||
}
|
||||
|
||||
func dockingRequestHandler(ctx context.Context) error {
|
||||
// Sink all waiting ships when this worker ends.
|
||||
// But don't be destructive so the service worker could recover.
|
||||
defer func() {
|
||||
for {
|
||||
select {
|
||||
case ship := <-dockingRequests:
|
||||
if ship != nil {
|
||||
ship.Sink()
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case ship := <-dockingRequests:
|
||||
// Ignore nil ships.
|
||||
if ship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := checkDockingPermission(ctx, ship); err != nil {
|
||||
log.Warningf("spn/captain: denied ship from %s to dock at pier %s: %s", ship.RemoteAddr(), ship.Transport().String(), err)
|
||||
} else {
|
||||
handleDockingRequest(ship)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkDockingPermission(ctx context.Context, ship ships.Ship) error {
|
||||
remoteIP, remotePort, err := netutils.IPPortFromAddr(ship.RemoteAddr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse remote IP: %w", err)
|
||||
}
|
||||
|
||||
// Create entity.
|
||||
entity := (&intel.Entity{
|
||||
IP: remoteIP,
|
||||
Protocol: uint8(netutils.ProtocolFromNetwork(ship.RemoteAddr().Network())),
|
||||
Port: remotePort,
|
||||
}).Init(ship.Transport().Port)
|
||||
entity.FetchData(ctx)
|
||||
|
||||
// Check against policy.
|
||||
result, reason := publicIdentity.Hub.GetInfo().EntryPolicy().Match(ctx, entity)
|
||||
if result == endpoints.Denied {
|
||||
return fmt.Errorf("entry policy violated: %s", reason)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleDockingRequest(ship ships.Ship) {
|
||||
log.Infof("spn/captain: pemitting %s to dock", ship)
|
||||
|
||||
crane, err := docks.NewCrane(ship, nil, publicIdentity)
|
||||
if err != nil {
|
||||
log.Warningf("spn/captain: failed to commission crane for %s: %s", ship, err)
|
||||
return
|
||||
}
|
||||
|
||||
module.StartWorker("start crane", func(ctx context.Context) error {
|
||||
_ = crane.Start(ctx)
|
||||
// Crane handles errors internally.
|
||||
return nil
|
||||
})
|
||||
}
|
||||
247
spn/captain/public.go
Normal file
247
spn/captain/public.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/metrics"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/safing/portmaster/spn/patrol"
|
||||
)
|
||||
|
||||
const (
|
||||
maintainStatusInterval = 15 * time.Minute
|
||||
maintainStatusUpdateDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
publicIdentity *cabin.Identity
|
||||
publicIdentityKey = "core:spn/public/identity"
|
||||
|
||||
publicIdentityUpdateTask *modules.Task
|
||||
statusUpdateTask *modules.Task
|
||||
)
|
||||
|
||||
func loadPublicIdentity() (err error) {
|
||||
var changed bool
|
||||
|
||||
publicIdentity, changed, err = cabin.LoadIdentity(publicIdentityKey)
|
||||
switch {
|
||||
case err == nil:
|
||||
// load was successful
|
||||
log.Infof("spn/captain: loaded public hub identity %s", publicIdentity.Hub.ID)
|
||||
case errors.Is(err, database.ErrNotFound):
|
||||
// does not exist, create new
|
||||
publicIdentity, err = cabin.CreateIdentity(module.Ctx, conf.MainMapName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new identity: %w", err)
|
||||
}
|
||||
publicIdentity.SetKey(publicIdentityKey)
|
||||
changed = true
|
||||
|
||||
log.Infof("spn/captain: created new public hub identity %s", publicIdentity.ID)
|
||||
default:
|
||||
// loading error, abort
|
||||
return fmt.Errorf("failed to load public identity: %w", err)
|
||||
}
|
||||
|
||||
// Save to database if the identity changed.
|
||||
if changed {
|
||||
err = publicIdentity.Save()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save new/updated identity to database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set available networks.
|
||||
conf.SetHubNetworks(
|
||||
publicIdentity.Hub.Info.IPv4 != nil,
|
||||
publicIdentity.Hub.Info.IPv6 != nil,
|
||||
)
|
||||
if cfgOptionBindToAdvertised() {
|
||||
conf.SetBindAddr(publicIdentity.Hub.Info.IPv4, publicIdentity.Hub.Info.IPv6)
|
||||
}
|
||||
|
||||
// Set Home Hub before updating the hub on the map, as this would trigger a
|
||||
// recalculation without a Home Hub.
|
||||
ok := navigator.Main.SetHome(publicIdentity.ID, nil)
|
||||
// Always update the navigator in any case in order to sync the reference to
|
||||
// the active struct of the identity.
|
||||
navigator.Main.UpdateHub(publicIdentity.Hub)
|
||||
// Setting the Home Hub will have failed if the identidy was only just
|
||||
// created - try again if it failed.
|
||||
if !ok {
|
||||
ok = navigator.Main.SetHome(publicIdentity.ID, nil)
|
||||
if !ok {
|
||||
return errors.New("failed to set self as home hub")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepPublicIdentityMgmt() error {
|
||||
publicIdentityUpdateTask = module.NewTask(
|
||||
"maintain public identity",
|
||||
maintainPublicIdentity,
|
||||
)
|
||||
|
||||
statusUpdateTask = module.NewTask(
|
||||
"maintain public status",
|
||||
maintainPublicStatus,
|
||||
).Repeat(maintainStatusInterval)
|
||||
|
||||
return module.RegisterEventHook(
|
||||
"config",
|
||||
"config change",
|
||||
"update public identity from config",
|
||||
func(_ context.Context, _ interface{}) error {
|
||||
// trigger update in 5 minutes
|
||||
publicIdentityUpdateTask.Schedule(time.Now().Add(5 * time.Minute))
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// TriggerHubStatusMaintenance queues the Hub status update task to be executed.
|
||||
func TriggerHubStatusMaintenance() {
|
||||
if statusUpdateTask != nil {
|
||||
statusUpdateTask.Queue()
|
||||
}
|
||||
}
|
||||
|
||||
func maintainPublicIdentity(ctx context.Context, task *modules.Task) error {
|
||||
changed, err := publicIdentity.MaintainAnnouncement(nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to maintain announcement: %w", err)
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update on map.
|
||||
navigator.Main.UpdateHub(publicIdentity.Hub)
|
||||
log.Debug("spn/captain: updated own hub on map after announcement change")
|
||||
|
||||
// export announcement
|
||||
announcementData, err := publicIdentity.ExportAnnouncement()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to export announcement: %w", err)
|
||||
}
|
||||
|
||||
// forward to other connected Hubs
|
||||
gossipRelayMsg("", GossipHubAnnouncementMsg, announcementData)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func maintainPublicStatus(ctx context.Context, task *modules.Task) error {
|
||||
// Get current lanes.
|
||||
cranes := docks.GetAllAssignedCranes()
|
||||
lanes := make([]*hub.Lane, 0, len(cranes))
|
||||
for _, crane := range cranes {
|
||||
// Ignore private, stopped or stopping cranes.
|
||||
if !crane.Public() || crane.Stopped() || crane.IsStopping() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get measurements.
|
||||
measurements := crane.ConnectedHub.GetMeasurements()
|
||||
latency, _ := measurements.GetLatency()
|
||||
capacity, _ := measurements.GetCapacity()
|
||||
|
||||
// Add crane lane.
|
||||
lanes = append(lanes, &hub.Lane{
|
||||
ID: crane.ConnectedHub.ID,
|
||||
Latency: latency,
|
||||
Capacity: capacity,
|
||||
})
|
||||
}
|
||||
// Sort Lanes for comparing.
|
||||
hub.SortLanes(lanes)
|
||||
|
||||
// Get system load and convert to fixed steps.
|
||||
var load int
|
||||
loadAvg, ok := metrics.LoadAvg15()
|
||||
switch {
|
||||
case !ok:
|
||||
load = -1
|
||||
case loadAvg >= 1:
|
||||
load = 100
|
||||
case loadAvg >= 0.95:
|
||||
load = 95
|
||||
case loadAvg >= 0.8:
|
||||
load = 80
|
||||
default:
|
||||
load = 0
|
||||
}
|
||||
if loadAvg >= 0.8 {
|
||||
log.Warningf("spn/captain: publishing 15m system load average of %.2f as %d", loadAvg, load)
|
||||
}
|
||||
|
||||
// Set flags.
|
||||
var flags []string
|
||||
if !patrol.HTTPSConnectivityConfirmed() {
|
||||
flags = append(flags, hub.FlagNetError)
|
||||
}
|
||||
// Sort Lanes for comparing.
|
||||
sort.Strings(flags)
|
||||
|
||||
// Run maintenance with the new data.
|
||||
changed, err := publicIdentity.MaintainStatus(lanes, &load, flags, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to maintain status: %w", err)
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update on map.
|
||||
navigator.Main.UpdateHub(publicIdentity.Hub)
|
||||
log.Debug("spn/captain: updated own hub on map after status change")
|
||||
|
||||
// export status
|
||||
statusData, err := publicIdentity.ExportStatus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to export status: %w", err)
|
||||
}
|
||||
|
||||
// forward to other connected Hubs
|
||||
gossipRelayMsg("", GossipHubStatusMsg, statusData)
|
||||
|
||||
log.Infof(
|
||||
"spn/captain: updated status with load %d and current lanes: %v",
|
||||
publicIdentity.Hub.Status.Load,
|
||||
publicIdentity.Hub.Status.Lanes,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func publishShutdownStatus() {
|
||||
// Create offline status.
|
||||
offlineStatusData, err := publicIdentity.MakeOfflineStatus()
|
||||
if err != nil {
|
||||
log.Errorf("spn/captain: failed to create offline status: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward to other connected Hubs.
|
||||
gossipRelayMsg("", GossipHubStatusMsg, offlineStatusData)
|
||||
|
||||
// Leave some time for the message to broadcast.
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
log.Infof("spn/captain: broadcasted offline status")
|
||||
}
|
||||
154
spn/captain/status.go
Normal file
154
spn/captain/status.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package captain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/runtime"
|
||||
"github.com/safing/portbase/utils/debug"
|
||||
"github.com/safing/portmaster/service/intel/geoip"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
)
|
||||
|
||||
// SPNStatus holds SPN status information.
|
||||
type SPNStatus struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Status SPNStatusName
|
||||
HomeHubID string
|
||||
HomeHubName string
|
||||
ConnectedIP string
|
||||
ConnectedTransport string
|
||||
ConnectedCountry *geoip.CountryInfo
|
||||
ConnectedSince *time.Time
|
||||
}
|
||||
|
||||
// SPNStatusName is a SPN status.
|
||||
type SPNStatusName string
|
||||
|
||||
// SPN Stati.
|
||||
const (
|
||||
StatusFailed SPNStatusName = "failed"
|
||||
StatusDisabled SPNStatusName = "disabled"
|
||||
StatusConnecting SPNStatusName = "connecting"
|
||||
StatusConnected SPNStatusName = "connected"
|
||||
)
|
||||
|
||||
var (
|
||||
spnStatus = &SPNStatus{
|
||||
Status: StatusDisabled,
|
||||
}
|
||||
spnStatusPushFunc runtime.PushFunc
|
||||
)
|
||||
|
||||
func registerSPNStatusProvider() (err error) {
|
||||
spnStatus.SetKey("runtime:spn/status")
|
||||
spnStatus.UpdateMeta()
|
||||
spnStatusPushFunc, err = runtime.Register("spn/status", runtime.ProvideRecord(spnStatus))
|
||||
return
|
||||
}
|
||||
|
||||
func resetSPNStatus(statusName SPNStatusName, overrideEvenIfConnected bool) {
|
||||
// Lock for updating values.
|
||||
spnStatus.Lock()
|
||||
defer spnStatus.Unlock()
|
||||
|
||||
// Ignore when connected and not overriding
|
||||
if !overrideEvenIfConnected && spnStatus.Status == StatusConnected {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset status.
|
||||
spnStatus.Status = statusName
|
||||
spnStatus.HomeHubID = ""
|
||||
spnStatus.HomeHubName = ""
|
||||
spnStatus.ConnectedIP = ""
|
||||
spnStatus.ConnectedTransport = ""
|
||||
spnStatus.ConnectedCountry = nil
|
||||
spnStatus.ConnectedSince = nil
|
||||
|
||||
// Push new status.
|
||||
pushSPNStatusUpdate()
|
||||
}
|
||||
|
||||
// pushSPNStatusUpdate pushes an update of spnStatus, which must be locked.
|
||||
func pushSPNStatusUpdate() {
|
||||
spnStatus.UpdateMeta()
|
||||
spnStatusPushFunc(spnStatus)
|
||||
}
|
||||
|
||||
// GetSPNStatus returns the current SPN status.
|
||||
func GetSPNStatus() *SPNStatus {
|
||||
spnStatus.Lock()
|
||||
defer spnStatus.Unlock()
|
||||
|
||||
return &SPNStatus{
|
||||
Status: spnStatus.Status,
|
||||
HomeHubID: spnStatus.HomeHubID,
|
||||
HomeHubName: spnStatus.HomeHubName,
|
||||
ConnectedIP: spnStatus.ConnectedIP,
|
||||
ConnectedTransport: spnStatus.ConnectedTransport,
|
||||
ConnectedCountry: spnStatus.ConnectedCountry,
|
||||
ConnectedSince: spnStatus.ConnectedSince,
|
||||
}
|
||||
}
|
||||
|
||||
// AddToDebugInfo adds the SPN status to the given debug.Info.
|
||||
func AddToDebugInfo(di *debug.Info) {
|
||||
spnStatus.Lock()
|
||||
defer spnStatus.Unlock()
|
||||
|
||||
// Check if SPN module is enabled.
|
||||
var moduleStatus string
|
||||
spnEnabled := config.GetAsBool(CfgOptionEnableSPNKey, false)
|
||||
if spnEnabled() {
|
||||
moduleStatus = "enabled"
|
||||
} else {
|
||||
moduleStatus = "disabled"
|
||||
}
|
||||
|
||||
// Collect status data.
|
||||
lines := make([]string, 0, 20)
|
||||
lines = append(lines, fmt.Sprintf("HomeHubID: %v", spnStatus.HomeHubID))
|
||||
lines = append(lines, fmt.Sprintf("HomeHubName: %v", spnStatus.HomeHubName))
|
||||
lines = append(lines, fmt.Sprintf("HomeHubIP: %v", spnStatus.ConnectedIP))
|
||||
lines = append(lines, fmt.Sprintf("Transport: %v", spnStatus.ConnectedTransport))
|
||||
if spnStatus.ConnectedSince != nil {
|
||||
lines = append(lines, fmt.Sprintf("Connected: %v ago", time.Since(*spnStatus.ConnectedSince).Round(time.Minute)))
|
||||
}
|
||||
lines = append(lines, "---")
|
||||
lines = append(lines, fmt.Sprintf("Client: %v", conf.Client()))
|
||||
lines = append(lines, fmt.Sprintf("PublicHub: %v", conf.PublicHub()))
|
||||
lines = append(lines, fmt.Sprintf("HubHasIPv4: %v", conf.HubHasIPv4()))
|
||||
lines = append(lines, fmt.Sprintf("HubHasIPv6: %v", conf.HubHasIPv6()))
|
||||
|
||||
// Collect status data of map.
|
||||
if navigator.Main != nil {
|
||||
lines = append(lines, "---")
|
||||
mainMapStats := navigator.Main.Stats()
|
||||
lines = append(lines, fmt.Sprintf("Map %s:", navigator.Main.Name))
|
||||
lines = append(lines, fmt.Sprintf("Active Terminals: %d Hubs", mainMapStats.ActiveTerminals))
|
||||
// Collect hub states.
|
||||
mapStateSummary := make([]string, 0, len(mainMapStats.States))
|
||||
for state, cnt := range mainMapStats.States {
|
||||
if cnt > 0 {
|
||||
mapStateSummary = append(mapStateSummary, fmt.Sprintf("State %s: %d Hubs", state, cnt))
|
||||
}
|
||||
}
|
||||
sort.Strings(mapStateSummary)
|
||||
lines = append(lines, mapStateSummary...)
|
||||
}
|
||||
|
||||
// Add all data as section.
|
||||
di.AddSection(
|
||||
fmt.Sprintf("SPN: %s (module %s)", spnStatus.Status, moduleStatus),
|
||||
debug.UseCodeSection|debug.AddContentLineBreaks,
|
||||
lines...,
|
||||
)
|
||||
}
|
||||
17
spn/conf/map.go
Normal file
17
spn/conf/map.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
)
|
||||
|
||||
// Primary Map Configuration.
|
||||
var (
|
||||
MainMapName = "main"
|
||||
MainMapScope = hub.ScopePublic
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&MainMapName, "spn-map", "main", "set main SPN map - use only for testing")
|
||||
}
|
||||
30
spn/conf/mode.go
Normal file
30
spn/conf/mode.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
var (
|
||||
publicHub = abool.New()
|
||||
client = abool.New()
|
||||
)
|
||||
|
||||
// PublicHub returns whether this is a public Hub.
|
||||
func PublicHub() bool {
|
||||
return publicHub.IsSet()
|
||||
}
|
||||
|
||||
// EnablePublicHub enables the public hub mode.
|
||||
func EnablePublicHub(enable bool) {
|
||||
publicHub.SetTo(enable)
|
||||
}
|
||||
|
||||
// Client returns whether this is a client.
|
||||
func Client() bool {
|
||||
return client.IsSet()
|
||||
}
|
||||
|
||||
// EnableClient enables the client mode.
|
||||
func EnableClient(enable bool) {
|
||||
client.SetTo(enable)
|
||||
}
|
||||
110
spn/conf/networks.go
Normal file
110
spn/conf/networks.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
var (
|
||||
hubHasV4 = abool.New()
|
||||
hubHasV6 = abool.New()
|
||||
)
|
||||
|
||||
// SetHubNetworks sets the available IP networks on the Hub.
|
||||
func SetHubNetworks(v4, v6 bool) {
|
||||
hubHasV4.SetTo(v4)
|
||||
hubHasV6.SetTo(v6)
|
||||
}
|
||||
|
||||
// HubHasIPv4 returns whether the Hub has IPv4 support.
|
||||
func HubHasIPv4() bool {
|
||||
return hubHasV4.IsSet()
|
||||
}
|
||||
|
||||
// HubHasIPv6 returns whether the Hub has IPv6 support.
|
||||
func HubHasIPv6() bool {
|
||||
return hubHasV6.IsSet()
|
||||
}
|
||||
|
||||
var (
|
||||
bindIPv4 net.IP
|
||||
bindIPv6 net.IP
|
||||
bindIPLock sync.Mutex
|
||||
)
|
||||
|
||||
// SetBindAddr sets the preferred connect (bind) addresses.
|
||||
func SetBindAddr(ip4, ip6 net.IP) {
|
||||
bindIPLock.Lock()
|
||||
defer bindIPLock.Unlock()
|
||||
|
||||
bindIPv4 = ip4
|
||||
bindIPv6 = ip6
|
||||
}
|
||||
|
||||
// BindAddrIsSet returns whether any bind address is set.
|
||||
func BindAddrIsSet() bool {
|
||||
bindIPLock.Lock()
|
||||
defer bindIPLock.Unlock()
|
||||
|
||||
return bindIPv4 != nil || bindIPv6 != nil
|
||||
}
|
||||
|
||||
// GetBindAddr returns an address with the preferred binding address for the
|
||||
// given dial network.
|
||||
// The dial network must have a suffix specifying the IP version.
|
||||
func GetBindAddr(dialNetwork string) net.Addr {
|
||||
bindIPLock.Lock()
|
||||
defer bindIPLock.Unlock()
|
||||
|
||||
switch dialNetwork {
|
||||
case "ip4":
|
||||
if bindIPv4 != nil {
|
||||
return &net.IPAddr{IP: bindIPv4}
|
||||
}
|
||||
case "ip6":
|
||||
if bindIPv6 != nil {
|
||||
return &net.IPAddr{IP: bindIPv6}
|
||||
}
|
||||
case "tcp4":
|
||||
if bindIPv4 != nil {
|
||||
return &net.TCPAddr{IP: bindIPv4}
|
||||
}
|
||||
case "tcp6":
|
||||
if bindIPv6 != nil {
|
||||
return &net.TCPAddr{IP: bindIPv6}
|
||||
}
|
||||
case "udp4":
|
||||
if bindIPv4 != nil {
|
||||
return &net.UDPAddr{IP: bindIPv4}
|
||||
}
|
||||
case "udp6":
|
||||
if bindIPv6 != nil {
|
||||
return &net.UDPAddr{IP: bindIPv6}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBindIPs returns the preferred binding IPs.
|
||||
// Returns a slice with a single nil IP if no preferred binding IPs are set.
|
||||
func GetBindIPs() []net.IP {
|
||||
bindIPLock.Lock()
|
||||
defer bindIPLock.Unlock()
|
||||
|
||||
switch {
|
||||
case bindIPv4 == nil && bindIPv6 == nil:
|
||||
// Match most common case first.
|
||||
return []net.IP{nil}
|
||||
case bindIPv4 != nil && bindIPv6 != nil:
|
||||
return []net.IP{bindIPv4, bindIPv6}
|
||||
case bindIPv4 != nil:
|
||||
return []net.IP{bindIPv4}
|
||||
case bindIPv6 != nil:
|
||||
return []net.IP{bindIPv6}
|
||||
}
|
||||
|
||||
return []net.IP{nil}
|
||||
}
|
||||
9
spn/conf/version.go
Normal file
9
spn/conf/version.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package conf
|
||||
|
||||
const (
|
||||
// VersionOne is the first protocol version.
|
||||
VersionOne = 1
|
||||
|
||||
// CurrentVersion always holds the newest version in production.
|
||||
CurrentVersion = 1
|
||||
)
|
||||
482
spn/crew/connect.go
Normal file
482
spn/crew/connect.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
"github.com/safing/portmaster/spn/access"
|
||||
"github.com/safing/portmaster/spn/docks"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// connectLock locks all routing operations to mitigate racy stuff for now.
|
||||
// TODO: Find a nice way to parallelize route creation.
|
||||
var connectLock sync.Mutex
|
||||
|
||||
// HandleSluiceRequest handles a sluice request to build a tunnel.
|
||||
func HandleSluiceRequest(connInfo *network.Connection, conn net.Conn) {
|
||||
if conn == nil {
|
||||
log.Debugf("spn/crew: closing tunnel for %s before starting because of shutdown", connInfo)
|
||||
|
||||
// This is called within the connInfo lock.
|
||||
connInfo.Failed("tunnel entry closed", "")
|
||||
connInfo.SaveWhenFinished()
|
||||
return
|
||||
}
|
||||
|
||||
t := &Tunnel{
|
||||
connInfo: connInfo,
|
||||
conn: conn,
|
||||
}
|
||||
module.StartWorker("tunnel handler", t.connectWorker)
|
||||
}
|
||||
|
||||
// Tunnel represents the local information and endpoint of a data tunnel.
|
||||
type Tunnel struct {
|
||||
connInfo *network.Connection
|
||||
conn net.Conn
|
||||
|
||||
dstPin *navigator.Pin
|
||||
dstTerminal terminal.Terminal
|
||||
route *navigator.Route
|
||||
failedTries int
|
||||
stickied bool
|
||||
}
|
||||
|
||||
func (t *Tunnel) connectWorker(ctx context.Context) (err error) {
|
||||
// Get tracing logger.
|
||||
ctx, tracer := log.AddTracer(ctx)
|
||||
defer tracer.Submit()
|
||||
|
||||
// Save start time.
|
||||
started := time.Now()
|
||||
|
||||
// Check the status of the Home Hub.
|
||||
home, homeTerminal := navigator.Main.GetHome()
|
||||
if home == nil || homeTerminal == nil || homeTerminal.IsBeingAbandoned() {
|
||||
reportConnectError(terminal.ErrUnknownError.With("home terminal is abandoned"))
|
||||
|
||||
t.connInfo.Lock()
|
||||
defer t.connInfo.Unlock()
|
||||
t.connInfo.Failed("SPN not ready for tunneling", "")
|
||||
t.connInfo.Save()
|
||||
|
||||
tracer.Infof("spn/crew: not tunneling %s, as the SPN is not ready", t.connInfo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create path through the SPN.
|
||||
err = t.establish(ctx)
|
||||
if err != nil {
|
||||
log.Warningf("spn/crew: failed to establish route for %s: %s", t.connInfo, err)
|
||||
|
||||
// TODO: Clean this up.
|
||||
t.connInfo.Lock()
|
||||
defer t.connInfo.Unlock()
|
||||
t.connInfo.Failed(fmt.Sprintf("SPN failed to establish route: %s", err), "")
|
||||
t.connInfo.Save()
|
||||
|
||||
tracer.Warningf("spn/crew: failed to establish route for %s: %s", t.connInfo, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect via established tunnel.
|
||||
_, tErr := NewConnectOp(t)
|
||||
if tErr != nil {
|
||||
tErr = tErr.Wrap("failed to initialize tunnel")
|
||||
reportConnectError(tErr)
|
||||
|
||||
t.connInfo.Lock()
|
||||
defer t.connInfo.Unlock()
|
||||
t.connInfo.Failed(fmt.Sprintf("SPN failed to initialize data tunnel (connect op): %s", tErr.Error()), "")
|
||||
t.connInfo.Save()
|
||||
|
||||
// TODO: try with another route?
|
||||
tracer.Warningf("spn/crew: failed to initialize data tunnel (connect op) for %s: %s", t.connInfo, err)
|
||||
return tErr
|
||||
}
|
||||
|
||||
// Report time taken to find, build and check route and send connect request.
|
||||
connectOpTTCRDurationHistogram.UpdateDuration(started)
|
||||
|
||||
t.connInfo.Lock()
|
||||
defer t.connInfo.Unlock()
|
||||
addTunnelContextToConnection(t)
|
||||
t.connInfo.Save()
|
||||
|
||||
tracer.Infof("spn/crew: connected %s via %s", t.connInfo, t.dstPin.Hub)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) establish(ctx context.Context) (err error) {
|
||||
var routes *navigator.Routes
|
||||
|
||||
// Check if the destination sticks to a Hub.
|
||||
sticksTo := getStickiedHub(t.connInfo)
|
||||
switch {
|
||||
case sticksTo == nil:
|
||||
// Continue.
|
||||
|
||||
case sticksTo.Avoid:
|
||||
log.Tracer(ctx).Tracef("spn/crew: avoiding %s", sticksTo.Pin.Hub)
|
||||
|
||||
// Avoid this Hub.
|
||||
// TODO: Remember more than one hub to avoid.
|
||||
avoidPolicy := []endpoints.Endpoint{
|
||||
&endpoints.EndpointDomain{
|
||||
OriginalValue: sticksTo.Pin.Hub.ID,
|
||||
Domain: strings.ToLower(sticksTo.Pin.Hub.ID) + ".",
|
||||
},
|
||||
}
|
||||
|
||||
// Append to policies.
|
||||
t.connInfo.TunnelOpts.Destination.HubPolicies = append(t.connInfo.TunnelOpts.Destination.HubPolicies, avoidPolicy)
|
||||
|
||||
default:
|
||||
log.Tracer(ctx).Tracef("spn/crew: using stickied %s", sticksTo.Pin.Hub)
|
||||
|
||||
// Check if the stickied Hub has an active terminal.
|
||||
dstTerminal := sticksTo.Pin.GetActiveTerminal()
|
||||
if dstTerminal != nil {
|
||||
t.dstPin = sticksTo.Pin
|
||||
t.dstTerminal = dstTerminal
|
||||
t.route = sticksTo.Route
|
||||
t.stickied = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// If not, attempt to find a route to the stickied hub.
|
||||
routes, err = navigator.Main.FindRouteToHub(
|
||||
sticksTo.Pin.Hub.ID,
|
||||
t.connInfo.TunnelOpts,
|
||||
)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Tracef("spn/crew: failed to find route to stickied %s: %s", sticksTo.Pin.Hub, err)
|
||||
routes = nil
|
||||
} else {
|
||||
t.stickied = true
|
||||
}
|
||||
}
|
||||
|
||||
// Find possible routes to destination.
|
||||
if routes == nil {
|
||||
log.Tracer(ctx).Trace("spn/crew: finding routes...")
|
||||
routes, err = navigator.Main.FindRoutes(
|
||||
t.connInfo.Entity.IP,
|
||||
t.connInfo.TunnelOpts,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find routes to %s: %w", t.connInfo.Entity.IP, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if routes are okay (again).
|
||||
if len(routes.All) == 0 {
|
||||
return fmt.Errorf("no routes to %s", t.connInfo.Entity.IP)
|
||||
}
|
||||
|
||||
// Try routes until one succeeds.
|
||||
log.Tracer(ctx).Trace("spn/crew: establishing route...")
|
||||
var dstPin *navigator.Pin
|
||||
var dstTerminal terminal.Terminal
|
||||
for tries, route := range routes.All {
|
||||
dstPin, dstTerminal, err = establishRoute(route)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Assign route data to tunnel.
|
||||
t.dstPin = dstPin
|
||||
t.dstTerminal = dstTerminal
|
||||
t.route = route
|
||||
t.failedTries = tries
|
||||
|
||||
// Push changes to Pins and return.
|
||||
navigator.Main.PushPinChanges()
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to establish a route to %s: %w", t.connInfo.Entity.IP, err)
|
||||
}
|
||||
|
||||
type hopCheck struct {
|
||||
pin *navigator.Pin
|
||||
route *navigator.Route
|
||||
expansion *docks.ExpansionTerminal
|
||||
authOp *access.AuthorizeOp
|
||||
pingOp *PingOp
|
||||
}
|
||||
|
||||
func establishRoute(route *navigator.Route) (dstPin *navigator.Pin, dstTerminal terminal.Terminal, err error) {
|
||||
connectLock.Lock()
|
||||
defer connectLock.Unlock()
|
||||
|
||||
// Check for path length.
|
||||
if len(route.Path) < 1 {
|
||||
return nil, nil, errors.New("path too short")
|
||||
}
|
||||
|
||||
// Check for failing hubs in path.
|
||||
for _, hop := range route.Path[1:] {
|
||||
if hop.Pin().GetState().Has(navigator.StateFailing) {
|
||||
return nil, nil, fmt.Errorf("failing hub in path: %s", hop.Pin().Hub.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// Get home hub.
|
||||
previousHop, homeTerminal := navigator.Main.GetHome()
|
||||
if previousHop == nil || homeTerminal == nil {
|
||||
return nil, nil, navigator.ErrHomeHubUnset
|
||||
}
|
||||
// Convert to interface for later use.
|
||||
var previousTerminal terminal.Terminal = homeTerminal
|
||||
|
||||
// Check if first hub in path is the home hub.
|
||||
if route.Path[0].HubID != previousHop.Hub.ID {
|
||||
return nil, nil, errors.New("path start does not match home hub")
|
||||
}
|
||||
|
||||
// Check if path only exists of home hub.
|
||||
if len(route.Path) == 1 {
|
||||
return previousHop, previousTerminal, nil
|
||||
}
|
||||
|
||||
// TODO: Check what needs locking.
|
||||
|
||||
// Build path and save created paths.
|
||||
hopChecks := make([]*hopCheck, 0, len(route.Path)-1)
|
||||
for i, hop := range route.Path[1:] {
|
||||
// Check if we already have a connection to the Hub.
|
||||
activeTerminal := hop.Pin().GetActiveTerminal()
|
||||
if activeTerminal != nil {
|
||||
// Ping terminal if not recently checked.
|
||||
if activeTerminal.NeedsReachableCheck(1 * time.Minute) {
|
||||
pingOp, tErr := NewPingOp(activeTerminal)
|
||||
if tErr.IsError() {
|
||||
return nil, nil, tErr.Wrap("failed start ping to %s", hop.Pin())
|
||||
}
|
||||
// Add for checking results later.
|
||||
hopChecks = append(hopChecks, &hopCheck{
|
||||
pin: hop.Pin(),
|
||||
route: route.CopyUpTo(i + 2),
|
||||
expansion: activeTerminal,
|
||||
pingOp: pingOp,
|
||||
})
|
||||
}
|
||||
|
||||
previousHop = hop.Pin()
|
||||
previousTerminal = activeTerminal
|
||||
continue
|
||||
}
|
||||
|
||||
// Expand to next Hub.
|
||||
expansion, authOp, tErr := expand(previousTerminal, previousHop, hop.Pin())
|
||||
if tErr != nil {
|
||||
return nil, nil, tErr.Wrap("failed to expand to %s", hop.Pin())
|
||||
}
|
||||
|
||||
// Add for checking results later.
|
||||
hopChecks = append(hopChecks, &hopCheck{
|
||||
pin: hop.Pin(),
|
||||
route: route.CopyUpTo(i + 2),
|
||||
expansion: expansion,
|
||||
authOp: authOp,
|
||||
})
|
||||
|
||||
// Save previous pin for next loop or end.
|
||||
previousHop = hop.Pin()
|
||||
previousTerminal = expansion
|
||||
}
|
||||
|
||||
// Check results.
|
||||
for _, check := range hopChecks {
|
||||
switch {
|
||||
case check.authOp != nil:
|
||||
// Wait for authOp result.
|
||||
select {
|
||||
case tErr := <-check.authOp.Result:
|
||||
switch {
|
||||
case tErr.IsError():
|
||||
// There was a network or authentication error.
|
||||
check.pin.MarkAsFailingFor(3 * time.Minute)
|
||||
log.Warningf("spn/crew: failed to auth to %s: %s", check.pin.Hub, tErr)
|
||||
return nil, nil, tErr.Wrap("failed to authenticate to %s: %w", check.pin.Hub, tErr)
|
||||
|
||||
case tErr.Is(terminal.ErrExplicitAck):
|
||||
// Authentication was successful.
|
||||
|
||||
default:
|
||||
// Authentication was aborted.
|
||||
if tErr != nil {
|
||||
tErr = terminal.ErrUnknownError
|
||||
}
|
||||
log.Warningf("spn/crew: auth to %s aborted with %s", check.pin.Hub, tErr)
|
||||
return nil, nil, tErr.Wrap("authentication to %s aborted: %w", check.pin.Hub, tErr)
|
||||
}
|
||||
|
||||
case <-time.After(5 * time.Second):
|
||||
// Mark as failing for just a minute, until server load may be less.
|
||||
check.pin.MarkAsFailingFor(1 * time.Minute)
|
||||
log.Warningf("spn/crew: auth to %s timed out", check.pin.Hub)
|
||||
|
||||
return nil, nil, terminal.ErrTimeout.With("waiting for auth to %s", check.pin.Hub)
|
||||
}
|
||||
|
||||
// Add terminal extension to the map.
|
||||
check.pin.SetActiveTerminal(&navigator.PinConnection{
|
||||
Terminal: check.expansion,
|
||||
Route: check.route,
|
||||
})
|
||||
check.expansion.MarkReachable()
|
||||
log.Infof("spn/crew: added conn to %s via %s", check.pin, check.route)
|
||||
|
||||
case check.pingOp != nil:
|
||||
// Wait for ping result.
|
||||
select {
|
||||
case tErr := <-check.pingOp.Result:
|
||||
if !tErr.Is(terminal.ErrExplicitAck) {
|
||||
// Mark as failing long enough to expire connections and session and shutdown connections.
|
||||
// TODO: Should we forcibly disconnect instead?
|
||||
// TODO: This might also be triggered if a relay fails and ends the operation.
|
||||
check.pin.MarkAsFailingFor(7 * time.Minute)
|
||||
// Forget about existing active terminal, re-create if needed.
|
||||
check.pin.SetActiveTerminal(nil)
|
||||
log.Warningf("spn/crew: failed to check reachability of %s: %s", check.pin.Hub, tErr)
|
||||
|
||||
return nil, nil, tErr.Wrap("failed to check reachability of %s: %w", check.pin.Hub, tErr)
|
||||
}
|
||||
|
||||
case <-time.After(5 * time.Second):
|
||||
// Mark as failing for just a minute, until server load may be less.
|
||||
check.pin.MarkAsFailingFor(1 * time.Minute)
|
||||
// Forget about existing active terminal, re-create if needed.
|
||||
check.pin.SetActiveTerminal(nil)
|
||||
log.Warningf("spn/crew: reachability check to %s timed out", check.pin.Hub)
|
||||
|
||||
return nil, nil, terminal.ErrTimeout.With("waiting for ping to %s", check.pin.Hub)
|
||||
}
|
||||
|
||||
check.expansion.MarkReachable()
|
||||
log.Debugf("spn/crew: checked conn to %s via %s", check.pin.Hub, check.route)
|
||||
|
||||
default:
|
||||
log.Errorf("spn/crew: invalid hop check for %s", check.pin.Hub)
|
||||
return nil, nil, terminal.ErrInternalError.With("invalid hop check")
|
||||
}
|
||||
}
|
||||
|
||||
// Return last hop.
|
||||
return previousHop, previousTerminal, nil
|
||||
}
|
||||
|
||||
func expand(fromTerminal terminal.Terminal, from, to *navigator.Pin) (expansion *docks.ExpansionTerminal, authOp *access.AuthorizeOp, tErr *terminal.Error) {
|
||||
expansion, tErr = docks.ExpandTo(fromTerminal, to.Hub.ID, to.Hub)
|
||||
if tErr != nil {
|
||||
return nil, nil, tErr.Wrap("failed to expand to %s", to.Hub)
|
||||
}
|
||||
|
||||
authOp, tErr = access.AuthorizeToTerminal(expansion)
|
||||
if tErr != nil {
|
||||
expansion.Abandon(nil)
|
||||
return nil, nil, tErr.Wrap("failed to authorize")
|
||||
}
|
||||
|
||||
log.Infof("spn/crew: expanded to %s (from %s)", to.Hub, from.Hub)
|
||||
return expansion, authOp, nil
|
||||
}
|
||||
|
||||
// TunnelContext holds additional information about the tunnel to be added to a
|
||||
// connection.
|
||||
type TunnelContext struct {
|
||||
Path []*TunnelContextHop
|
||||
PathCost float32
|
||||
RoutingAlg string
|
||||
|
||||
tunnel *Tunnel
|
||||
}
|
||||
|
||||
// GetExitNodeID returns the ID of the exit node.
|
||||
// It returns an empty string in case no path exists.
|
||||
func (tc *TunnelContext) GetExitNodeID() string {
|
||||
if len(tc.Path) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return tc.Path[len(tc.Path)-1].ID
|
||||
}
|
||||
|
||||
// StopTunnel stops the tunnel.
|
||||
func (tc *TunnelContext) StopTunnel() error {
|
||||
if tc.tunnel != nil && tc.tunnel.conn != nil {
|
||||
return tc.tunnel.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TunnelContextHop holds hop data for TunnelContext.
|
||||
type TunnelContextHop struct {
|
||||
ID string
|
||||
Name string
|
||||
IPv4 *TunnelContextHopIPInfo `json:",omitempty"`
|
||||
IPv6 *TunnelContextHopIPInfo `json:",omitempty"`
|
||||
}
|
||||
|
||||
// TunnelContextHopIPInfo holds hop IP data for TunnelContextHop.
|
||||
type TunnelContextHopIPInfo struct {
|
||||
IP net.IP
|
||||
Country string
|
||||
ASN uint
|
||||
ASOwner string
|
||||
}
|
||||
|
||||
func addTunnelContextToConnection(t *Tunnel) {
|
||||
// Create and add basic info.
|
||||
tunnelCtx := &TunnelContext{
|
||||
Path: make([]*TunnelContextHop, len(t.route.Path)),
|
||||
PathCost: t.route.TotalCost,
|
||||
RoutingAlg: t.route.Algorithm,
|
||||
tunnel: t,
|
||||
}
|
||||
t.connInfo.TunnelContext = tunnelCtx
|
||||
|
||||
// Add path info.
|
||||
for i, hop := range t.route.Path {
|
||||
// Add hub info.
|
||||
hopCtx := &TunnelContextHop{
|
||||
ID: hop.HubID,
|
||||
Name: hop.Pin().Hub.Info.Name,
|
||||
}
|
||||
tunnelCtx.Path[i] = hopCtx
|
||||
// Add hub IPv4 info.
|
||||
if hop.Pin().Hub.Info.IPv4 != nil {
|
||||
hopCtx.IPv4 = &TunnelContextHopIPInfo{
|
||||
IP: hop.Pin().Hub.Info.IPv4,
|
||||
}
|
||||
if hop.Pin().LocationV4 != nil {
|
||||
hopCtx.IPv4.Country = hop.Pin().LocationV4.Country.Code
|
||||
hopCtx.IPv4.ASN = hop.Pin().LocationV4.AutonomousSystemNumber
|
||||
hopCtx.IPv4.ASOwner = hop.Pin().LocationV4.AutonomousSystemOrganization
|
||||
}
|
||||
}
|
||||
// Add hub IPv6 info.
|
||||
if hop.Pin().Hub.Info.IPv6 != nil {
|
||||
hopCtx.IPv6 = &TunnelContextHopIPInfo{
|
||||
IP: hop.Pin().Hub.Info.IPv6,
|
||||
}
|
||||
if hop.Pin().LocationV6 != nil {
|
||||
hopCtx.IPv6.Country = hop.Pin().LocationV6.Country.Code
|
||||
hopCtx.IPv6.ASN = hop.Pin().LocationV6.AutonomousSystemNumber
|
||||
hopCtx.IPv6.ASOwner = hop.Pin().LocationV6.AutonomousSystemOrganization
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
223
spn/crew/metrics.go
Normal file
223
spn/crew/metrics.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/metrics"
|
||||
)
|
||||
|
||||
var (
|
||||
connectOpCnt *metrics.Counter
|
||||
connectOpCntError *metrics.Counter
|
||||
connectOpCntBadRequest *metrics.Counter
|
||||
connectOpCntCanceled *metrics.Counter
|
||||
connectOpCntFailed *metrics.Counter
|
||||
connectOpCntConnected *metrics.Counter
|
||||
connectOpCntRateLimited *metrics.Counter
|
||||
|
||||
connectOpIncomingBytes *metrics.Counter
|
||||
connectOpOutgoingBytes *metrics.Counter
|
||||
|
||||
connectOpTTCRDurationHistogram *metrics.Histogram
|
||||
connectOpTTFBDurationHistogram *metrics.Histogram
|
||||
connectOpDurationHistogram *metrics.Histogram
|
||||
connectOpIncomingDataHistogram *metrics.Histogram
|
||||
connectOpOutgoingDataHistogram *metrics.Histogram
|
||||
|
||||
metricsRegistered = abool.New()
|
||||
)
|
||||
|
||||
func registerMetrics() (err error) {
|
||||
// Only register metrics once.
|
||||
if !metricsRegistered.SetToIf(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect Op Stats on client.
|
||||
|
||||
connectOpCnt, err = metrics.NewCounter(
|
||||
"spn/op/connect/total",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Total Connect Operations",
|
||||
InternalID: "spn_connect_count",
|
||||
Permission: api.PermitUser,
|
||||
Persist: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Connect Op Stats on server.
|
||||
|
||||
connectOpCntOptions := &metrics.Options{
|
||||
Name: "SPN Total Connect Operations",
|
||||
Permission: api.PermitUser,
|
||||
Persist: true,
|
||||
}
|
||||
|
||||
connectOpCntError, err = metrics.NewCounter(
|
||||
"spn/op/connect/total",
|
||||
map[string]string{"result": "error"},
|
||||
connectOpCntOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpCntBadRequest, err = metrics.NewCounter(
|
||||
"spn/op/connect/total",
|
||||
map[string]string{"result": "bad_request"},
|
||||
connectOpCntOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpCntCanceled, err = metrics.NewCounter(
|
||||
"spn/op/connect/total",
|
||||
map[string]string{"result": "canceled"},
|
||||
connectOpCntOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpCntFailed, err = metrics.NewCounter(
|
||||
"spn/op/connect/total",
|
||||
map[string]string{"result": "failed"},
|
||||
connectOpCntOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpCntConnected, err = metrics.NewCounter(
|
||||
"spn/op/connect/total",
|
||||
map[string]string{"result": "connected"},
|
||||
connectOpCntOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpCntRateLimited, err = metrics.NewCounter(
|
||||
"spn/op/connect/total",
|
||||
map[string]string{"result": "rate_limited"},
|
||||
connectOpCntOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/op/connect/active",
|
||||
nil,
|
||||
getActiveConnectOpsStat,
|
||||
&metrics.Options{
|
||||
Name: "SPN Active Connect Operations",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpIncomingBytes, err = metrics.NewCounter(
|
||||
"spn/op/connect/incoming/bytes",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Connect Operation Incoming Bytes",
|
||||
InternalID: "spn_connect_in_bytes",
|
||||
Permission: api.PermitUser,
|
||||
Persist: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpOutgoingBytes, err = metrics.NewCounter(
|
||||
"spn/op/connect/outgoing/bytes",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Connect Operation Outgoing Bytes",
|
||||
InternalID: "spn_connect_out_bytes",
|
||||
Permission: api.PermitUser,
|
||||
Persist: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpTTCRDurationHistogram, err = metrics.NewHistogram(
|
||||
"spn/op/connect/histogram/ttcr/seconds",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Connect Operation time-to-connect-request Histogram",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpTTFBDurationHistogram, err = metrics.NewHistogram(
|
||||
"spn/op/connect/histogram/ttfb/seconds",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Connect Operation time-to-first-byte (from TTCR) Histogram",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpDurationHistogram, err = metrics.NewHistogram(
|
||||
"spn/op/connect/histogram/duration/seconds",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Connect Operation Duration Histogram",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpIncomingDataHistogram, err = metrics.NewHistogram(
|
||||
"spn/op/connect/histogram/incoming/bytes",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Connect Operation Downloaded Data Histogram",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectOpOutgoingDataHistogram, err = metrics.NewHistogram(
|
||||
"spn/op/connect/histogram/outgoing/bytes",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Connect Operation Outgoing Data Histogram",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getActiveConnectOpsStat() float64 {
|
||||
return float64(atomic.LoadInt64(activeConnectOps))
|
||||
}
|
||||
44
spn/crew/module.go
Normal file
44
spn/crew/module.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
var module *modules.Module
|
||||
|
||||
func init() {
|
||||
module = modules.Register("crew", nil, start, stop, "terminal", "docks", "navigator", "intel", "cabin")
|
||||
}
|
||||
|
||||
func start() error {
|
||||
module.NewTask("sticky cleaner", cleanStickyHubs).
|
||||
Repeat(10 * time.Minute)
|
||||
|
||||
return registerMetrics()
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
clearStickyHubs()
|
||||
terminal.StopScheduler()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var connectErrors = make(chan *terminal.Error, 10)
|
||||
|
||||
func reportConnectError(tErr *terminal.Error) {
|
||||
select {
|
||||
case connectErrors <- tErr:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectErrors returns errors of connect operations.
|
||||
// It only has a small and shared buffer and may only be used for indications,
|
||||
// not for full monitoring.
|
||||
func ConnectErrors() <-chan *terminal.Error {
|
||||
return connectErrors
|
||||
}
|
||||
13
spn/crew/module_test.go
Normal file
13
spn/crew/module_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/service/core/pmtesting"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
conf.EnablePublicHub(true)
|
||||
pmtesting.TestMain(m, module)
|
||||
}
|
||||
585
spn/crew/op_connect.go
Normal file
585
spn/crew/op_connect.go
Normal file
@@ -0,0 +1,585 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// ConnectOpType is the type ID for the connection operation.
|
||||
const ConnectOpType string = "connect"
|
||||
|
||||
var activeConnectOps = new(int64)
|
||||
|
||||
// ConnectOp is used to connect data tunnels to servers on the Internet.
|
||||
type ConnectOp struct {
|
||||
terminal.OperationBase
|
||||
|
||||
// Flow Control
|
||||
dfq *terminal.DuplexFlowQueue
|
||||
|
||||
// Context and shutdown handling
|
||||
// ctx is the context of the Terminal.
|
||||
ctx context.Context
|
||||
// cancelCtx cancels ctx.
|
||||
cancelCtx context.CancelFunc
|
||||
// doneWriting signals that the writer has finished writing.
|
||||
doneWriting chan struct{}
|
||||
|
||||
// Metrics
|
||||
incomingTraffic atomic.Uint64
|
||||
outgoingTraffic atomic.Uint64
|
||||
started time.Time
|
||||
|
||||
// Connection
|
||||
t terminal.Terminal
|
||||
conn net.Conn
|
||||
request *ConnectRequest
|
||||
entry bool
|
||||
tunnel *Tunnel
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *ConnectOp) Type() string {
|
||||
return ConnectOpType
|
||||
}
|
||||
|
||||
// Ctx returns the operation context.
|
||||
func (op *ConnectOp) Ctx() context.Context {
|
||||
return op.ctx
|
||||
}
|
||||
|
||||
// ConnectRequest holds all the information necessary for a connect operation.
|
||||
type ConnectRequest struct {
|
||||
Domain string `json:"d,omitempty"`
|
||||
IP net.IP `json:"ip,omitempty"`
|
||||
UsePriorityDataMsgs bool `json:"pr,omitempty"`
|
||||
Protocol packet.IPProtocol `json:"p,omitempty"`
|
||||
Port uint16 `json:"po,omitempty"`
|
||||
QueueSize uint32 `json:"qs,omitempty"`
|
||||
}
|
||||
|
||||
// DialNetwork returns the address of the connect request.
|
||||
func (r *ConnectRequest) DialNetwork() string {
|
||||
if ip4 := r.IP.To4(); ip4 != nil {
|
||||
switch r.Protocol { //nolint:exhaustive // Only looking for supported protocols.
|
||||
case packet.TCP:
|
||||
return "tcp4"
|
||||
case packet.UDP:
|
||||
return "udp4"
|
||||
}
|
||||
} else {
|
||||
switch r.Protocol { //nolint:exhaustive // Only looking for supported protocols.
|
||||
case packet.TCP:
|
||||
return "tcp6"
|
||||
case packet.UDP:
|
||||
return "udp6"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Address returns the address of the connext request.
|
||||
func (r *ConnectRequest) Address() string {
|
||||
return net.JoinHostPort(r.IP.String(), strconv.Itoa(int(r.Port)))
|
||||
}
|
||||
|
||||
func (r *ConnectRequest) String() string {
|
||||
if r.Domain != "" {
|
||||
return fmt.Sprintf("%s (%s %s)", r.Domain, r.Protocol, r.Address())
|
||||
}
|
||||
return fmt.Sprintf("%s %s", r.Protocol, r.Address())
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: ConnectOpType,
|
||||
Requires: terminal.MayConnect,
|
||||
Start: startConnectOp,
|
||||
})
|
||||
}
|
||||
|
||||
// NewConnectOp starts a new connect operation.
|
||||
func NewConnectOp(tunnel *Tunnel) (*ConnectOp, *terminal.Error) {
|
||||
// Submit metrics.
|
||||
connectOpCnt.Inc()
|
||||
|
||||
// Create request.
|
||||
request := &ConnectRequest{
|
||||
Domain: tunnel.connInfo.Entity.Domain,
|
||||
IP: tunnel.connInfo.Entity.IP,
|
||||
Protocol: packet.IPProtocol(tunnel.connInfo.Entity.Protocol),
|
||||
Port: tunnel.connInfo.Entity.Port,
|
||||
UsePriorityDataMsgs: terminal.UsePriorityDataMsgs,
|
||||
}
|
||||
|
||||
// Set defaults.
|
||||
if request.QueueSize == 0 {
|
||||
request.QueueSize = terminal.DefaultQueueSize
|
||||
}
|
||||
|
||||
// Create new op.
|
||||
op := &ConnectOp{
|
||||
doneWriting: make(chan struct{}),
|
||||
t: tunnel.dstTerminal,
|
||||
conn: tunnel.conn,
|
||||
request: request,
|
||||
entry: true,
|
||||
tunnel: tunnel,
|
||||
}
|
||||
op.ctx, op.cancelCtx = context.WithCancel(module.Ctx)
|
||||
op.dfq = terminal.NewDuplexFlowQueue(op.Ctx(), request.QueueSize, op.submitUpstream)
|
||||
|
||||
// Prepare init msg.
|
||||
data, err := dsd.Dump(request, dsd.CBOR)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to pack connect request: %w", err)
|
||||
}
|
||||
|
||||
// Initialize.
|
||||
tErr := op.t.StartOperation(op, container.New(data), 5*time.Second)
|
||||
if err != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
// Setup metrics.
|
||||
op.started = time.Now()
|
||||
|
||||
module.StartWorker("connect op conn reader", op.connReader)
|
||||
module.StartWorker("connect op conn writer", op.connWriter)
|
||||
module.StartWorker("connect op flow handler", op.dfq.FlowHandler)
|
||||
|
||||
log.Infof("spn/crew: connected to %s via %s", request, tunnel.dstPin.Hub)
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func startConnectOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Check if we are running a public hub.
|
||||
if !conf.PublicHub() {
|
||||
return nil, terminal.ErrPermissionDenied.With("connecting is only allowed on public hubs")
|
||||
}
|
||||
|
||||
// Parse connect request.
|
||||
request := &ConnectRequest{}
|
||||
_, err := dsd.Load(data.CompileData(), request)
|
||||
if err != nil {
|
||||
connectOpCntError.Inc() // More like a protocol/system error than a bad request.
|
||||
return nil, terminal.ErrMalformedData.With("failed to parse connect request: %w", err)
|
||||
}
|
||||
if request.QueueSize == 0 || request.QueueSize > terminal.MaxQueueSize {
|
||||
connectOpCntError.Inc() // More like a protocol/system error than a bad request.
|
||||
return nil, terminal.ErrInvalidOptions.With("invalid queue size of %d", request.QueueSize)
|
||||
}
|
||||
|
||||
// Check if IP seems valid.
|
||||
if len(request.IP) != net.IPv4len && len(request.IP) != net.IPv6len {
|
||||
connectOpCntError.Inc() // More like a protocol/system error than a bad request.
|
||||
return nil, terminal.ErrInvalidOptions.With("ip address is not valid")
|
||||
}
|
||||
|
||||
// Create and initialize operation.
|
||||
op := &ConnectOp{
|
||||
doneWriting: make(chan struct{}),
|
||||
t: t,
|
||||
request: request,
|
||||
}
|
||||
op.InitOperationBase(t, opID)
|
||||
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
|
||||
op.dfq = terminal.NewDuplexFlowQueue(op.Ctx(), request.QueueSize, op.submitUpstream)
|
||||
|
||||
// Start worker to complete setting up the connection.
|
||||
module.StartWorker("connect op setup", op.handleSetup)
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (op *ConnectOp) handleSetup(_ context.Context) error {
|
||||
// Get terminal session for rate limiting.
|
||||
var session *terminal.Session
|
||||
if sessionTerm, ok := op.t.(terminal.SessionTerminal); ok {
|
||||
session = sessionTerm.GetSession()
|
||||
} else {
|
||||
connectOpCntError.Inc()
|
||||
log.Errorf("spn/crew: %T is not a session terminal, aborting op %s#%d", op.t, op.t.FmtID(), op.ID())
|
||||
op.Stop(op, terminal.ErrInternalError.With("no session available"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Limit concurrency of connecting.
|
||||
cancelErr := session.LimitConcurrency(op.Ctx(), func() {
|
||||
op.setup(session)
|
||||
})
|
||||
|
||||
// If context was canceled, stop operation.
|
||||
if cancelErr != nil {
|
||||
connectOpCntCanceled.Inc()
|
||||
op.Stop(op, terminal.ErrCanceled.With(cancelErr.Error()))
|
||||
}
|
||||
|
||||
// Do not return a worker error.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (op *ConnectOp) setup(session *terminal.Session) {
|
||||
// Rate limit before connecting.
|
||||
if tErr := session.RateLimit(); tErr != nil {
|
||||
// Add rate limit info to error.
|
||||
if tErr.Is(terminal.ErrRateLimited) {
|
||||
connectOpCntRateLimited.Inc()
|
||||
op.Stop(op, tErr.With(session.RateLimitInfo()))
|
||||
return
|
||||
}
|
||||
|
||||
connectOpCntError.Inc()
|
||||
op.Stop(op, tErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if connection target is in global scope.
|
||||
ipScope := netutils.GetIPScope(op.request.IP)
|
||||
if ipScope != netutils.Global {
|
||||
session.ReportSuspiciousActivity(terminal.SusFactorQuiteUnusual)
|
||||
connectOpCntBadRequest.Inc()
|
||||
op.Stop(op, terminal.ErrPermissionDenied.With("denied request to connect to non-global IP %s", op.request.IP))
|
||||
return
|
||||
}
|
||||
|
||||
// Check exit policy.
|
||||
if tErr := checkExitPolicy(op.request); tErr != nil {
|
||||
session.ReportSuspiciousActivity(terminal.SusFactorQuiteUnusual)
|
||||
connectOpCntBadRequest.Inc()
|
||||
op.Stop(op, tErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Check one last time before connecting if operation was not canceled.
|
||||
if op.Ctx().Err() != nil {
|
||||
op.Stop(op, terminal.ErrCanceled.With(op.Ctx().Err().Error()))
|
||||
connectOpCntCanceled.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Connect to destination.
|
||||
dialNet := op.request.DialNetwork()
|
||||
if dialNet == "" {
|
||||
session.ReportSuspiciousActivity(terminal.SusFactorCommon)
|
||||
connectOpCntBadRequest.Inc()
|
||||
op.Stop(op, terminal.ErrIncorrectUsage.With("protocol %s is not supported", op.request.Protocol))
|
||||
return
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
LocalAddr: conf.GetBindAddr(dialNet),
|
||||
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
|
||||
KeepAlive: -1, // Disable keep-alive.
|
||||
}
|
||||
conn, err := dialer.DialContext(op.Ctx(), dialNet, op.request.Address())
|
||||
if err != nil {
|
||||
// Connection errors are common, but still a bit suspicious.
|
||||
var netError net.Error
|
||||
switch {
|
||||
case errors.As(err, &netError) && netError.Timeout():
|
||||
session.ReportSuspiciousActivity(terminal.SusFactorCommon)
|
||||
connectOpCntFailed.Inc()
|
||||
case errors.Is(err, context.Canceled):
|
||||
session.ReportSuspiciousActivity(terminal.SusFactorCommon)
|
||||
connectOpCntCanceled.Inc()
|
||||
default:
|
||||
session.ReportSuspiciousActivity(terminal.SusFactorWeirdButOK)
|
||||
connectOpCntFailed.Inc()
|
||||
}
|
||||
|
||||
op.Stop(op, terminal.ErrConnectionError.With("failed to connect to %s: %w", op.request, err))
|
||||
return
|
||||
}
|
||||
op.conn = conn
|
||||
|
||||
// Start worker.
|
||||
module.StartWorker("connect op conn reader", op.connReader)
|
||||
module.StartWorker("connect op conn writer", op.connWriter)
|
||||
module.StartWorker("connect op flow handler", op.dfq.FlowHandler)
|
||||
|
||||
connectOpCntConnected.Inc()
|
||||
log.Infof("spn/crew: connected op %s#%d to %s", op.t.FmtID(), op.ID(), op.request)
|
||||
}
|
||||
|
||||
func (op *ConnectOp) submitUpstream(msg *terminal.Msg, timeout time.Duration) {
|
||||
err := op.Send(msg, timeout)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
op.Stop(op, err.Wrap("failed to send data (op) read from %s", op.connectedType()))
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
readBufSize = 1500
|
||||
|
||||
// High priority up to first 10MB.
|
||||
highPrioThreshold = 10_000_000
|
||||
|
||||
// Rate limit to 128 Mbit/s after 1GB traffic.
|
||||
// Do NOT use time.Sleep per packet, as it is very inaccurate and will sleep a lot longer than desired.
|
||||
rateLimitThreshold = 1_000_000_000
|
||||
rateLimitMaxMbit = 128
|
||||
)
|
||||
|
||||
func (op *ConnectOp) connReader(_ context.Context) error {
|
||||
// Metrics setup and submitting.
|
||||
atomic.AddInt64(activeConnectOps, 1)
|
||||
defer func() {
|
||||
atomic.AddInt64(activeConnectOps, -1)
|
||||
connectOpDurationHistogram.UpdateDuration(op.started)
|
||||
connectOpIncomingDataHistogram.Update(float64(op.incomingTraffic.Load()))
|
||||
}()
|
||||
|
||||
rateLimiter := terminal.NewRateLimiter(rateLimitMaxMbit)
|
||||
|
||||
for {
|
||||
// Read from connection.
|
||||
buf := make([]byte, readBufSize)
|
||||
n, err := op.conn.Read(buf)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
op.Stop(op, terminal.ErrStopping.With("connection to %s was closed on read", op.connectedType()))
|
||||
} else {
|
||||
op.Stop(op, terminal.ErrConnectionError.With("failed to read from %s: %w", op.connectedType(), err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if n == 0 {
|
||||
log.Tracef("spn/crew: connect op %s>%d read 0 bytes from %s", op.t.FmtID(), op.ID(), op.connectedType())
|
||||
continue
|
||||
}
|
||||
|
||||
// Submit metrics.
|
||||
connectOpIncomingBytes.Add(n)
|
||||
inBytes := op.incomingTraffic.Add(uint64(n))
|
||||
|
||||
// Rate limit if over threshold.
|
||||
if inBytes > rateLimitThreshold {
|
||||
rateLimiter.Limit(uint64(n))
|
||||
}
|
||||
|
||||
// Create message from data.
|
||||
msg := op.NewMsg(buf[:n])
|
||||
|
||||
// Define priority and possibly wait for slot.
|
||||
switch {
|
||||
case inBytes > highPrioThreshold:
|
||||
msg.Unit.WaitForSlot()
|
||||
case op.request.UsePriorityDataMsgs:
|
||||
msg.Unit.MakeHighPriority()
|
||||
}
|
||||
|
||||
// Send packet.
|
||||
tErr := op.dfq.Send(
|
||||
msg,
|
||||
30*time.Second,
|
||||
)
|
||||
if tErr != nil {
|
||||
msg.Finish()
|
||||
op.Stop(op, tErr.Wrap("failed to send data (dfq) from %s", op.connectedType()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver delivers a messages to the operation.
|
||||
func (op *ConnectOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
return op.dfq.Deliver(msg)
|
||||
}
|
||||
|
||||
func (op *ConnectOp) connWriter(_ context.Context) error {
|
||||
// Metrics submitting.
|
||||
defer func() {
|
||||
connectOpOutgoingDataHistogram.Update(float64(op.outgoingTraffic.Load()))
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
// Signal that we are done with writing.
|
||||
close(op.doneWriting)
|
||||
// Close connection.
|
||||
_ = op.conn.Close()
|
||||
}()
|
||||
|
||||
var msg *terminal.Msg
|
||||
defer msg.Finish()
|
||||
|
||||
rateLimiter := terminal.NewRateLimiter(rateLimitMaxMbit)
|
||||
|
||||
writing:
|
||||
for {
|
||||
msg.Finish()
|
||||
|
||||
select {
|
||||
case msg = <-op.dfq.Receive():
|
||||
case <-op.ctx.Done():
|
||||
op.Stop(op, terminal.ErrCanceled)
|
||||
return nil
|
||||
default:
|
||||
// Handle all data before also listening for the context cancel.
|
||||
// This ensures all data is written properly before stopping.
|
||||
select {
|
||||
case msg = <-op.dfq.Receive():
|
||||
case op.doneWriting <- struct{}{}:
|
||||
op.Stop(op, terminal.ErrStopping)
|
||||
return nil
|
||||
case <-op.ctx.Done():
|
||||
op.Stop(op, terminal.ErrCanceled)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Instead of compiling data here again, can we send it as in the container?
|
||||
data := msg.Data.CompileData()
|
||||
if len(data) == 0 {
|
||||
continue writing
|
||||
}
|
||||
|
||||
// Submit metrics.
|
||||
connectOpOutgoingBytes.Add(len(data))
|
||||
out := op.outgoingTraffic.Add(uint64(len(data)))
|
||||
|
||||
// Rate limit if over threshold.
|
||||
if out > rateLimitThreshold {
|
||||
rateLimiter.Limit(uint64(len(data)))
|
||||
}
|
||||
|
||||
// Special handling after first data was received on client.
|
||||
if op.entry &&
|
||||
out == uint64(len(data)) {
|
||||
// Report time taken to receive first byte.
|
||||
connectOpTTFBDurationHistogram.UpdateDuration(op.started)
|
||||
|
||||
// If not stickied yet, stick destination to Hub.
|
||||
if !op.tunnel.stickied {
|
||||
op.tunnel.stickDestinationToHub()
|
||||
}
|
||||
}
|
||||
|
||||
// Send all given data.
|
||||
for {
|
||||
n, err := op.conn.Write(data)
|
||||
switch {
|
||||
case err != nil:
|
||||
if errors.Is(err, io.EOF) {
|
||||
op.Stop(op, terminal.ErrStopping.With("connection to %s was closed on write", op.connectedType()))
|
||||
} else {
|
||||
op.Stop(op, terminal.ErrConnectionError.With("failed to send to %s: %w", op.connectedType(), err))
|
||||
}
|
||||
return nil
|
||||
case n == 0:
|
||||
op.Stop(op, terminal.ErrConnectionError.With("sent 0 bytes to %s", op.connectedType()))
|
||||
return nil
|
||||
case n < len(data):
|
||||
// If not all data was sent, try again.
|
||||
log.Debugf("spn/crew: %s#%d only sent %d/%d bytes to %s", op.t.FmtID(), op.ID(), n, len(data), op.connectedType())
|
||||
data = data[n:]
|
||||
default:
|
||||
continue writing
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (op *ConnectOp) connectedType() string {
|
||||
if op.entry {
|
||||
return "origin"
|
||||
}
|
||||
return "destination"
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *ConnectOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
if err.IsError() {
|
||||
reportConnectError(err)
|
||||
}
|
||||
|
||||
// If the connection has sent or received any data so far, finish the data
|
||||
// flows as it makes sense.
|
||||
if op.incomingTraffic.Load() > 0 || op.outgoingTraffic.Load() > 0 {
|
||||
// If the op was ended locally, send all data before closing.
|
||||
// If the op was ended remotely, don't bother sending remaining data.
|
||||
if !err.IsExternal() {
|
||||
// Flushing could mean sending a full buffer of 50000 packets.
|
||||
op.dfq.Flush(5 * time.Minute)
|
||||
}
|
||||
|
||||
// If the op was ended remotely, write all remaining received data.
|
||||
// If the op was ended locally, don't bother writing remaining data.
|
||||
if err.IsExternal() {
|
||||
select {
|
||||
case <-op.doneWriting:
|
||||
default:
|
||||
select {
|
||||
case <-op.doneWriting:
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel workers.
|
||||
op.cancelCtx()
|
||||
|
||||
// Special client-side handling.
|
||||
if op.entry {
|
||||
// Mark the connection as failed if there was an error and no data was sent to the app yet.
|
||||
if err.IsError() && op.outgoingTraffic.Load() == 0 {
|
||||
// Set connection to failed and save it to propagate the update.
|
||||
c := op.tunnel.connInfo
|
||||
func() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if err.IsExternal() {
|
||||
c.Failed(fmt.Sprintf(
|
||||
"the exit node reported an error: %s", err,
|
||||
), "")
|
||||
} else {
|
||||
c.Failed(fmt.Sprintf(
|
||||
"connection failed locally: %s", err,
|
||||
), "")
|
||||
}
|
||||
|
||||
c.Save()
|
||||
}()
|
||||
}
|
||||
|
||||
// Avoid connecting to the destination via this Hub if:
|
||||
// - The error is external - ie. from the server.
|
||||
// - The error is a connection error.
|
||||
// - No data was received.
|
||||
// This indicates that there is some network level issue that we can
|
||||
// possibly work around by using another exit node.
|
||||
if err.IsError() && err.IsExternal() &&
|
||||
err.Is(terminal.ErrConnectionError) &&
|
||||
op.outgoingTraffic.Load() == 0 {
|
||||
op.tunnel.avoidDestinationHub()
|
||||
}
|
||||
|
||||
// Don't leak local errors to the server.
|
||||
if !err.IsExternal() {
|
||||
// Change error that is reported.
|
||||
return terminal.ErrStopping
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
115
spn/crew/op_connect_test.go
Normal file
115
spn/crew/op_connect_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/intel"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const (
|
||||
testPadding = 8
|
||||
testQueueSize = 10
|
||||
)
|
||||
|
||||
func TestConnectOp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode, as it interacts with the network")
|
||||
}
|
||||
|
||||
// Create test terminal pair.
|
||||
a, b, err := terminal.NewSimpleTestTerminalPair(0, 0,
|
||||
&terminal.TerminalOpts{
|
||||
FlowControl: terminal.FlowControlDFQ,
|
||||
FlowControlSize: testQueueSize,
|
||||
Padding: testPadding,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test terminal pair: %s", err)
|
||||
}
|
||||
|
||||
// Set up connect op.
|
||||
b.GrantPermission(terminal.MayConnect)
|
||||
conf.EnablePublicHub(true)
|
||||
identity, err := cabin.CreateIdentity(module.Ctx, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create identity: %s", err)
|
||||
}
|
||||
_, err = identity.MaintainAnnouncement(&hub.Announcement{
|
||||
Transports: []string{
|
||||
"tcp:17",
|
||||
},
|
||||
Exit: []string{
|
||||
"+ * */80",
|
||||
"- *",
|
||||
},
|
||||
}, true)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to update identity: %s", err)
|
||||
}
|
||||
EnableConnecting(identity.Hub)
|
||||
|
||||
for i := 0; i < 1; i++ {
|
||||
appConn, sluiceConn := net.Pipe()
|
||||
_, tErr := NewConnectOp(&Tunnel{
|
||||
connInfo: &network.Connection{
|
||||
Entity: (&intel.Entity{
|
||||
Protocol: 6,
|
||||
Port: 80,
|
||||
Domain: "orf.at.",
|
||||
IP: net.IPv4(194, 232, 104, 142),
|
||||
}).Init(0),
|
||||
},
|
||||
conn: sluiceConn,
|
||||
dstTerminal: a,
|
||||
dstPin: &navigator.Pin{
|
||||
Hub: identity.Hub,
|
||||
},
|
||||
})
|
||||
if tErr != nil {
|
||||
t.Fatalf("failed to start connect op: %s", tErr)
|
||||
}
|
||||
|
||||
// Send request.
|
||||
requestURL, err := url.Parse("http://orf.at/")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse request url: %s", err)
|
||||
}
|
||||
r := http.Request{
|
||||
Method: http.MethodHead,
|
||||
URL: requestURL,
|
||||
}
|
||||
err = r.Write(appConn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write request: %s", err)
|
||||
}
|
||||
|
||||
// Recv response.
|
||||
data := make([]byte, 1500)
|
||||
n, err := appConn.Read(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read request: %s", err)
|
||||
}
|
||||
if n == 0 {
|
||||
t.Fatal("received empty reply")
|
||||
}
|
||||
|
||||
t.Log("received data:")
|
||||
fmt.Println(string(data[:n]))
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
149
spn/crew/op_ping.go
Normal file
149
spn/crew/op_ping.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const (
|
||||
// PingOpType is the type ID of the latency test operation.
|
||||
PingOpType = "ping"
|
||||
|
||||
pingOpNonceSize = 16
|
||||
pingOpTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// PingOp is used to measure latency.
|
||||
type PingOp struct {
|
||||
terminal.OneOffOperationBase
|
||||
|
||||
started time.Time
|
||||
nonce []byte
|
||||
}
|
||||
|
||||
// PingOpRequest is a ping request.
|
||||
type PingOpRequest struct {
|
||||
Nonce []byte `json:"n,omitempty"`
|
||||
}
|
||||
|
||||
// PingOpResponse is a ping response.
|
||||
type PingOpResponse struct {
|
||||
Nonce []byte `json:"n,omitempty"`
|
||||
Time time.Time `json:"t,omitempty"`
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *PingOp) Type() string {
|
||||
return PingOpType
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: PingOpType,
|
||||
Start: startPingOp,
|
||||
})
|
||||
}
|
||||
|
||||
// NewPingOp runs a latency test.
|
||||
func NewPingOp(t terminal.Terminal) (*PingOp, *terminal.Error) {
|
||||
// Generate nonce.
|
||||
nonce, err := rng.Bytes(pingOpNonceSize)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to generate ping nonce: %w", err)
|
||||
}
|
||||
|
||||
// Create operation and init.
|
||||
op := &PingOp{
|
||||
started: time.Now().UTC(),
|
||||
nonce: nonce,
|
||||
}
|
||||
op.OneOffOperationBase.Init()
|
||||
|
||||
// Create request.
|
||||
pingRequest, err := dsd.Dump(&PingOpRequest{
|
||||
Nonce: op.nonce,
|
||||
}, dsd.CBOR)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to create ping request: %w", err)
|
||||
}
|
||||
|
||||
// Send ping.
|
||||
tErr := t.StartOperation(op, container.New(pingRequest), pingOpTimeout)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
func (op *PingOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
defer msg.Finish()
|
||||
|
||||
// Parse response.
|
||||
response := &PingOpResponse{}
|
||||
_, err := dsd.Load(msg.Data.CompileData(), response)
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to parse ping response: %w", err)
|
||||
}
|
||||
|
||||
// Check if the nonce matches.
|
||||
if subtle.ConstantTimeCompare(op.nonce, response.Nonce) != 1 {
|
||||
return terminal.ErrIntegrity.With("ping nonce mismatched")
|
||||
}
|
||||
|
||||
return terminal.ErrExplicitAck
|
||||
}
|
||||
|
||||
func startPingOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Parse request.
|
||||
request := &PingOpRequest{}
|
||||
_, err := dsd.Load(data.CompileData(), request)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrMalformedData.With("failed to parse ping request: %w", err)
|
||||
}
|
||||
|
||||
// Create response.
|
||||
response, err := dsd.Dump(&PingOpResponse{
|
||||
Nonce: request.Nonce,
|
||||
Time: time.Now().UTC(),
|
||||
}, dsd.CBOR)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to create ping response: %w", err)
|
||||
}
|
||||
|
||||
// Send response.
|
||||
msg := terminal.NewMsg(response)
|
||||
msg.FlowID = opID
|
||||
msg.Unit.MakeHighPriority()
|
||||
if terminal.UsePriorityDataMsgs {
|
||||
msg.Type = terminal.MsgTypePriorityData
|
||||
}
|
||||
tErr := t.Send(msg, pingOpTimeout)
|
||||
if tErr != nil {
|
||||
// Finish message unit on failure.
|
||||
msg.Finish()
|
||||
return nil, tErr.With("failed to send ping response")
|
||||
}
|
||||
|
||||
// Operation is just one response and finished successfully.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *PingOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
// Prevent remote from sending explicit ack, as we use it as a success signal internally.
|
||||
if err.Is(terminal.ErrExplicitAck) && err.IsExternal() {
|
||||
err = terminal.ErrStopping.AsExternal()
|
||||
}
|
||||
|
||||
// Continue with usual handling of inherited base.
|
||||
return op.OneOffOperationBase.HandleStop(err)
|
||||
}
|
||||
32
spn/crew/op_ping_test.go
Normal file
32
spn/crew/op_ping_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
func TestPingOp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create test terminal pair.
|
||||
a, _, err := terminal.NewSimpleTestTerminalPair(0, 0, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test terminal pair: %s", err)
|
||||
}
|
||||
|
||||
// Create ping op.
|
||||
op, tErr := NewPingOp(a)
|
||||
if tErr.IsError() {
|
||||
t.Fatal(tErr)
|
||||
}
|
||||
|
||||
// Wait for result.
|
||||
select {
|
||||
case result := <-op.Result:
|
||||
t.Logf("ping result: %s", result.Error())
|
||||
case <-time.After(pingOpTimeout):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
}
|
||||
51
spn/crew/policy.go
Normal file
51
spn/crew/policy.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/service/intel"
|
||||
"github.com/safing/portmaster/service/profile/endpoints"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
var (
|
||||
connectingHubLock sync.Mutex
|
||||
connectingHub *hub.Hub
|
||||
)
|
||||
|
||||
// EnableConnecting enables connecting from this Hub.
|
||||
func EnableConnecting(my *hub.Hub) {
|
||||
connectingHubLock.Lock()
|
||||
defer connectingHubLock.Unlock()
|
||||
|
||||
connectingHub = my
|
||||
}
|
||||
|
||||
func checkExitPolicy(request *ConnectRequest) *terminal.Error {
|
||||
connectingHubLock.Lock()
|
||||
defer connectingHubLock.Unlock()
|
||||
|
||||
// Check if connect requests are allowed.
|
||||
if connectingHub == nil {
|
||||
return terminal.ErrPermissionDenied.With("connect requests disabled")
|
||||
}
|
||||
|
||||
// Create entity.
|
||||
entity := (&intel.Entity{
|
||||
IP: request.IP,
|
||||
Protocol: uint8(request.Protocol),
|
||||
Port: request.Port,
|
||||
Domain: request.Domain,
|
||||
}).Init(0)
|
||||
entity.FetchData(context.TODO())
|
||||
|
||||
// Check against policy.
|
||||
result, reason := connectingHub.GetInfo().ExitPolicy().Match(context.TODO(), entity)
|
||||
if result == endpoints.Denied {
|
||||
return terminal.ErrPermissionDenied.With("connect request for %s violates the exit policy: %s", request, reason)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
176
spn/crew/sticky.go
Normal file
176
spn/crew/sticky.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package crew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/service/network"
|
||||
"github.com/safing/portmaster/service/network/packet"
|
||||
"github.com/safing/portmaster/spn/navigator"
|
||||
)
|
||||
|
||||
const (
|
||||
stickyTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
stickyIPs = make(map[string]*stickyHub)
|
||||
stickyDomains = make(map[string]*stickyHub)
|
||||
stickyLock sync.Mutex
|
||||
)
|
||||
|
||||
type stickyHub struct {
|
||||
Pin *navigator.Pin
|
||||
Route *navigator.Route
|
||||
LastSeen time.Time
|
||||
Avoid bool
|
||||
}
|
||||
|
||||
func (sh *stickyHub) isExpired() bool {
|
||||
return time.Now().Add(-stickyTTL).After(sh.LastSeen)
|
||||
}
|
||||
|
||||
func makeStickyIPKey(conn *network.Connection) string {
|
||||
if p := conn.Process().Profile(); p != nil {
|
||||
return fmt.Sprintf(
|
||||
"%s/%s>%s",
|
||||
p.LocalProfile().Source,
|
||||
p.LocalProfile().ID,
|
||||
conn.Entity.IP,
|
||||
)
|
||||
}
|
||||
|
||||
return "?>" + string(conn.Entity.IP)
|
||||
}
|
||||
|
||||
func makeStickyDomainKey(conn *network.Connection) string {
|
||||
if p := conn.Process().Profile(); p != nil {
|
||||
return fmt.Sprintf(
|
||||
"%s/%s>%s",
|
||||
p.LocalProfile().Source,
|
||||
p.LocalProfile().ID,
|
||||
conn.Entity.Domain,
|
||||
)
|
||||
}
|
||||
|
||||
return "?>" + conn.Entity.Domain
|
||||
}
|
||||
|
||||
func getStickiedHub(conn *network.Connection) (sticksTo *stickyHub) {
|
||||
stickyLock.Lock()
|
||||
defer stickyLock.Unlock()
|
||||
|
||||
// Check if IP is sticky.
|
||||
sticksTo = stickyIPs[makeStickyIPKey(conn)] // byte comparison
|
||||
if sticksTo != nil && !sticksTo.isExpired() {
|
||||
sticksTo.LastSeen = time.Now()
|
||||
}
|
||||
|
||||
// If the IP did not stick and we have a domain, check if that sticks.
|
||||
if sticksTo == nil && conn.Entity.Domain != "" {
|
||||
sticksTo, ok := stickyDomains[makeStickyDomainKey(conn)]
|
||||
if ok && !sticksTo.isExpired() {
|
||||
sticksTo.LastSeen = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing sticked, return now.
|
||||
if sticksTo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get intel from map before locking pin to avoid simultaneous locking.
|
||||
mapIntel := navigator.Main.GetIntel()
|
||||
|
||||
// Lock Pin for checking.
|
||||
sticksTo.Pin.Lock()
|
||||
defer sticksTo.Pin.Unlock()
|
||||
|
||||
// Check if the stickied Hub supports the needed IP version.
|
||||
switch {
|
||||
case conn.IPVersion == packet.IPv4 && sticksTo.Pin.EntityV4 == nil:
|
||||
// Connection is IPv4, but stickied Hub has no IPv4.
|
||||
return nil
|
||||
case conn.IPVersion == packet.IPv6 && sticksTo.Pin.EntityV6 == nil:
|
||||
// Connection is IPv4, but stickied Hub has no IPv4.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disregard stickied Hub if it is disregard with the current options.
|
||||
matcher := conn.TunnelOpts.Destination.Matcher(mapIntel)
|
||||
if !matcher(sticksTo.Pin) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return fully checked stickied Hub.
|
||||
return sticksTo
|
||||
}
|
||||
|
||||
func (t *Tunnel) stickDestinationToHub() {
|
||||
stickyLock.Lock()
|
||||
defer stickyLock.Unlock()
|
||||
|
||||
// Stick to IP.
|
||||
ipKey := makeStickyIPKey(t.connInfo)
|
||||
stickyIPs[ipKey] = &stickyHub{
|
||||
Pin: t.dstPin,
|
||||
Route: t.route,
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
log.Infof("spn/crew: sticking %s to %s", ipKey, t.dstPin.Hub)
|
||||
|
||||
// Stick to Domain, if present.
|
||||
if t.connInfo.Entity.Domain != "" {
|
||||
domainKey := makeStickyDomainKey(t.connInfo)
|
||||
stickyDomains[domainKey] = &stickyHub{
|
||||
Pin: t.dstPin,
|
||||
Route: t.route,
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
log.Infof("spn/crew: sticking %s to %s", domainKey, t.dstPin.Hub)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) avoidDestinationHub() {
|
||||
stickyLock.Lock()
|
||||
defer stickyLock.Unlock()
|
||||
|
||||
// Stick to Hub/IP Pair.
|
||||
ipKey := makeStickyIPKey(t.connInfo)
|
||||
stickyIPs[ipKey] = &stickyHub{
|
||||
Pin: t.dstPin,
|
||||
LastSeen: time.Now(),
|
||||
Avoid: true,
|
||||
}
|
||||
log.Warningf("spn/crew: avoiding %s for %s", t.dstPin.Hub, ipKey)
|
||||
}
|
||||
|
||||
func cleanStickyHubs(ctx context.Context, task *modules.Task) error {
|
||||
stickyLock.Lock()
|
||||
defer stickyLock.Unlock()
|
||||
|
||||
for _, stickyRegistry := range []map[string]*stickyHub{stickyIPs, stickyDomains} {
|
||||
for key, stickedEntry := range stickyRegistry {
|
||||
if stickedEntry.isExpired() {
|
||||
delete(stickyRegistry, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func clearStickyHubs() {
|
||||
stickyLock.Lock()
|
||||
defer stickyLock.Unlock()
|
||||
|
||||
for _, stickyRegistry := range []map[string]*stickyHub{stickyIPs, stickyDomains} {
|
||||
for key := range stickyRegistry {
|
||||
delete(stickyRegistry, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
90
spn/docks/bandwidth_test.go
Normal file
90
spn/docks/bandwidth_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
func TestEffectiveBandwidth(t *testing.T) { //nolint:paralleltest // Run alone.
|
||||
// Skip in CI.
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
var (
|
||||
bwTestDelay = 50 * time.Millisecond
|
||||
bwTestQueueSize uint32 = 1000
|
||||
bwTestVolume = 10000000 // 10MB
|
||||
bwTestTime = 10 * time.Second
|
||||
)
|
||||
|
||||
// Create test terminal pair.
|
||||
a, b, err := terminal.NewSimpleTestTerminalPair(
|
||||
bwTestDelay,
|
||||
int(bwTestQueueSize),
|
||||
&terminal.TerminalOpts{
|
||||
FlowControl: terminal.FlowControlDFQ,
|
||||
FlowControlSize: bwTestQueueSize,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test terminal pair: %s", err)
|
||||
}
|
||||
|
||||
// Grant permission for op on remote terminal and start op.
|
||||
b.GrantPermission(terminal.IsCraneController)
|
||||
|
||||
// Re-use the capacity test for the bandwidth test.
|
||||
op := &CapacityTestOp{
|
||||
opts: &CapacityTestOptions{
|
||||
TestVolume: bwTestVolume,
|
||||
MaxTime: bwTestTime,
|
||||
testing: true,
|
||||
},
|
||||
recvQueue: make(chan *terminal.Msg),
|
||||
dataSent: new(int64),
|
||||
dataSentWasAckd: abool.New(),
|
||||
result: make(chan *terminal.Error, 1),
|
||||
}
|
||||
// Disable sender again.
|
||||
op.senderStarted = true
|
||||
op.dataSentWasAckd.Set()
|
||||
// Make capacity test request.
|
||||
request, err := dsd.Dump(op.opts, dsd.CBOR)
|
||||
if err != nil {
|
||||
t.Fatal(terminal.ErrInternalError.With("failed to serialize capactity test options: %w", err))
|
||||
}
|
||||
// Send test request.
|
||||
tErr := a.StartOperation(op, container.New(request), 1*time.Second)
|
||||
if tErr != nil {
|
||||
t.Fatal(tErr)
|
||||
}
|
||||
// Start handler.
|
||||
module.StartWorker("op capacity handler", op.handler)
|
||||
|
||||
// Wait for result and check error.
|
||||
tErr = <-op.Result()
|
||||
if !tErr.IsOK() {
|
||||
t.Fatalf("op failed: %s", tErr)
|
||||
}
|
||||
t.Logf("measured capacity: %d bit/s", op.testResult)
|
||||
|
||||
// Calculate expected bandwidth.
|
||||
expectedBitsPerSecond := (float64(capacityTestMsgSize*8*int64(bwTestQueueSize)) / float64(bwTestDelay)) * float64(time.Second)
|
||||
t.Logf("expected capacity: %f bit/s", expectedBitsPerSecond)
|
||||
|
||||
// Check if measured bandwidth is within parameters.
|
||||
if float64(op.testResult) > expectedBitsPerSecond*1.6 {
|
||||
t.Fatal("measured capacity too high")
|
||||
}
|
||||
// TODO: Check if we can raise this to at least 90%.
|
||||
if float64(op.testResult) < expectedBitsPerSecond*0.2 {
|
||||
t.Fatal("measured capacity too low")
|
||||
}
|
||||
}
|
||||
100
spn/docks/controller.go
Normal file
100
spn/docks/controller.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// CraneControllerTerminal is a terminal for the crane itself.
|
||||
type CraneControllerTerminal struct {
|
||||
*terminal.TerminalBase
|
||||
|
||||
Crane *Crane
|
||||
}
|
||||
|
||||
// NewLocalCraneControllerTerminal returns a new local crane controller.
|
||||
func NewLocalCraneControllerTerminal(
|
||||
crane *Crane,
|
||||
initMsg *terminal.TerminalOpts,
|
||||
) (*CraneControllerTerminal, *container.Container, *terminal.Error) {
|
||||
// Remove unnecessary options from the crane controller.
|
||||
initMsg.Padding = 0
|
||||
|
||||
// Create Terminal Base.
|
||||
t, initData, err := terminal.NewLocalBaseTerminal(
|
||||
crane.ctx,
|
||||
0,
|
||||
crane.ID,
|
||||
nil,
|
||||
initMsg,
|
||||
terminal.UpstreamSendFunc(crane.sendImportantTerminalMsg),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return initCraneController(crane, t, initMsg), initData, nil
|
||||
}
|
||||
|
||||
// NewRemoteCraneControllerTerminal returns a new remote crane controller.
|
||||
func NewRemoteCraneControllerTerminal(
|
||||
crane *Crane,
|
||||
initData *container.Container,
|
||||
) (*CraneControllerTerminal, *terminal.TerminalOpts, *terminal.Error) {
|
||||
// Create Terminal Base.
|
||||
t, initMsg, err := terminal.NewRemoteBaseTerminal(
|
||||
crane.ctx,
|
||||
0,
|
||||
crane.ID,
|
||||
nil,
|
||||
initData,
|
||||
terminal.UpstreamSendFunc(crane.sendImportantTerminalMsg),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return initCraneController(crane, t, initMsg), initMsg, nil
|
||||
}
|
||||
|
||||
func initCraneController(
|
||||
crane *Crane,
|
||||
t *terminal.TerminalBase,
|
||||
initMsg *terminal.TerminalOpts,
|
||||
) *CraneControllerTerminal {
|
||||
// Create Crane Terminal and assign it as the extended Terminal.
|
||||
cct := &CraneControllerTerminal{
|
||||
TerminalBase: t,
|
||||
Crane: crane,
|
||||
}
|
||||
t.SetTerminalExtension(cct)
|
||||
|
||||
// Assign controller to crane.
|
||||
crane.Controller = cct
|
||||
crane.terminals[cct.ID()] = cct
|
||||
|
||||
// Copy the options to the crane itself.
|
||||
crane.opts = *initMsg
|
||||
|
||||
// Grant crane controller permission.
|
||||
t.GrantPermission(terminal.IsCraneController)
|
||||
|
||||
// Start workers.
|
||||
t.StartWorkers(module, "crane controller terminal")
|
||||
|
||||
return cct
|
||||
}
|
||||
|
||||
// HandleAbandon gives the terminal the ability to cleanly shut down.
|
||||
func (controller *CraneControllerTerminal) HandleAbandon(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
// Abandon terminal.
|
||||
controller.Crane.AbandonTerminal(0, err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleDestruction gives the terminal the ability to clean up.
|
||||
func (controller *CraneControllerTerminal) HandleDestruction(err *terminal.Error) {
|
||||
// Stop controlled crane.
|
||||
controller.Crane.Stop(nil)
|
||||
}
|
||||
913
spn/docks/crane.go
Normal file
913
spn/docks/crane.go
Normal file
@@ -0,0 +1,913 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/jess"
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const (
|
||||
// QOTD holds the quote of the day to return on idling unused connections.
|
||||
QOTD = "Privacy is not an option, and it shouldn't be the price we accept for just getting on the Internet.\nGary Kovacs\n"
|
||||
|
||||
// maxUnloadSize defines the maximum size of a message to unload.
|
||||
maxUnloadSize = 16384
|
||||
maxSegmentLength = 16384
|
||||
maxCraneStoppingDuration = 6 * time.Hour
|
||||
maxCraneStopDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// optimalMinLoadSize defines minimum for Crane.targetLoadSize.
|
||||
optimalMinLoadSize = 3072 // Targeting around 4096.
|
||||
|
||||
// loadingMaxWaitDuration is the maximum time a crane will wait for
|
||||
// additional data to send.
|
||||
loadingMaxWaitDuration = 5 * time.Millisecond
|
||||
)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrDone = errors.New("crane is done")
|
||||
)
|
||||
|
||||
// Crane is the primary duplexer and connection manager.
|
||||
type Crane struct {
|
||||
// ID is the ID of the Crane.
|
||||
ID string
|
||||
// opts holds options.
|
||||
opts terminal.TerminalOpts
|
||||
|
||||
// ctx is the context of the Terminal.
|
||||
ctx context.Context
|
||||
// cancelCtx cancels ctx.
|
||||
cancelCtx context.CancelFunc
|
||||
// stopping indicates if the Crane will be stopped soon. The Crane may still
|
||||
// be used until stopped, but must not be advertised anymore.
|
||||
stopping *abool.AtomicBool
|
||||
// stopped indicates if the Crane has been stopped. Whoever stopped the Crane
|
||||
// already took care of notifying everyone, so a silent fail is normally the
|
||||
// best response.
|
||||
stopped *abool.AtomicBool
|
||||
// authenticated indicates if there is has been any successful authentication.
|
||||
authenticated *abool.AtomicBool
|
||||
|
||||
// ConnectedHub is the identity of the remote Hub.
|
||||
ConnectedHub *hub.Hub
|
||||
// NetState holds the network optimization state.
|
||||
// It must always be set and the reference must not be changed.
|
||||
// Access to fields within are coordinated by itself.
|
||||
NetState *NetworkOptimizationState
|
||||
// identity is identity of this instance and is usually only populated on a server.
|
||||
identity *cabin.Identity
|
||||
|
||||
// jession is the jess session used for encryption.
|
||||
jession *jess.Session
|
||||
// jessionLock locks jession.
|
||||
jessionLock sync.Mutex
|
||||
|
||||
// Controller is the Crane's Controller Terminal.
|
||||
Controller *CraneControllerTerminal
|
||||
|
||||
// ship represents the underlying physical connection.
|
||||
ship ships.Ship
|
||||
// unloading moves containers from the ship to the crane.
|
||||
unloading chan *container.Container
|
||||
// loading moves containers from the crane to the ship.
|
||||
loading chan *container.Container
|
||||
// terminalMsgs holds containers from terminals waiting to be laoded.
|
||||
terminalMsgs chan *terminal.Msg
|
||||
// controllerMsgs holds important containers from terminals waiting to be laoded.
|
||||
controllerMsgs chan *terminal.Msg
|
||||
|
||||
// terminals holds all the connected terminals.
|
||||
terminals map[uint32]terminal.Terminal
|
||||
// terminalsLock locks terminals.
|
||||
terminalsLock sync.Mutex
|
||||
// nextTerminalID holds the next terminal ID.
|
||||
nextTerminalID uint32
|
||||
|
||||
// targetLoadSize defines the optimal loading size.
|
||||
targetLoadSize int
|
||||
}
|
||||
|
||||
// NewCrane returns a new crane.
|
||||
func NewCrane(ship ships.Ship, connectedHub *hub.Hub, id *cabin.Identity) (*Crane, error) {
|
||||
// Cranes always run in module context.
|
||||
ctx, cancelCtx := context.WithCancel(module.Ctx)
|
||||
|
||||
newCrane := &Crane{
|
||||
ctx: ctx,
|
||||
cancelCtx: cancelCtx,
|
||||
stopping: abool.NewBool(false),
|
||||
stopped: abool.NewBool(false),
|
||||
authenticated: abool.NewBool(false),
|
||||
|
||||
ConnectedHub: connectedHub,
|
||||
NetState: newNetworkOptimizationState(),
|
||||
identity: id,
|
||||
|
||||
ship: ship,
|
||||
unloading: make(chan *container.Container),
|
||||
loading: make(chan *container.Container, 100),
|
||||
terminalMsgs: make(chan *terminal.Msg, 100),
|
||||
controllerMsgs: make(chan *terminal.Msg, 100),
|
||||
|
||||
terminals: make(map[uint32]terminal.Terminal),
|
||||
}
|
||||
err := registerCrane(newCrane)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register crane: %w", err)
|
||||
}
|
||||
|
||||
// Shift next terminal IDs on the server.
|
||||
if !ship.IsMine() {
|
||||
newCrane.nextTerminalID += 4
|
||||
}
|
||||
|
||||
// Calculate target load size.
|
||||
loadSize := ship.LoadSize()
|
||||
if loadSize <= 0 {
|
||||
loadSize = ships.BaseMTU
|
||||
}
|
||||
newCrane.targetLoadSize = loadSize
|
||||
for newCrane.targetLoadSize < optimalMinLoadSize {
|
||||
newCrane.targetLoadSize += loadSize
|
||||
}
|
||||
// Subtract overhead needed for encryption.
|
||||
newCrane.targetLoadSize -= 25 // Manually tested for jess.SuiteWireV1
|
||||
// Subtract space needed for length encoding the final chunk.
|
||||
newCrane.targetLoadSize -= varint.EncodedSize(uint64(newCrane.targetLoadSize))
|
||||
|
||||
return newCrane, nil
|
||||
}
|
||||
|
||||
// IsMine returns whether the crane was started on this side.
|
||||
func (crane *Crane) IsMine() bool {
|
||||
return crane.ship.IsMine()
|
||||
}
|
||||
|
||||
// Public returns whether the crane has been published.
|
||||
func (crane *Crane) Public() bool {
|
||||
return crane.ship.Public()
|
||||
}
|
||||
|
||||
// IsStopping returns whether the crane is stopping.
|
||||
func (crane *Crane) IsStopping() bool {
|
||||
return crane.stopping.IsSet()
|
||||
}
|
||||
|
||||
// MarkStoppingRequested marks the crane as stopping requested.
|
||||
func (crane *Crane) MarkStoppingRequested() {
|
||||
crane.NetState.lock.Lock()
|
||||
defer crane.NetState.lock.Unlock()
|
||||
|
||||
if !crane.NetState.stoppingRequested {
|
||||
crane.NetState.stoppingRequested = true
|
||||
crane.startSyncStateOp()
|
||||
}
|
||||
}
|
||||
|
||||
// MarkStopping marks the crane as stopping.
|
||||
func (crane *Crane) MarkStopping() (stopping bool) {
|
||||
// Can only stop owned cranes.
|
||||
if !crane.IsMine() {
|
||||
return false
|
||||
}
|
||||
|
||||
if !crane.stopping.SetToIf(false, true) {
|
||||
return false
|
||||
}
|
||||
|
||||
crane.NetState.lock.Lock()
|
||||
defer crane.NetState.lock.Unlock()
|
||||
crane.NetState.markedStoppingAt = time.Now()
|
||||
|
||||
crane.startSyncStateOp()
|
||||
return true
|
||||
}
|
||||
|
||||
// AbortStopping aborts the stopping.
|
||||
func (crane *Crane) AbortStopping() (aborted bool) {
|
||||
aborted = crane.stopping.SetToIf(true, false)
|
||||
|
||||
crane.NetState.lock.Lock()
|
||||
defer crane.NetState.lock.Unlock()
|
||||
|
||||
abortedStoppingRequest := crane.NetState.stoppingRequested
|
||||
crane.NetState.stoppingRequested = false
|
||||
crane.NetState.markedStoppingAt = time.Time{}
|
||||
|
||||
// Sync if any state changed.
|
||||
if aborted || abortedStoppingRequest {
|
||||
crane.startSyncStateOp()
|
||||
}
|
||||
|
||||
return aborted
|
||||
}
|
||||
|
||||
// Authenticated returns whether the other side of the crane has authenticated
|
||||
// itself with an access code.
|
||||
func (crane *Crane) Authenticated() bool {
|
||||
return crane.authenticated.IsSet()
|
||||
}
|
||||
|
||||
// Publish publishes the connection as a lane.
|
||||
func (crane *Crane) Publish() error {
|
||||
// Check if crane is connected.
|
||||
if crane.ConnectedHub == nil {
|
||||
return fmt.Errorf("spn/docks: %s: cannot publish without defined connected hub", crane)
|
||||
}
|
||||
|
||||
// Submit metrics.
|
||||
if !crane.Public() {
|
||||
newPublicCranes.Inc()
|
||||
}
|
||||
|
||||
// Mark crane as public.
|
||||
maskedID := crane.ship.MaskAddress(crane.ship.RemoteAddr())
|
||||
crane.ship.MarkPublic()
|
||||
|
||||
// Assign crane to make it available to others.
|
||||
AssignCrane(crane.ConnectedHub.ID, crane)
|
||||
|
||||
log.Infof("spn/docks: %s (was %s) is now public", crane, maskedID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns ship's local address.
|
||||
func (crane *Crane) LocalAddr() net.Addr {
|
||||
return crane.ship.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns ship's local address.
|
||||
func (crane *Crane) RemoteAddr() net.Addr {
|
||||
return crane.ship.RemoteAddr()
|
||||
}
|
||||
|
||||
// Transport returns ship's transport.
|
||||
func (crane *Crane) Transport() *hub.Transport {
|
||||
return crane.ship.Transport()
|
||||
}
|
||||
|
||||
func (crane *Crane) getNextTerminalID() uint32 {
|
||||
crane.terminalsLock.Lock()
|
||||
defer crane.terminalsLock.Unlock()
|
||||
|
||||
for {
|
||||
// Bump to next ID.
|
||||
crane.nextTerminalID += 8
|
||||
|
||||
// Check if it's free.
|
||||
_, ok := crane.terminals[crane.nextTerminalID]
|
||||
if !ok {
|
||||
return crane.nextTerminalID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (crane *Crane) terminalCount() int {
|
||||
crane.terminalsLock.Lock()
|
||||
defer crane.terminalsLock.Unlock()
|
||||
|
||||
return len(crane.terminals)
|
||||
}
|
||||
|
||||
func (crane *Crane) getTerminal(id uint32) (t terminal.Terminal, ok bool) {
|
||||
crane.terminalsLock.Lock()
|
||||
defer crane.terminalsLock.Unlock()
|
||||
|
||||
t, ok = crane.terminals[id]
|
||||
return
|
||||
}
|
||||
|
||||
func (crane *Crane) setTerminal(t terminal.Terminal) {
|
||||
crane.terminalsLock.Lock()
|
||||
defer crane.terminalsLock.Unlock()
|
||||
|
||||
crane.terminals[t.ID()] = t
|
||||
}
|
||||
|
||||
func (crane *Crane) deleteTerminal(id uint32) (t terminal.Terminal, ok bool) {
|
||||
crane.terminalsLock.Lock()
|
||||
defer crane.terminalsLock.Unlock()
|
||||
|
||||
t, ok = crane.terminals[id]
|
||||
if ok {
|
||||
delete(crane.terminals, id)
|
||||
return t, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// AbandonTerminal abandons the terminal with the given ID.
|
||||
func (crane *Crane) AbandonTerminal(id uint32, err *terminal.Error) {
|
||||
// Get active terminal.
|
||||
t, ok := crane.deleteTerminal(id)
|
||||
if ok {
|
||||
// If the terminal was registered, abandon it.
|
||||
|
||||
// Log reason the terminal is ending. Override stopping error with nil.
|
||||
switch {
|
||||
case err == nil || err.IsOK():
|
||||
log.Debugf("spn/docks: %T %s is being abandoned", t, t.FmtID())
|
||||
case err.Is(terminal.ErrStopping):
|
||||
err = nil
|
||||
log.Debugf("spn/docks: %T %s is being abandoned by peer", t, t.FmtID())
|
||||
case err.Is(terminal.ErrNoActivity):
|
||||
err = nil
|
||||
log.Debugf("spn/docks: %T %s is being abandoned due to no activity", t, t.FmtID())
|
||||
default:
|
||||
log.Warningf("spn/docks: %T %s: %s", t, t.FmtID(), err)
|
||||
}
|
||||
|
||||
// Call the terminal's abandon function.
|
||||
t.Abandon(err)
|
||||
} else { //nolint:gocritic
|
||||
// When a crane terminal is abandoned, it calls crane.AbandonTerminal when
|
||||
// finished. This time, the terminal won't be in the registry anymore and
|
||||
// it finished shutting down, so we can now check if the crane needs to be
|
||||
// stopped.
|
||||
|
||||
// If the crane is stopping, check if we can stop.
|
||||
// We can stop when all terminals are abandoned or after a timeout.
|
||||
// FYI: The crane controller will always take up one slot.
|
||||
if crane.stopping.IsSet() &&
|
||||
crane.terminalCount() <= 1 {
|
||||
// Stop the crane in worker, so the caller can do some work.
|
||||
module.StartWorker("retire crane", func(_ context.Context) error {
|
||||
// Let enough time for the last errors to be sent, as terminals are abandoned in a goroutine.
|
||||
time.Sleep(3 * time.Second)
|
||||
crane.Stop(nil)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (crane *Crane) sendImportantTerminalMsg(msg *terminal.Msg, timeout time.Duration) *terminal.Error {
|
||||
select {
|
||||
case crane.controllerMsgs <- msg:
|
||||
return nil
|
||||
case <-crane.ctx.Done():
|
||||
msg.Finish()
|
||||
return terminal.ErrCanceled
|
||||
}
|
||||
}
|
||||
|
||||
// Send is used by others to send a message through the crane.
|
||||
func (crane *Crane) Send(msg *terminal.Msg, timeout time.Duration) *terminal.Error {
|
||||
select {
|
||||
case crane.terminalMsgs <- msg:
|
||||
return nil
|
||||
case <-crane.ctx.Done():
|
||||
msg.Finish()
|
||||
return terminal.ErrCanceled
|
||||
}
|
||||
}
|
||||
|
||||
func (crane *Crane) encrypt(shipment *container.Container) (encrypted *container.Container, err error) {
|
||||
// Skip if encryption is not enabled.
|
||||
if crane.jession == nil {
|
||||
return shipment, nil
|
||||
}
|
||||
|
||||
crane.jessionLock.Lock()
|
||||
defer crane.jessionLock.Unlock()
|
||||
|
||||
letter, err := crane.jession.Close(shipment.CompileData())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encrypted, err = letter.ToWire()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pack letter: %w", err)
|
||||
}
|
||||
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
func (crane *Crane) decrypt(shipment *container.Container) (decrypted *container.Container, err error) {
|
||||
// Skip if encryption is not enabled.
|
||||
if crane.jession == nil {
|
||||
return shipment, nil
|
||||
}
|
||||
|
||||
crane.jessionLock.Lock()
|
||||
defer crane.jessionLock.Unlock()
|
||||
|
||||
letter, err := jess.LetterFromWire(shipment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse letter: %w", err)
|
||||
}
|
||||
|
||||
decryptedData, err := crane.jession.Open(letter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return container.New(decryptedData), nil
|
||||
}
|
||||
|
||||
func (crane *Crane) unloader(workerCtx context.Context) error {
|
||||
// Unclean shutdown safeguard.
|
||||
defer crane.Stop(terminal.ErrUnknownError.With("unloader died"))
|
||||
|
||||
for {
|
||||
// Get first couple bytes to get the packet length.
|
||||
// 2 bytes are enough to encode 65535.
|
||||
// On the other hand, packets can be only 2 bytes small.
|
||||
lenBuf := make([]byte, 2)
|
||||
err := crane.unloadUntilFull(lenBuf)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
crane.Stop(terminal.ErrStopping.With("connection closed"))
|
||||
} else {
|
||||
crane.Stop(terminal.ErrInternalError.With("failed to unload: %w", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unpack length.
|
||||
containerLen, n, err := varint.Unpack64(lenBuf)
|
||||
if err != nil {
|
||||
crane.Stop(terminal.ErrMalformedData.With("failed to get container length: %w", err))
|
||||
return nil
|
||||
}
|
||||
switch {
|
||||
case containerLen <= 0:
|
||||
crane.Stop(terminal.ErrMalformedData.With("received empty container with length %d", containerLen))
|
||||
return nil
|
||||
case containerLen > maxUnloadSize:
|
||||
crane.Stop(terminal.ErrMalformedData.With("received oversized container with length %d", containerLen))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build shipment.
|
||||
var shipmentBuf []byte
|
||||
leftovers := len(lenBuf) - n
|
||||
|
||||
if leftovers == int(containerLen) {
|
||||
// We already have all the shipment data.
|
||||
shipmentBuf = lenBuf[n:]
|
||||
} else {
|
||||
// Create a shipment buffer, copy leftovers and read the rest from the connection.
|
||||
shipmentBuf = make([]byte, containerLen)
|
||||
if leftovers > 0 {
|
||||
copy(shipmentBuf, lenBuf[n:])
|
||||
}
|
||||
|
||||
// Read remaining shipment.
|
||||
err = crane.unloadUntilFull(shipmentBuf[leftovers:])
|
||||
if err != nil {
|
||||
crane.Stop(terminal.ErrInternalError.With("failed to unload: %w", err))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Submit to handler.
|
||||
select {
|
||||
case <-crane.ctx.Done():
|
||||
crane.Stop(nil)
|
||||
return nil
|
||||
case crane.unloading <- container.New(shipmentBuf):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (crane *Crane) unloadUntilFull(buf []byte) error {
|
||||
var bytesRead int
|
||||
for {
|
||||
// Get shipment from ship.
|
||||
n, err := crane.ship.UnloadTo(buf[bytesRead:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
log.Tracef("spn/docks: %s unloaded 0 bytes", crane)
|
||||
}
|
||||
bytesRead += n
|
||||
|
||||
// Return if buffer has been fully filled.
|
||||
if bytesRead == len(buf) {
|
||||
// Submit metrics.
|
||||
crane.submitCraneTrafficStats(bytesRead)
|
||||
crane.NetState.ReportTraffic(uint64(bytesRead), true)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (crane *Crane) handler(workerCtx context.Context) error {
|
||||
var partialShipment *container.Container
|
||||
var segmentLength uint32
|
||||
|
||||
// Unclean shutdown safeguard.
|
||||
defer crane.Stop(terminal.ErrUnknownError.With("handler died"))
|
||||
|
||||
handling:
|
||||
for {
|
||||
select {
|
||||
case <-crane.ctx.Done():
|
||||
crane.Stop(nil)
|
||||
return nil
|
||||
|
||||
case shipment := <-crane.unloading:
|
||||
// log.Debugf("spn/crane %s: before decrypt: %v ... %v", crane.ID, c.CompileData()[:10], c.CompileData()[c.Length()-10:])
|
||||
|
||||
// Decrypt shipment.
|
||||
shipment, err := crane.decrypt(shipment)
|
||||
if err != nil {
|
||||
crane.Stop(terminal.ErrIntegrity.With("failed to decrypt: %w", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process all segments/containers of the shipment.
|
||||
for shipment.HoldsData() {
|
||||
if partialShipment != nil {
|
||||
// Continue processing partial segment.
|
||||
// Append new shipment to previous partial segment.
|
||||
partialShipment.AppendContainer(shipment)
|
||||
shipment, partialShipment = partialShipment, nil
|
||||
}
|
||||
|
||||
// Get next segment length.
|
||||
if segmentLength == 0 {
|
||||
segmentLength, err = shipment.GetNextN32()
|
||||
if err != nil {
|
||||
if errors.Is(err, varint.ErrBufTooSmall) {
|
||||
// Continue handling when there is not yet enough data.
|
||||
partialShipment = shipment
|
||||
segmentLength = 0
|
||||
continue handling
|
||||
}
|
||||
|
||||
crane.Stop(terminal.ErrMalformedData.With("failed to get segment length: %w", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
if segmentLength == 0 {
|
||||
// Remainder is padding.
|
||||
continue handling
|
||||
}
|
||||
|
||||
// Check if the segment is within the boundary.
|
||||
if segmentLength > maxSegmentLength {
|
||||
crane.Stop(terminal.ErrMalformedData.With("received oversized segment with length %d", segmentLength))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have enough data for the segment.
|
||||
if uint32(shipment.Length()) < segmentLength {
|
||||
partialShipment = shipment
|
||||
continue handling
|
||||
}
|
||||
|
||||
// Get segment from shipment.
|
||||
segment, err := shipment.GetAsContainer(int(segmentLength))
|
||||
if err != nil {
|
||||
crane.Stop(terminal.ErrMalformedData.With("failed to get segment: %w", err))
|
||||
return nil
|
||||
}
|
||||
segmentLength = 0
|
||||
|
||||
// Get terminal ID and message type of segment.
|
||||
terminalID, terminalMsgType, err := terminal.ParseIDType(segment)
|
||||
if err != nil {
|
||||
crane.Stop(terminal.ErrMalformedData.With("failed to get terminal ID and msg type: %w", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
switch terminalMsgType {
|
||||
case terminal.MsgTypeInit:
|
||||
crane.establishTerminal(terminalID, segment)
|
||||
|
||||
case terminal.MsgTypeData, terminal.MsgTypePriorityData:
|
||||
// Get terminal and let it further handle the message.
|
||||
t, ok := crane.getTerminal(terminalID)
|
||||
if ok {
|
||||
// Create msg and set priority.
|
||||
msg := terminal.NewEmptyMsg()
|
||||
msg.FlowID = terminalID
|
||||
msg.Type = terminalMsgType
|
||||
msg.Data = segment
|
||||
if msg.Type == terminal.MsgTypePriorityData {
|
||||
msg.Unit.MakeHighPriority()
|
||||
}
|
||||
// Deliver to terminal.
|
||||
deliveryErr := t.Deliver(msg)
|
||||
if deliveryErr != nil {
|
||||
msg.Finish()
|
||||
// This is a hot path. Start a worker for abandoning the terminal.
|
||||
module.StartWorker("end terminal", func(_ context.Context) error {
|
||||
crane.AbandonTerminal(t.ID(), deliveryErr.Wrap("failed to deliver data"))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
} else {
|
||||
log.Tracef("spn/docks: %s received msg for unknown terminal %d", crane, terminalID)
|
||||
}
|
||||
|
||||
case terminal.MsgTypeStop:
|
||||
// Parse error.
|
||||
receivedErr, err := terminal.ParseExternalError(segment.CompileData())
|
||||
if err != nil {
|
||||
log.Warningf("spn/docks: %s failed to parse abandon error: %s", crane, err)
|
||||
receivedErr = terminal.ErrUnknownError.AsExternal()
|
||||
}
|
||||
// This is a hot path. Start a worker for abandoning the terminal.
|
||||
module.StartWorker("end terminal", func(_ context.Context) error {
|
||||
crane.AbandonTerminal(terminalID, receivedErr)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (crane *Crane) loader(workerCtx context.Context) (err error) {
|
||||
shipment := container.New()
|
||||
var partialShipment *container.Container
|
||||
var loadingTimer *time.Timer
|
||||
|
||||
// Unclean shutdown safeguard.
|
||||
defer crane.Stop(terminal.ErrUnknownError.With("loader died"))
|
||||
|
||||
// Return the loading wait channel if waiting.
|
||||
loadNow := func() <-chan time.Time {
|
||||
if loadingTimer != nil {
|
||||
return loadingTimer.C
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure any received message is finished
|
||||
var msg, firstMsg *terminal.Msg
|
||||
defer msg.Finish()
|
||||
defer firstMsg.Finish()
|
||||
|
||||
for {
|
||||
// Reset first message in shipment.
|
||||
firstMsg.Finish()
|
||||
firstMsg = nil
|
||||
|
||||
fillingShipment:
|
||||
for shipment.Length() < crane.targetLoadSize {
|
||||
// Gather segments until shipment is filled.
|
||||
|
||||
// Prioritize messages from the controller.
|
||||
select {
|
||||
case msg = <-crane.controllerMsgs:
|
||||
case <-crane.ctx.Done():
|
||||
crane.Stop(nil)
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Then listen for all.
|
||||
select {
|
||||
case msg = <-crane.controllerMsgs:
|
||||
case msg = <-crane.terminalMsgs:
|
||||
case <-loadNow():
|
||||
break fillingShipment
|
||||
case <-crane.ctx.Done():
|
||||
crane.Stop(nil)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Debug unit leaks.
|
||||
msg.Debug()
|
||||
|
||||
// Handle new message.
|
||||
if msg != nil {
|
||||
// Pack msg and add to segment.
|
||||
msg.Pack()
|
||||
newSegment := msg.Data
|
||||
|
||||
// Check if this is the first message.
|
||||
// This is the only message where we wait for a slot.
|
||||
if firstMsg == nil {
|
||||
firstMsg = msg
|
||||
firstMsg.Unit.WaitForSlot()
|
||||
} else {
|
||||
msg.Finish()
|
||||
}
|
||||
|
||||
// Check length.
|
||||
if newSegment.Length() > maxSegmentLength {
|
||||
log.Warningf("spn/docks: %s ignored oversized segment with length %d", crane, newSegment.Length())
|
||||
continue fillingShipment
|
||||
}
|
||||
|
||||
// Append to shipment.
|
||||
shipment.AppendContainer(newSegment)
|
||||
|
||||
// Set loading max wait timer on first segment.
|
||||
if loadingTimer == nil {
|
||||
loadingTimer = time.NewTimer(loadingMaxWaitDuration)
|
||||
}
|
||||
|
||||
} else if crane.stopped.IsSet() {
|
||||
// If there is no new segment, this might have been triggered by a
|
||||
// closed channel. Check if the crane is still active.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
sendingShipment:
|
||||
for {
|
||||
// Check if we are over the target load size and split the shipment.
|
||||
if shipment.Length() > crane.targetLoadSize {
|
||||
partialShipment, err = shipment.GetAsContainer(crane.targetLoadSize)
|
||||
if err != nil {
|
||||
crane.Stop(terminal.ErrInternalError.With("failed to split segment: %w", err))
|
||||
return nil
|
||||
}
|
||||
shipment, partialShipment = partialShipment, shipment
|
||||
}
|
||||
|
||||
// Load shipment.
|
||||
err = crane.load(shipment)
|
||||
if err != nil {
|
||||
crane.Stop(terminal.ErrShipSunk.With("failed to load shipment: %w", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset loading timer.
|
||||
loadingTimer = nil
|
||||
|
||||
// Continue loading with partial shipment, or a new one.
|
||||
if partialShipment != nil {
|
||||
// Continue loading with a partial previous shipment.
|
||||
shipment, partialShipment = partialShipment, nil
|
||||
|
||||
// If shipment is not big enough to send immediately, wait for more data.
|
||||
if shipment.Length() < crane.targetLoadSize {
|
||||
loadingTimer = time.NewTimer(loadingMaxWaitDuration)
|
||||
break sendingShipment
|
||||
}
|
||||
|
||||
} else {
|
||||
// Continue loading with new shipment.
|
||||
shipment = container.New()
|
||||
break sendingShipment
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (crane *Crane) load(c *container.Container) error {
|
||||
// Add Padding if needed.
|
||||
if crane.opts.Padding > 0 {
|
||||
paddingNeeded := int(crane.opts.Padding) -
|
||||
((c.Length() + varint.EncodedSize(uint64(c.Length()))) % int(crane.opts.Padding))
|
||||
// As the length changes slightly with the padding, we should avoid loading
|
||||
// lengths around the varint size hops:
|
||||
// - 128
|
||||
// - 16384
|
||||
// - 2097152
|
||||
// - 268435456
|
||||
|
||||
// Pad to target load size at maximum.
|
||||
maxPadding := crane.targetLoadSize - c.Length()
|
||||
if paddingNeeded > maxPadding {
|
||||
paddingNeeded = maxPadding
|
||||
}
|
||||
|
||||
if paddingNeeded > 0 {
|
||||
// Add padding indicator.
|
||||
c.Append([]byte{0})
|
||||
paddingNeeded--
|
||||
|
||||
// Add needed padding data.
|
||||
if paddingNeeded > 0 {
|
||||
padding, err := rng.Bytes(paddingNeeded)
|
||||
if err != nil {
|
||||
log.Debugf("spn/docks: %s failed to get random padding data, using zeros instead", crane)
|
||||
padding = make([]byte, paddingNeeded)
|
||||
}
|
||||
c.Append(padding)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt shipment.
|
||||
c, err := crane.encrypt(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt: %w", err)
|
||||
}
|
||||
|
||||
// Finalize data.
|
||||
c.PrependLength()
|
||||
readyToSend := c.CompileData()
|
||||
|
||||
// Submit metrics.
|
||||
crane.submitCraneTrafficStats(len(readyToSend))
|
||||
crane.NetState.ReportTraffic(uint64(len(readyToSend)), false)
|
||||
|
||||
// Load onto ship.
|
||||
err = crane.ship.Load(readyToSend)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load ship: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the crane.
|
||||
func (crane *Crane) Stop(err *terminal.Error) {
|
||||
if !crane.stopped.SetToIf(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// Log error message.
|
||||
if err != nil {
|
||||
if err.IsOK() {
|
||||
log.Infof("spn/docks: %s is done", crane)
|
||||
} else {
|
||||
log.Warningf("spn/docks: %s is stopping: %s", crane, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Unregister crane.
|
||||
unregisterCrane(crane)
|
||||
|
||||
// Stop all terminals.
|
||||
for _, t := range crane.allTerms() {
|
||||
t.Abandon(err) // Async!
|
||||
}
|
||||
|
||||
// Stop controller.
|
||||
if crane.Controller != nil {
|
||||
crane.Controller.Abandon(err) // Async!
|
||||
}
|
||||
|
||||
// Wait shortly for all terminals to finish abandoning.
|
||||
waitStep := 50 * time.Millisecond
|
||||
for i := time.Duration(0); i < maxCraneStopDuration; i += waitStep {
|
||||
// Check if all terminals are done.
|
||||
if crane.terminalCount() == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(waitStep)
|
||||
}
|
||||
|
||||
// Close connection.
|
||||
crane.ship.Sink()
|
||||
|
||||
// Cancel crane context.
|
||||
crane.cancelCtx()
|
||||
|
||||
// Notify about change.
|
||||
crane.NotifyUpdate()
|
||||
}
|
||||
|
||||
func (crane *Crane) allTerms() []terminal.Terminal {
|
||||
crane.terminalsLock.Lock()
|
||||
defer crane.terminalsLock.Unlock()
|
||||
|
||||
terms := make([]terminal.Terminal, 0, len(crane.terminals))
|
||||
for _, term := range crane.terminals {
|
||||
terms = append(terms, term)
|
||||
}
|
||||
|
||||
return terms
|
||||
}
|
||||
|
||||
func (crane *Crane) String() string {
|
||||
remoteAddr := crane.ship.RemoteAddr()
|
||||
switch {
|
||||
case remoteAddr == nil:
|
||||
return fmt.Sprintf("crane %s", crane.ID)
|
||||
case crane.ship.IsMine():
|
||||
return fmt.Sprintf("crane %s to %s", crane.ID, crane.ship.MaskAddress(crane.ship.RemoteAddr()))
|
||||
default:
|
||||
return fmt.Sprintf("crane %s from %s", crane.ID, crane.ship.MaskAddress(crane.ship.RemoteAddr()))
|
||||
}
|
||||
}
|
||||
|
||||
// Stopped returns whether the crane has stopped.
|
||||
func (crane *Crane) Stopped() bool {
|
||||
return crane.stopped.IsSet()
|
||||
}
|
||||
81
spn/docks/crane_establish.go
Normal file
81
spn/docks/crane_establish.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTerminalIdleTimeout = 15 * time.Minute
|
||||
remoteTerminalIdleTimeout = 30 * time.Minute
|
||||
)
|
||||
|
||||
// EstablishNewTerminal establishes a new terminal with the crane.
|
||||
func (crane *Crane) EstablishNewTerminal(
|
||||
localTerm terminal.Terminal,
|
||||
initData *container.Container,
|
||||
) *terminal.Error {
|
||||
// Create message.
|
||||
msg := terminal.NewEmptyMsg()
|
||||
msg.FlowID = localTerm.ID()
|
||||
msg.Type = terminal.MsgTypeInit
|
||||
msg.Data = initData
|
||||
|
||||
// Register terminal with crane.
|
||||
crane.setTerminal(localTerm)
|
||||
|
||||
// Send message.
|
||||
select {
|
||||
case crane.controllerMsgs <- msg:
|
||||
log.Debugf("spn/docks: %s initiated new terminal %d", crane, localTerm.ID())
|
||||
return nil
|
||||
case <-crane.ctx.Done():
|
||||
crane.AbandonTerminal(localTerm.ID(), terminal.ErrStopping.With("initiation aborted"))
|
||||
return terminal.ErrStopping
|
||||
}
|
||||
}
|
||||
|
||||
func (crane *Crane) establishTerminal(id uint32, initData *container.Container) {
|
||||
// Create new remote crane terminal.
|
||||
newTerminal, _, err := NewRemoteCraneTerminal(
|
||||
crane,
|
||||
id,
|
||||
initData,
|
||||
)
|
||||
if err == nil {
|
||||
// Connections via public cranes have a timeout.
|
||||
if crane.Public() {
|
||||
newTerminal.TerminalBase.SetTimeout(remoteTerminalIdleTimeout)
|
||||
}
|
||||
// Register terminal with crane.
|
||||
crane.setTerminal(newTerminal)
|
||||
log.Debugf("spn/docks: %s established new crane terminal %d", crane, newTerminal.ID())
|
||||
return
|
||||
}
|
||||
|
||||
// If something goes wrong, send an error back.
|
||||
log.Warningf("spn/docks: %s failed to establish crane terminal: %s", crane, err)
|
||||
|
||||
// Build abandon message.
|
||||
msg := terminal.NewMsg(err.Pack())
|
||||
msg.FlowID = id
|
||||
msg.Type = terminal.MsgTypeStop
|
||||
|
||||
// Send message directly, or async.
|
||||
select {
|
||||
case crane.terminalMsgs <- msg:
|
||||
default:
|
||||
// Send error async.
|
||||
module.StartWorker("abandon terminal", func(ctx context.Context) error {
|
||||
select {
|
||||
case crane.terminalMsgs <- msg:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
339
spn/docks/crane_init.go
Normal file
339
spn/docks/crane_init.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/jess"
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portbase/info"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
Crane Init Message Format:
|
||||
used by init procedures
|
||||
|
||||
- Data [bytes block]
|
||||
- MsgType [varint]
|
||||
- Data [bytes; only when MsgType is Verify or Start*]
|
||||
|
||||
Crane Init Response Format:
|
||||
|
||||
- Data [bytes block]
|
||||
|
||||
Crane Operational Message Format:
|
||||
|
||||
- Data [bytes block]
|
||||
- possibly encrypted
|
||||
|
||||
*/
|
||||
|
||||
// Crane Msg Types.
|
||||
const (
|
||||
CraneMsgTypeEnd = 0
|
||||
CraneMsgTypeInfo = 1
|
||||
CraneMsgTypeRequestHubInfo = 2
|
||||
CraneMsgTypeVerify = 3
|
||||
CraneMsgTypeStartEncrypted = 4
|
||||
CraneMsgTypeStartUnencrypted = 5
|
||||
)
|
||||
|
||||
// Start starts the crane.
|
||||
func (crane *Crane) Start(callerCtx context.Context) error {
|
||||
log.Infof("spn/docks: %s is starting", crane)
|
||||
|
||||
// Submit metrics.
|
||||
newCranes.Inc()
|
||||
|
||||
// Start crane depending on situation.
|
||||
var tErr *terminal.Error
|
||||
if crane.ship.IsMine() {
|
||||
tErr = crane.startLocal(callerCtx)
|
||||
} else {
|
||||
tErr = crane.startRemote(callerCtx)
|
||||
}
|
||||
|
||||
// Stop crane again if starting failed.
|
||||
if tErr != nil {
|
||||
crane.Stop(tErr)
|
||||
return tErr
|
||||
}
|
||||
|
||||
log.Debugf("spn/docks: %s started", crane)
|
||||
// Return an explicit nil for working "!= nil" checks.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (crane *Crane) startLocal(callerCtx context.Context) *terminal.Error {
|
||||
module.StartWorker("crane unloader", crane.unloader)
|
||||
|
||||
if !crane.ship.IsSecure() {
|
||||
// Start encrypted channel.
|
||||
// Check if we have all the data we need from the Hub.
|
||||
if crane.ConnectedHub == nil {
|
||||
return terminal.ErrIncorrectUsage.With("cannot start encrypted channel without connected hub")
|
||||
}
|
||||
|
||||
// Always request hub info, as we don't know if the hub has restarted in
|
||||
// the meantime and lost ephemeral keys.
|
||||
hubInfoRequest := container.New(
|
||||
varint.Pack8(CraneMsgTypeRequestHubInfo),
|
||||
)
|
||||
hubInfoRequest.PrependLength()
|
||||
err := crane.ship.Load(hubInfoRequest.CompileData())
|
||||
if err != nil {
|
||||
return terminal.ErrShipSunk.With("failed to request hub info: %w", err)
|
||||
}
|
||||
|
||||
// Wait for reply.
|
||||
var reply *container.Container
|
||||
select {
|
||||
case reply = <-crane.unloading:
|
||||
case <-time.After(30 * time.Second):
|
||||
return terminal.ErrTimeout.With("waiting for hub info")
|
||||
case <-crane.ctx.Done():
|
||||
return terminal.ErrShipSunk.With("waiting for hub info")
|
||||
case <-callerCtx.Done():
|
||||
return terminal.ErrCanceled.With("waiting for hub info")
|
||||
}
|
||||
|
||||
// Parse and import Announcement and Status.
|
||||
announcementData, err := reply.GetNextBlock()
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to get announcement: %w", err)
|
||||
}
|
||||
statusData, err := reply.GetNextBlock()
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to get status: %w", err)
|
||||
}
|
||||
h, _, tErr := ImportAndVerifyHubInfo(
|
||||
callerCtx,
|
||||
crane.ConnectedHub.ID,
|
||||
announcementData, statusData, conf.MainMapName, conf.MainMapScope,
|
||||
)
|
||||
if tErr != nil {
|
||||
return tErr.Wrap("failed to import and verify hub")
|
||||
}
|
||||
// Update reference in case it was changed by the import.
|
||||
crane.ConnectedHub = h
|
||||
|
||||
// Now, try to select a public key again.
|
||||
signet := crane.ConnectedHub.SelectSignet()
|
||||
if signet == nil {
|
||||
return terminal.ErrHubNotReady.With("failed to select signet (after updating hub info)")
|
||||
}
|
||||
|
||||
// Configure encryption.
|
||||
env := jess.NewUnconfiguredEnvelope()
|
||||
env.SuiteID = jess.SuiteWireV1
|
||||
env.Recipients = []*jess.Signet{signet}
|
||||
|
||||
// Do not encrypt directly, rather get session for future use, then encrypt.
|
||||
crane.jession, err = env.WireCorrespondence(nil)
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to create encryption session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create crane controller.
|
||||
_, initData, tErr := NewLocalCraneControllerTerminal(crane, terminal.DefaultCraneControllerOpts())
|
||||
if tErr != nil {
|
||||
return tErr.Wrap("failed to set up controller")
|
||||
}
|
||||
|
||||
// Prepare init message for sending.
|
||||
if crane.ship.IsSecure() {
|
||||
initData.PrependNumber(CraneMsgTypeStartUnencrypted)
|
||||
} else {
|
||||
// Encrypt controller initializer.
|
||||
letter, err := crane.jession.Close(initData.CompileData())
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to encrypt initial packet: %w", err)
|
||||
}
|
||||
initData, err = letter.ToWire()
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to pack initial packet: %w", err)
|
||||
}
|
||||
initData.PrependNumber(CraneMsgTypeStartEncrypted)
|
||||
}
|
||||
|
||||
// Send start message.
|
||||
initData.PrependLength()
|
||||
err := crane.ship.Load(initData.CompileData())
|
||||
if err != nil {
|
||||
return terminal.ErrShipSunk.With("failed to send init msg: %w", err)
|
||||
}
|
||||
|
||||
// Start remaining workers.
|
||||
module.StartWorker("crane loader", crane.loader)
|
||||
module.StartWorker("crane handler", crane.handler)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (crane *Crane) startRemote(callerCtx context.Context) *terminal.Error {
|
||||
var initMsg *container.Container
|
||||
|
||||
module.StartWorker("crane unloader", crane.unloader)
|
||||
|
||||
handling:
|
||||
for {
|
||||
// Wait for request.
|
||||
var request *container.Container
|
||||
select {
|
||||
case request = <-crane.unloading:
|
||||
|
||||
case <-time.After(30 * time.Second):
|
||||
return terminal.ErrTimeout.With("waiting for crane init msg")
|
||||
case <-crane.ctx.Done():
|
||||
return terminal.ErrShipSunk.With("waiting for crane init msg")
|
||||
case <-callerCtx.Done():
|
||||
return terminal.ErrCanceled.With("waiting for crane init msg")
|
||||
}
|
||||
|
||||
msgType, err := request.GetNextN8()
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to parse crane msg type: %s", err)
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case CraneMsgTypeEnd:
|
||||
// End connection.
|
||||
return terminal.ErrStopping
|
||||
|
||||
case CraneMsgTypeInfo:
|
||||
// Info is a terminating request.
|
||||
err := crane.handleCraneInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("spn/docks: %s sent version info", crane)
|
||||
|
||||
case CraneMsgTypeRequestHubInfo:
|
||||
// Handle Hub info request.
|
||||
err := crane.handleCraneHubInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("spn/docks: %s sent hub info", crane)
|
||||
|
||||
case CraneMsgTypeVerify:
|
||||
// Verify is a terminating request.
|
||||
err := crane.handleCraneVerification(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("spn/docks: %s sent hub verification", crane)
|
||||
|
||||
case CraneMsgTypeStartUnencrypted:
|
||||
initMsg = request
|
||||
|
||||
// Start crane with initMsg.
|
||||
log.Debugf("spn/docks: %s initiated unencrypted channel", crane)
|
||||
break handling
|
||||
|
||||
case CraneMsgTypeStartEncrypted:
|
||||
if crane.identity == nil {
|
||||
return terminal.ErrIncorrectUsage.With("cannot start incoming crane without designated identity")
|
||||
}
|
||||
|
||||
// Set up encryption.
|
||||
letter, err := jess.LetterFromWire(container.New(request.CompileData()))
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to unpack initial packet: %w", err)
|
||||
}
|
||||
crane.jession, err = letter.WireCorrespondence(crane.identity)
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to create encryption session: %w", err)
|
||||
}
|
||||
initMsgData, err := crane.jession.Open(letter)
|
||||
if err != nil {
|
||||
return terminal.ErrIntegrity.With("failed to decrypt initial packet: %w", err)
|
||||
}
|
||||
initMsg = container.New(initMsgData)
|
||||
|
||||
// Start crane with initMsg.
|
||||
log.Debugf("spn/docks: %s initiated encrypted channel", crane)
|
||||
break handling
|
||||
}
|
||||
}
|
||||
|
||||
_, _, err := NewRemoteCraneControllerTerminal(crane, initMsg)
|
||||
if err != nil {
|
||||
return err.Wrap("failed to start crane controller")
|
||||
}
|
||||
|
||||
// Start remaining workers.
|
||||
module.StartWorker("crane loader", crane.loader)
|
||||
module.StartWorker("crane handler", crane.handler)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (crane *Crane) endInit() *terminal.Error {
|
||||
endMsg := container.New(
|
||||
varint.Pack8(CraneMsgTypeEnd),
|
||||
)
|
||||
endMsg.PrependLength()
|
||||
err := crane.ship.Load(endMsg.CompileData())
|
||||
if err != nil {
|
||||
return terminal.ErrShipSunk.With("failed to send end msg: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (crane *Crane) handleCraneInfo() *terminal.Error {
|
||||
// Pack info data.
|
||||
infoData, err := dsd.Dump(info.GetInfo(), dsd.JSON)
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to pack info: %w", err)
|
||||
}
|
||||
msg := container.New(infoData)
|
||||
|
||||
// Manually send reply.
|
||||
msg.PrependLength()
|
||||
err = crane.ship.Load(msg.CompileData())
|
||||
if err != nil {
|
||||
return terminal.ErrShipSunk.With("failed to send info reply: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (crane *Crane) handleCraneHubInfo() *terminal.Error {
|
||||
msg := container.New()
|
||||
|
||||
// Check if we have an identity.
|
||||
if crane.identity == nil {
|
||||
return terminal.ErrIncorrectUsage.With("cannot handle hub info request without designated identity")
|
||||
}
|
||||
|
||||
// Add Hub Announcement.
|
||||
announcementData, err := crane.identity.ExportAnnouncement()
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to export announcement: %w", err)
|
||||
}
|
||||
msg.AppendAsBlock(announcementData)
|
||||
|
||||
// Add Hub Status.
|
||||
statusData, err := crane.identity.ExportStatus()
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to export status: %w", err)
|
||||
}
|
||||
msg.AppendAsBlock(statusData)
|
||||
|
||||
// Manually send reply.
|
||||
msg.PrependLength()
|
||||
err = crane.ship.Load(msg.CompileData())
|
||||
if err != nil {
|
||||
return terminal.ErrShipSunk.With("failed to send hub info reply: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
131
spn/docks/crane_netstate.go
Normal file
131
spn/docks/crane_netstate.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NetStatePeriodInterval defines the interval some of the net state should be reset.
|
||||
const NetStatePeriodInterval = 15 * time.Minute
|
||||
|
||||
// NetworkOptimizationState holds data for optimization purposes.
|
||||
type NetworkOptimizationState struct {
|
||||
lock sync.Mutex
|
||||
|
||||
// lastSuggestedAt holds the time when the connection to the connected Hub was last suggested by the network optimization.
|
||||
lastSuggestedAt time.Time
|
||||
|
||||
// stoppingRequested signifies whether stopping this lane is requested.
|
||||
stoppingRequested bool
|
||||
// stoppingRequestedByPeer signifies whether stopping this lane is requested by the peer.
|
||||
stoppingRequestedByPeer bool
|
||||
// markedStoppingAt holds the time when the crane was last marked as stopping.
|
||||
markedStoppingAt time.Time
|
||||
|
||||
lifetimeBytesIn *uint64
|
||||
lifetimeBytesOut *uint64
|
||||
lifetimeStarted time.Time
|
||||
periodBytesIn *uint64
|
||||
periodBytesOut *uint64
|
||||
periodStarted time.Time
|
||||
}
|
||||
|
||||
func newNetworkOptimizationState() *NetworkOptimizationState {
|
||||
return &NetworkOptimizationState{
|
||||
lifetimeBytesIn: new(uint64),
|
||||
lifetimeBytesOut: new(uint64),
|
||||
lifetimeStarted: time.Now(),
|
||||
periodBytesIn: new(uint64),
|
||||
periodBytesOut: new(uint64),
|
||||
periodStarted: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLastSuggestedAt sets when the lane was last suggested to the current time.
|
||||
func (netState *NetworkOptimizationState) UpdateLastSuggestedAt() {
|
||||
netState.lock.Lock()
|
||||
defer netState.lock.Unlock()
|
||||
|
||||
netState.lastSuggestedAt = time.Now()
|
||||
}
|
||||
|
||||
// StoppingState returns when the stopping state.
|
||||
func (netState *NetworkOptimizationState) StoppingState() (requested, requestedByPeer bool, markedAt time.Time) {
|
||||
netState.lock.Lock()
|
||||
defer netState.lock.Unlock()
|
||||
|
||||
return netState.stoppingRequested, netState.stoppingRequestedByPeer, netState.markedStoppingAt
|
||||
}
|
||||
|
||||
// RequestStoppingSuggested returns whether the crane should request stopping.
|
||||
func (netState *NetworkOptimizationState) RequestStoppingSuggested(maxNotSuggestedDuration time.Duration) bool {
|
||||
netState.lock.Lock()
|
||||
defer netState.lock.Unlock()
|
||||
|
||||
return time.Now().Add(-maxNotSuggestedDuration).After(netState.lastSuggestedAt)
|
||||
}
|
||||
|
||||
// StoppingSuggested returns whether the crane should be marked as stopping.
|
||||
func (netState *NetworkOptimizationState) StoppingSuggested() bool {
|
||||
netState.lock.Lock()
|
||||
defer netState.lock.Unlock()
|
||||
|
||||
return netState.stoppingRequested &&
|
||||
netState.stoppingRequestedByPeer
|
||||
}
|
||||
|
||||
// StopSuggested returns whether the crane should be stopped.
|
||||
func (netState *NetworkOptimizationState) StopSuggested() bool {
|
||||
netState.lock.Lock()
|
||||
defer netState.lock.Unlock()
|
||||
|
||||
return netState.stoppingRequested &&
|
||||
netState.stoppingRequestedByPeer &&
|
||||
!netState.markedStoppingAt.IsZero() &&
|
||||
time.Now().Add(-maxCraneStoppingDuration).After(netState.markedStoppingAt)
|
||||
}
|
||||
|
||||
// ReportTraffic adds the reported transferred data to the traffic stats.
|
||||
func (netState *NetworkOptimizationState) ReportTraffic(bytes uint64, in bool) {
|
||||
if in {
|
||||
atomic.AddUint64(netState.lifetimeBytesIn, bytes)
|
||||
atomic.AddUint64(netState.periodBytesIn, bytes)
|
||||
} else {
|
||||
atomic.AddUint64(netState.lifetimeBytesOut, bytes)
|
||||
atomic.AddUint64(netState.periodBytesOut, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
// LapsePeriod lapses the net state period, if needed.
|
||||
func (netState *NetworkOptimizationState) LapsePeriod() {
|
||||
netState.lock.Lock()
|
||||
defer netState.lock.Unlock()
|
||||
|
||||
// Reset period if interval elapsed.
|
||||
if time.Now().Add(-NetStatePeriodInterval).After(netState.periodStarted) {
|
||||
atomic.StoreUint64(netState.periodBytesIn, 0)
|
||||
atomic.StoreUint64(netState.periodBytesOut, 0)
|
||||
netState.periodStarted = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// GetTrafficStats returns the traffic stats.
|
||||
func (netState *NetworkOptimizationState) GetTrafficStats() (
|
||||
lifetimeBytesIn uint64,
|
||||
lifetimeBytesOut uint64,
|
||||
lifetimeStarted time.Time,
|
||||
periodBytesIn uint64,
|
||||
periodBytesOut uint64,
|
||||
periodStarted time.Time,
|
||||
) {
|
||||
netState.lock.Lock()
|
||||
defer netState.lock.Unlock()
|
||||
|
||||
return atomic.LoadUint64(netState.lifetimeBytesIn),
|
||||
atomic.LoadUint64(netState.lifetimeBytesOut),
|
||||
netState.lifetimeStarted,
|
||||
atomic.LoadUint64(netState.periodBytesIn),
|
||||
atomic.LoadUint64(netState.periodBytesOut),
|
||||
netState.periodStarted
|
||||
}
|
||||
122
spn/docks/crane_terminal.go
Normal file
122
spn/docks/crane_terminal.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// CraneTerminal is a terminal started by a crane.
|
||||
type CraneTerminal struct {
|
||||
*terminal.TerminalBase
|
||||
|
||||
// Add-Ons
|
||||
terminal.SessionAddOn
|
||||
|
||||
crane *Crane
|
||||
}
|
||||
|
||||
// NewLocalCraneTerminal returns a new local crane terminal.
|
||||
func NewLocalCraneTerminal(
|
||||
crane *Crane,
|
||||
remoteHub *hub.Hub,
|
||||
initMsg *terminal.TerminalOpts,
|
||||
) (*CraneTerminal, *container.Container, *terminal.Error) {
|
||||
// Create Terminal Base.
|
||||
t, initData, err := terminal.NewLocalBaseTerminal(
|
||||
crane.ctx,
|
||||
crane.getNextTerminalID(),
|
||||
crane.ID,
|
||||
remoteHub,
|
||||
initMsg,
|
||||
crane,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return initCraneTerminal(crane, t), initData, nil
|
||||
}
|
||||
|
||||
// NewRemoteCraneTerminal returns a new remote crane terminal.
|
||||
func NewRemoteCraneTerminal(
|
||||
crane *Crane,
|
||||
id uint32,
|
||||
initData *container.Container,
|
||||
) (*CraneTerminal, *terminal.TerminalOpts, *terminal.Error) {
|
||||
// Create Terminal Base.
|
||||
t, initMsg, err := terminal.NewRemoteBaseTerminal(
|
||||
crane.ctx,
|
||||
id,
|
||||
crane.ID,
|
||||
crane.identity,
|
||||
initData,
|
||||
crane,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return initCraneTerminal(crane, t), initMsg, nil
|
||||
}
|
||||
|
||||
func initCraneTerminal(
|
||||
crane *Crane,
|
||||
t *terminal.TerminalBase,
|
||||
) *CraneTerminal {
|
||||
// Create Crane Terminal and assign it as the extended Terminal.
|
||||
ct := &CraneTerminal{
|
||||
TerminalBase: t,
|
||||
crane: crane,
|
||||
}
|
||||
t.SetTerminalExtension(ct)
|
||||
|
||||
// Start workers.
|
||||
t.StartWorkers(module, "crane terminal")
|
||||
|
||||
return ct
|
||||
}
|
||||
|
||||
// GrantPermission grants the given permissions.
|
||||
// Additionally, it will mark the crane as authenticated, if not public.
|
||||
func (t *CraneTerminal) GrantPermission(grant terminal.Permission) {
|
||||
// Forward granted permission to base terminal.
|
||||
t.TerminalBase.GrantPermission(grant)
|
||||
|
||||
// Mark crane as authenticated if not public or already authenticated.
|
||||
if !t.crane.Public() && !t.crane.Authenticated() {
|
||||
t.crane.authenticated.Set()
|
||||
|
||||
// Submit metrics.
|
||||
newAuthenticatedCranes.Inc()
|
||||
}
|
||||
}
|
||||
|
||||
// LocalAddr returns the crane's local address.
|
||||
func (t *CraneTerminal) LocalAddr() net.Addr {
|
||||
return t.crane.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the crane's remote address.
|
||||
func (t *CraneTerminal) RemoteAddr() net.Addr {
|
||||
return t.crane.RemoteAddr()
|
||||
}
|
||||
|
||||
// Transport returns the crane's transport.
|
||||
func (t *CraneTerminal) Transport() *hub.Transport {
|
||||
return t.crane.Transport()
|
||||
}
|
||||
|
||||
// IsBeingAbandoned returns whether the terminal is being abandoned.
|
||||
func (t *CraneTerminal) IsBeingAbandoned() bool {
|
||||
return t.Abandoning.IsSet()
|
||||
}
|
||||
|
||||
// HandleDestruction gives the terminal the ability to clean up.
|
||||
// The terminal has already fully shut down at this point.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
func (t *CraneTerminal) HandleDestruction(err *terminal.Error) {
|
||||
t.crane.AbandonTerminal(t.ID(), err)
|
||||
}
|
||||
267
spn/docks/crane_test.go
Normal file
267
spn/docks/crane_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
func TestCraneCommunication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCraneWithCounter(t, "plain-counter-load-100", false, 100, 1000)
|
||||
testCraneWithCounter(t, "plain-counter-load-1000", false, 1000, 1000)
|
||||
testCraneWithCounter(t, "plain-counter-load-10000", false, 10000, 1000)
|
||||
testCraneWithCounter(t, "encrypted-counter", true, 1000, 1000)
|
||||
}
|
||||
|
||||
func testCraneWithCounter(t *testing.T, testID string, encrypting bool, loadSize int, countTo uint64) { //nolint:unparam,thelper
|
||||
var identity *cabin.Identity
|
||||
var connectedHub *hub.Hub
|
||||
if encrypting {
|
||||
identity, connectedHub = getTestIdentity(t)
|
||||
}
|
||||
|
||||
// Build ship and cranes.
|
||||
optimalMinLoadSize = loadSize * 2
|
||||
ship := ships.NewTestShip(!encrypting, loadSize)
|
||||
|
||||
var crane1, crane2 *Crane
|
||||
var craneWg sync.WaitGroup
|
||||
craneWg.Add(2)
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
crane1, err = NewCrane(ship, connectedHub, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crane test %s could not create crane1: %s", testID, err))
|
||||
}
|
||||
err = crane1.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crane test %s could not start crane1: %s", testID, err))
|
||||
}
|
||||
craneWg.Done()
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
crane2, err = NewCrane(ship.Reverse(), nil, identity)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crane test %s could not create crane2: %s", testID, err))
|
||||
}
|
||||
err = crane2.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crane test %s could not start crane2: %s", testID, err))
|
||||
}
|
||||
craneWg.Done()
|
||||
}()
|
||||
|
||||
craneWg.Wait()
|
||||
t.Logf("crane test %s setup complete", testID)
|
||||
|
||||
// Wait async for test to complete, print stack after timeout.
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Logf("crane test %s is taking too long, print stack:", testID)
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Logf("crane1 controller: %+v", crane1.Controller)
|
||||
t.Logf("crane2 controller: %+v", crane2.Controller)
|
||||
|
||||
// Start counters for testing.
|
||||
op1, tErr := terminal.NewCounterOp(crane1.Controller, terminal.CounterOpts{
|
||||
ClientCountTo: countTo,
|
||||
ServerCountTo: countTo,
|
||||
})
|
||||
if tErr != nil {
|
||||
t.Fatalf("crane test %s failed to run counter op: %s", testID, tErr)
|
||||
}
|
||||
|
||||
// Wait for completion.
|
||||
op1.Wait()
|
||||
close(finished)
|
||||
|
||||
// Wait a little so that all errors can be propagated, so we can truly see
|
||||
// if we succeeded.
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Check errors.
|
||||
if op1.Error != nil {
|
||||
t.Fatalf("crane test %s counter op1 failed: %s", testID, op1.Error)
|
||||
}
|
||||
}
|
||||
|
||||
type StreamingTerminal struct {
|
||||
terminal.BareTerminal
|
||||
|
||||
test *testing.T
|
||||
id uint32
|
||||
crane *Crane
|
||||
recv chan *terminal.Msg
|
||||
testData []byte
|
||||
}
|
||||
|
||||
func (t *StreamingTerminal) ID() uint32 {
|
||||
return t.id
|
||||
}
|
||||
|
||||
func (t *StreamingTerminal) Ctx() context.Context {
|
||||
return module.Ctx
|
||||
}
|
||||
|
||||
func (t *StreamingTerminal) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
t.recv <- msg
|
||||
msg.Finish()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *StreamingTerminal) Abandon(err *terminal.Error) {
|
||||
t.crane.AbandonTerminal(t.ID(), err)
|
||||
if err != nil {
|
||||
t.test.Errorf("streaming terminal %d failed: %s", t.id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *StreamingTerminal) FmtID() string {
|
||||
return fmt.Sprintf("test-%d", t.id)
|
||||
}
|
||||
|
||||
func TestCraneLoadingUnloading(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCraneWithStreaming(t, "plain-streaming", false, 100)
|
||||
testCraneWithStreaming(t, "encrypted-streaming", true, 100)
|
||||
}
|
||||
|
||||
func testCraneWithStreaming(t *testing.T, testID string, encrypting bool, loadSize int) { //nolint:thelper
|
||||
var identity *cabin.Identity
|
||||
var connectedHub *hub.Hub
|
||||
if encrypting {
|
||||
identity, connectedHub = getTestIdentity(t)
|
||||
}
|
||||
|
||||
// Build ship and cranes.
|
||||
optimalMinLoadSize = loadSize * 2
|
||||
ship := ships.NewTestShip(!encrypting, loadSize)
|
||||
|
||||
var crane1, crane2 *Crane
|
||||
var craneWg sync.WaitGroup
|
||||
craneWg.Add(2)
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
crane1, err = NewCrane(ship, connectedHub, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crane test %s could not create crane1: %s", testID, err))
|
||||
}
|
||||
err = crane1.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crane test %s could not start crane1: %s", testID, err))
|
||||
}
|
||||
craneWg.Done()
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
crane2, err = NewCrane(ship.Reverse(), nil, identity)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crane test %s could not create crane2: %s", testID, err))
|
||||
}
|
||||
err = crane2.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crane test %s could not start crane2: %s", testID, err))
|
||||
}
|
||||
craneWg.Done()
|
||||
}()
|
||||
|
||||
craneWg.Wait()
|
||||
t.Logf("crane test %s setup complete", testID)
|
||||
|
||||
// Wait async for test to complete, print stack after timeout.
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Logf("crane test %s is taking too long, print stack:", testID)
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Logf("crane1 controller: %+v", crane1.Controller)
|
||||
t.Logf("crane2 controller: %+v", crane2.Controller)
|
||||
|
||||
// Create terminals and run test.
|
||||
st := &StreamingTerminal{
|
||||
test: t,
|
||||
id: 8,
|
||||
crane: crane2,
|
||||
recv: make(chan *terminal.Msg),
|
||||
testData: []byte("The quick brown fox jumps over the lazy dog."),
|
||||
}
|
||||
crane2.terminals[st.ID()] = st
|
||||
|
||||
// Run streaming test.
|
||||
var streamingWg sync.WaitGroup
|
||||
streamingWg.Add(2)
|
||||
count := 10000
|
||||
go func() {
|
||||
for i := 1; i <= count; i++ {
|
||||
msg := terminal.NewMsg(st.testData)
|
||||
msg.FlowID = st.id
|
||||
err := crane1.Send(msg, 1*time.Second)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
crane1.Stop(err.Wrap("failed to submit terminal msg"))
|
||||
}
|
||||
// log.Tracef("spn/testing: + %d", i)
|
||||
}
|
||||
t.Logf("crane test %s done with sending", testID)
|
||||
streamingWg.Done()
|
||||
}()
|
||||
go func() {
|
||||
for i := 1; i <= count; i++ {
|
||||
msg := <-st.recv
|
||||
assert.Equal(t, st.testData, msg.Data.CompileData(), "data mismatched")
|
||||
// log.Tracef("spn/testing: - %d", i)
|
||||
}
|
||||
t.Logf("crane test %s done with receiving", testID)
|
||||
streamingWg.Done()
|
||||
}()
|
||||
|
||||
// Wait for completion.
|
||||
streamingWg.Wait()
|
||||
close(finished)
|
||||
}
|
||||
|
||||
var testIdentity *cabin.Identity
|
||||
|
||||
func getTestIdentity(t *testing.T) (*cabin.Identity, *hub.Hub) {
|
||||
t.Helper()
|
||||
|
||||
if testIdentity == nil {
|
||||
var err error
|
||||
testIdentity, err = cabin.CreateIdentity(module.Ctx, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create identity: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return testIdentity, testIdentity.Hub
|
||||
}
|
||||
85
spn/docks/crane_verify.go
Normal file
85
spn/docks/crane_verify.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const (
|
||||
hubVerificationPurpose = "hub identify verification"
|
||||
)
|
||||
|
||||
// VerifyConnectedHub verifies the connected Hub.
|
||||
func (crane *Crane) VerifyConnectedHub(callerCtx context.Context) error {
|
||||
if !crane.ship.IsMine() || crane.nextTerminalID != 0 || crane.Public() {
|
||||
return errors.New("hub verification can only be executed in init phase by the client")
|
||||
}
|
||||
|
||||
// Create verification request.
|
||||
v, request, err := cabin.CreateVerificationRequest(hubVerificationPurpose, "", "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create verification request: %w", err)
|
||||
}
|
||||
|
||||
// Send it.
|
||||
msg := container.New(
|
||||
varint.Pack8(CraneMsgTypeVerify),
|
||||
request,
|
||||
)
|
||||
msg.PrependLength()
|
||||
err = crane.ship.Load(msg.CompileData())
|
||||
if err != nil {
|
||||
return terminal.ErrShipSunk.With("failed to send verification request: %w", err)
|
||||
}
|
||||
|
||||
// Wait for reply.
|
||||
var reply *container.Container
|
||||
select {
|
||||
case reply = <-crane.unloading:
|
||||
case <-time.After(2 * time.Minute):
|
||||
// Use a big timeout here, as this might keep servers from joining the
|
||||
// network at all, as every servers needs to verify every server, no
|
||||
// matter how far away.
|
||||
return terminal.ErrTimeout.With("waiting for verification reply")
|
||||
case <-crane.ctx.Done():
|
||||
return terminal.ErrShipSunk.With("waiting for verification reply")
|
||||
case <-callerCtx.Done():
|
||||
return terminal.ErrShipSunk.With("waiting for verification reply")
|
||||
}
|
||||
|
||||
// Verify reply.
|
||||
return v.Verify(reply.CompileData(), crane.ConnectedHub)
|
||||
}
|
||||
|
||||
func (crane *Crane) handleCraneVerification(request *container.Container) *terminal.Error {
|
||||
// Check if we have an identity.
|
||||
if crane.identity == nil {
|
||||
return terminal.ErrIncorrectUsage.With("cannot handle verification request without designated identity")
|
||||
}
|
||||
|
||||
response, err := crane.identity.SignVerificationRequest(
|
||||
request.CompileData(),
|
||||
hubVerificationPurpose,
|
||||
"", "",
|
||||
)
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("failed to sign verification request: %w", err)
|
||||
}
|
||||
msg := container.New(response)
|
||||
|
||||
// Manually send reply.
|
||||
msg.PrependLength()
|
||||
err = crane.ship.Load(msg.CompileData())
|
||||
if err != nil {
|
||||
return terminal.ErrShipSunk.With("failed to send verification reply: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
46
spn/docks/cranehooks.go
Normal file
46
spn/docks/cranehooks.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
var (
|
||||
craneUpdateHook func(crane *Crane)
|
||||
craneUpdateHookLock sync.Mutex
|
||||
)
|
||||
|
||||
// RegisterCraneUpdateHook allows the captain to hook into receiving updates for cranes.
|
||||
func RegisterCraneUpdateHook(fn func(crane *Crane)) {
|
||||
craneUpdateHookLock.Lock()
|
||||
defer craneUpdateHookLock.Unlock()
|
||||
|
||||
if craneUpdateHook == nil {
|
||||
craneUpdateHook = fn
|
||||
} else {
|
||||
log.Error("spn/docks: crane update hook already registered")
|
||||
}
|
||||
}
|
||||
|
||||
// ResetCraneUpdateHook resets the hook for receiving updates for cranes.
|
||||
func ResetCraneUpdateHook() {
|
||||
craneUpdateHookLock.Lock()
|
||||
defer craneUpdateHookLock.Unlock()
|
||||
|
||||
craneUpdateHook = nil
|
||||
}
|
||||
|
||||
// NotifyUpdate calls the registers crane update hook function.
|
||||
func (crane *Crane) NotifyUpdate() {
|
||||
if crane == nil {
|
||||
return
|
||||
}
|
||||
|
||||
craneUpdateHookLock.Lock()
|
||||
defer craneUpdateHookLock.Unlock()
|
||||
|
||||
if craneUpdateHook != nil {
|
||||
craneUpdateHook(crane)
|
||||
}
|
||||
}
|
||||
189
spn/docks/hub_import.go
Normal file
189
spn/docks/hub_import.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
var hubImportLock sync.Mutex
|
||||
|
||||
// ImportAndVerifyHubInfo imports the given hub message and verifies them.
|
||||
func ImportAndVerifyHubInfo(ctx context.Context, hubID string, announcementData, statusData []byte, mapName string, scope hub.Scope) (h *hub.Hub, forward bool, tErr *terminal.Error) {
|
||||
var firstErr *terminal.Error
|
||||
|
||||
// Synchronize import, as we might easily learn of a new hub from different
|
||||
// gossip channels simultaneously.
|
||||
hubImportLock.Lock()
|
||||
defer hubImportLock.Unlock()
|
||||
|
||||
// Check arguments.
|
||||
if announcementData == nil && statusData == nil {
|
||||
return nil, false, terminal.ErrInternalError.With("no announcement or status supplied")
|
||||
}
|
||||
|
||||
// Import Announcement, if given.
|
||||
var hubKnown, hubChanged bool
|
||||
if announcementData != nil {
|
||||
hubFromMsg, known, changed, err := hub.ApplyAnnouncement(nil, announcementData, mapName, scope, false)
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = terminal.ErrInternalError.With("failed to apply announcement: %w", err)
|
||||
}
|
||||
if known {
|
||||
hubKnown = true
|
||||
}
|
||||
if changed {
|
||||
hubChanged = true
|
||||
}
|
||||
if hubFromMsg != nil {
|
||||
h = hubFromMsg
|
||||
}
|
||||
}
|
||||
|
||||
// Import Status, if given.
|
||||
if statusData != nil {
|
||||
hubFromMsg, known, changed, err := hub.ApplyStatus(h, statusData, mapName, scope, false)
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = terminal.ErrInternalError.With("failed to apply status: %w", err)
|
||||
}
|
||||
if known && announcementData == nil {
|
||||
// If we parsed an announcement before, "known" will always be true here,
|
||||
// as we supply hub.ApplyStatus with a hub.
|
||||
hubKnown = true
|
||||
}
|
||||
if changed {
|
||||
hubChanged = true
|
||||
}
|
||||
if hubFromMsg != nil {
|
||||
h = hubFromMsg
|
||||
}
|
||||
}
|
||||
|
||||
// Only continue if we now have a Hub.
|
||||
if h == nil {
|
||||
if firstErr != nil {
|
||||
return nil, false, firstErr
|
||||
}
|
||||
return nil, false, terminal.ErrInternalError.With("got not hub after data import")
|
||||
}
|
||||
|
||||
// Abort if the given hub ID does not match.
|
||||
// We may have just connected to the wrong IP address.
|
||||
if hubID != "" && h.ID != hubID {
|
||||
return nil, false, terminal.ErrInternalError.With("hub mismatch")
|
||||
}
|
||||
|
||||
// Verify hub if:
|
||||
// - There is no error up until here.
|
||||
// - There has been any change.
|
||||
// - The hub is not verified yet.
|
||||
// - We're a public Hub.
|
||||
// - We're not testing.
|
||||
if firstErr == nil && hubChanged && !h.Verified() && conf.PublicHub() && !runningTests {
|
||||
if !conf.HubHasIPv4() && !conf.HubHasIPv6() {
|
||||
firstErr = terminal.ErrInternalError.With("no hub networks set")
|
||||
}
|
||||
if h.Info.IPv4 != nil && conf.HubHasIPv4() {
|
||||
err := verifyHubIP(ctx, h, h.Info.IPv4)
|
||||
if err != nil {
|
||||
firstErr = terminal.ErrIntegrity.With("failed to verify IPv4 address %s of %s: %w", h.Info.IPv4, h, err)
|
||||
}
|
||||
}
|
||||
if h.Info.IPv6 != nil && conf.HubHasIPv6() {
|
||||
err := verifyHubIP(ctx, h, h.Info.IPv6)
|
||||
if err != nil {
|
||||
firstErr = terminal.ErrIntegrity.With("failed to verify IPv6 address %s of %s: %w", h.Info.IPv6, h, err)
|
||||
}
|
||||
}
|
||||
|
||||
if firstErr != nil {
|
||||
func() {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
h.InvalidInfo = true
|
||||
}()
|
||||
log.Warningf("spn/docks: failed to verify IPs of %s: %s", h, firstErr)
|
||||
} else {
|
||||
func() {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
h.VerifiedIPs = true
|
||||
}()
|
||||
log.Infof("spn/docks: verified IPs of %s: IPv4=%s IPv6=%s", h, h.Info.IPv4, h.Info.IPv6)
|
||||
}
|
||||
}
|
||||
|
||||
// Dismiss initial imports with errors.
|
||||
if !hubKnown && firstErr != nil {
|
||||
return nil, false, firstErr
|
||||
}
|
||||
|
||||
// Don't do anything if nothing changed.
|
||||
if !hubChanged {
|
||||
return h, false, firstErr
|
||||
}
|
||||
|
||||
// We now have one of:
|
||||
// - A unknown Hub without error.
|
||||
// - A known Hub without error.
|
||||
// - A known Hub with error, which we want to save and propagate.
|
||||
|
||||
// Save the Hub to the database.
|
||||
err := h.Save()
|
||||
if err != nil {
|
||||
log.Errorf("spn/docks: failed to persist %s: %s", h, err)
|
||||
}
|
||||
|
||||
// Save the raw messages to the database.
|
||||
if announcementData != nil {
|
||||
err = hub.SaveHubMsg(h.ID, h.Map, hub.MsgTypeAnnouncement, announcementData)
|
||||
if err != nil {
|
||||
log.Errorf("spn/docks: failed to save raw announcement msg of %s: %s", h, err)
|
||||
}
|
||||
}
|
||||
if statusData != nil {
|
||||
err = hub.SaveHubMsg(h.ID, h.Map, hub.MsgTypeStatus, statusData)
|
||||
if err != nil {
|
||||
log.Errorf("spn/docks: failed to save raw status msg of %s: %s", h, err)
|
||||
}
|
||||
}
|
||||
|
||||
return h, true, firstErr
|
||||
}
|
||||
|
||||
func verifyHubIP(ctx context.Context, h *hub.Hub, ip net.IP) error {
|
||||
// Create connection.
|
||||
ship, err := ships.Launch(ctx, h, nil, ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to launch ship to %s: %w", ip, err)
|
||||
}
|
||||
|
||||
// Start crane for receiving reply.
|
||||
crane, err := NewCrane(ship, h, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create crane: %w", err)
|
||||
}
|
||||
module.StartWorker("crane unloader", crane.unloader)
|
||||
defer crane.Stop(nil)
|
||||
|
||||
// Verify Hub.
|
||||
err = crane.VerifyConnectedHub(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// End connection.
|
||||
tErr := crane.endInit()
|
||||
if tErr != nil {
|
||||
log.Debugf("spn/docks: failed to end verification connection to %s: %s", ip, tErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
108
spn/docks/measurements.go
Normal file
108
spn/docks/measurements.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// Measurement Configuration.
|
||||
const (
|
||||
CraneMeasurementTTLDefault = 30 * time.Minute
|
||||
CraneMeasurementTTLByCostBase = 1 * time.Minute
|
||||
CraneMeasurementTTLByCostMin = 30 * time.Minute
|
||||
CraneMeasurementTTLByCostMax = 3 * time.Hour
|
||||
|
||||
// With a base TTL of 1m, this leads to:
|
||||
// 20c -> 20m -> raised to 30m
|
||||
// 50c -> 50m
|
||||
// 100c -> 1h40m
|
||||
// 1000c -> 16h40m -> capped to 3h.
|
||||
)
|
||||
|
||||
// MeasureHub measures the connection to this Hub and saves the results to the
|
||||
// Hub.
|
||||
func MeasureHub(ctx context.Context, h *hub.Hub, checkExpiryWith time.Duration) *terminal.Error {
|
||||
// Check if we are measuring before building a connection.
|
||||
if capacityTestRunning.IsSet() {
|
||||
return terminal.ErrTryAgainLater.With("another capacity op is already running")
|
||||
}
|
||||
|
||||
// Check if we have a connection to this Hub.
|
||||
crane := GetAssignedCrane(h.ID)
|
||||
if crane == nil {
|
||||
// Connect to Hub.
|
||||
var err error
|
||||
crane, err = establishCraneForMeasuring(ctx, h)
|
||||
if err != nil {
|
||||
return terminal.ErrConnectionError.With("failed to connect to %s: %s", h, err)
|
||||
}
|
||||
// Stop crane if established just for measuring.
|
||||
defer crane.Stop(nil)
|
||||
}
|
||||
|
||||
// Run latency test.
|
||||
_, expires := h.GetMeasurements().GetLatency()
|
||||
if checkExpiryWith == 0 || time.Now().Add(-checkExpiryWith).After(expires) {
|
||||
latOp, tErr := NewLatencyTestOp(crane.Controller)
|
||||
if !tErr.IsOK() {
|
||||
return tErr
|
||||
}
|
||||
select {
|
||||
case tErr = <-latOp.Result():
|
||||
if !tErr.IsOK() {
|
||||
return tErr
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return terminal.ErrCanceled
|
||||
case <-time.After(1 * time.Minute):
|
||||
crane.Controller.StopOperation(latOp, terminal.ErrTimeout)
|
||||
return terminal.ErrTimeout.With("waiting for latency test")
|
||||
}
|
||||
}
|
||||
|
||||
// Run capacity test.
|
||||
_, expires = h.GetMeasurements().GetCapacity()
|
||||
if checkExpiryWith == 0 || time.Now().Add(-checkExpiryWith).After(expires) {
|
||||
capOp, tErr := NewCapacityTestOp(crane.Controller, nil)
|
||||
if !tErr.IsOK() {
|
||||
return tErr
|
||||
}
|
||||
select {
|
||||
case tErr = <-capOp.Result():
|
||||
if !tErr.IsOK() {
|
||||
return tErr
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return terminal.ErrCanceled
|
||||
case <-time.After(1 * time.Minute):
|
||||
crane.Controller.StopOperation(capOp, terminal.ErrTimeout)
|
||||
return terminal.ErrTimeout.With("waiting for capacity test")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func establishCraneForMeasuring(ctx context.Context, dst *hub.Hub) (*Crane, error) {
|
||||
ship, err := ships.Launch(ctx, dst, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to launch ship: %w", err)
|
||||
}
|
||||
|
||||
crane, err := NewCrane(ship, dst, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create crane: %w", err)
|
||||
}
|
||||
|
||||
err = crane.Start(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start crane: %w", err)
|
||||
}
|
||||
|
||||
return crane, nil
|
||||
}
|
||||
404
spn/docks/metrics.go
Normal file
404
spn/docks/metrics.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/metrics"
|
||||
)
|
||||
|
||||
var (
|
||||
newCranes *metrics.Counter
|
||||
newPublicCranes *metrics.Counter
|
||||
newAuthenticatedCranes *metrics.Counter
|
||||
|
||||
trafficBytesPublicCranes *metrics.Counter
|
||||
trafficBytesAuthenticatedCranes *metrics.Counter
|
||||
trafficBytesPrivateCranes *metrics.Counter
|
||||
|
||||
newExpandOp *metrics.Counter
|
||||
expandOpDurationHistogram *metrics.Histogram
|
||||
expandOpRelayedDataHistogram *metrics.Histogram
|
||||
|
||||
metricsRegistered = abool.New()
|
||||
)
|
||||
|
||||
func registerMetrics() (err error) {
|
||||
// Only register metrics once.
|
||||
if !metricsRegistered.SetToIf(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Total Crane Stats.
|
||||
|
||||
newCranes, err = metrics.NewCounter(
|
||||
"spn/cranes/total",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN New Cranes",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newPublicCranes, err = metrics.NewCounter(
|
||||
"spn/cranes/public/total",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN New Public Cranes",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newAuthenticatedCranes, err = metrics.NewCounter(
|
||||
"spn/cranes/authenticated/total",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN New Authenticated Cranes",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Active Crane Stats.
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/cranes/active",
|
||||
map[string]string{
|
||||
"status": "public",
|
||||
},
|
||||
getActivePublicCranes,
|
||||
&metrics.Options{
|
||||
Name: "SPN Active Public Cranes",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/cranes/active",
|
||||
map[string]string{
|
||||
"status": "authenticated",
|
||||
},
|
||||
getActiveAuthenticatedCranes,
|
||||
&metrics.Options{
|
||||
Name: "SPN Active Authenticated Cranes",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/cranes/active",
|
||||
map[string]string{
|
||||
"status": "private",
|
||||
},
|
||||
getActivePrivateCranes,
|
||||
&metrics.Options{
|
||||
Name: "SPN Active Private Cranes",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/cranes/active",
|
||||
map[string]string{
|
||||
"status": "stopping",
|
||||
},
|
||||
getActiveStoppingCranes,
|
||||
&metrics.Options{
|
||||
Name: "SPN Active Stopping Cranes",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Crane Traffic Stats.
|
||||
|
||||
trafficBytesPublicCranes, err = metrics.NewCounter(
|
||||
"spn/cranes/bytes",
|
||||
map[string]string{
|
||||
"status": "public",
|
||||
},
|
||||
&metrics.Options{
|
||||
Name: "SPN Public Crane Traffic",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trafficBytesAuthenticatedCranes, err = metrics.NewCounter(
|
||||
"spn/cranes/bytes",
|
||||
map[string]string{
|
||||
"status": "authenticated",
|
||||
},
|
||||
&metrics.Options{
|
||||
Name: "SPN Authenticated Crane Traffic",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trafficBytesPrivateCranes, err = metrics.NewCounter(
|
||||
"spn/cranes/bytes",
|
||||
map[string]string{
|
||||
"status": "private",
|
||||
},
|
||||
&metrics.Options{
|
||||
Name: "SPN Private Crane Traffic",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Lane Stats.
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/lanes/latency/avg/seconds",
|
||||
nil,
|
||||
getAvgLaneLatencyStat,
|
||||
&metrics.Options{
|
||||
Name: "SPN Avg Lane Latency",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/lanes/latency/min/seconds",
|
||||
nil,
|
||||
getMinLaneLatencyStat,
|
||||
&metrics.Options{
|
||||
Name: "SPN Min Lane Latency",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/lanes/capacity/avg/bytes",
|
||||
nil,
|
||||
getAvgLaneCapacityStat,
|
||||
&metrics.Options{
|
||||
Name: "SPN Avg Lane Capacity",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/lanes/capacity/max/bytes",
|
||||
nil,
|
||||
getMaxLaneCapacityStat,
|
||||
&metrics.Options{
|
||||
Name: "SPN Max Lane Capacity",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Expand Op Stats.
|
||||
|
||||
newExpandOp, err = metrics.NewCounter(
|
||||
"spn/op/expand/total",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Total Expand Operations",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = metrics.NewGauge(
|
||||
"spn/op/expand/active",
|
||||
nil,
|
||||
getActiveExpandOpsStat,
|
||||
&metrics.Options{
|
||||
Name: "SPN Active Expand Operations",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expandOpDurationHistogram, err = metrics.NewHistogram(
|
||||
"spn/op/expand/histogram/duration/seconds",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Expand Operation Duration Histogram",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expandOpRelayedDataHistogram, err = metrics.NewHistogram(
|
||||
"spn/op/expand/histogram/traffic/bytes",
|
||||
nil,
|
||||
&metrics.Options{
|
||||
Name: "SPN Expand Operation Relayed Data Histogram",
|
||||
Permission: api.PermitUser,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func getActiveExpandOpsStat() float64 {
|
||||
return float64(atomic.LoadInt64(activeExpandOps))
|
||||
}
|
||||
|
||||
var (
|
||||
craneStats *craneGauges
|
||||
craneStatsExpires time.Time
|
||||
craneStatsLock sync.Mutex
|
||||
craneStatsTTL = 55 * time.Second
|
||||
)
|
||||
|
||||
type craneGauges struct {
|
||||
publicActive float64
|
||||
authenticatedActive float64
|
||||
privateActive float64
|
||||
stoppingActive float64
|
||||
|
||||
laneLatencyAvg float64
|
||||
laneLatencyMin float64
|
||||
laneCapacityAvg float64
|
||||
laneCapacityMax float64
|
||||
}
|
||||
|
||||
func getActivePublicCranes() float64 { return getCraneStats().publicActive }
|
||||
func getActiveAuthenticatedCranes() float64 { return getCraneStats().authenticatedActive }
|
||||
func getActivePrivateCranes() float64 { return getCraneStats().privateActive }
|
||||
func getActiveStoppingCranes() float64 { return getCraneStats().stoppingActive }
|
||||
func getAvgLaneLatencyStat() float64 { return getCraneStats().laneLatencyAvg }
|
||||
func getMinLaneLatencyStat() float64 { return getCraneStats().laneLatencyMin }
|
||||
func getAvgLaneCapacityStat() float64 { return getCraneStats().laneCapacityAvg }
|
||||
func getMaxLaneCapacityStat() float64 { return getCraneStats().laneCapacityMax }
|
||||
|
||||
func getCraneStats() *craneGauges {
|
||||
craneStatsLock.Lock()
|
||||
defer craneStatsLock.Unlock()
|
||||
|
||||
// Return cache if still valid.
|
||||
if time.Now().Before(craneStatsExpires) {
|
||||
return craneStats
|
||||
}
|
||||
|
||||
// Refresh.
|
||||
craneStats = &craneGauges{}
|
||||
var laneStatCnt float64
|
||||
for _, crane := range getAllCranes() {
|
||||
switch {
|
||||
case crane.Stopped():
|
||||
continue
|
||||
case crane.IsStopping():
|
||||
craneStats.stoppingActive++
|
||||
continue
|
||||
case crane.Public():
|
||||
craneStats.publicActive++
|
||||
case crane.Authenticated():
|
||||
craneStats.authenticatedActive++
|
||||
continue
|
||||
default:
|
||||
craneStats.privateActive++
|
||||
continue
|
||||
}
|
||||
|
||||
// Get lane stats.
|
||||
if crane.ConnectedHub == nil {
|
||||
continue
|
||||
}
|
||||
measurements := crane.ConnectedHub.GetMeasurements()
|
||||
laneLatency, _ := measurements.GetLatency()
|
||||
if laneLatency == 0 {
|
||||
continue
|
||||
}
|
||||
laneCapacity, _ := measurements.GetCapacity()
|
||||
if laneCapacity == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only use data if both latency and capacity is available.
|
||||
laneStatCnt++
|
||||
|
||||
// Convert to base unit: seconds.
|
||||
latency := laneLatency.Seconds()
|
||||
// Add to avg and set min if lower.
|
||||
craneStats.laneLatencyAvg += latency
|
||||
if craneStats.laneLatencyMin > latency || craneStats.laneLatencyMin == 0 {
|
||||
craneStats.laneLatencyMin = latency
|
||||
}
|
||||
|
||||
// Convert in base unit: bytes.
|
||||
capacity := float64(laneCapacity) / 8
|
||||
// Add to avg and set max if higher.
|
||||
craneStats.laneCapacityAvg += capacity
|
||||
if craneStats.laneCapacityMax < capacity {
|
||||
craneStats.laneCapacityMax = capacity
|
||||
}
|
||||
}
|
||||
|
||||
// Create averages.
|
||||
if laneStatCnt > 0 {
|
||||
craneStats.laneLatencyAvg /= laneStatCnt
|
||||
craneStats.laneCapacityAvg /= laneStatCnt
|
||||
}
|
||||
|
||||
craneStatsExpires = time.Now().Add(craneStatsTTL)
|
||||
return craneStats
|
||||
}
|
||||
|
||||
func (crane *Crane) submitCraneTrafficStats(bytes int) {
|
||||
switch {
|
||||
case crane.Stopped():
|
||||
return
|
||||
case crane.Public():
|
||||
trafficBytesPublicCranes.Add(bytes)
|
||||
case crane.Authenticated():
|
||||
trafficBytesAuthenticatedCranes.Add(bytes)
|
||||
default:
|
||||
trafficBytesPrivateCranes.Add(bytes)
|
||||
}
|
||||
}
|
||||
117
spn/docks/module.go
Normal file
117
spn/docks/module.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/rng"
|
||||
_ "github.com/safing/portmaster/spn/access"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
allCranes = make(map[string]*Crane) // ID = Crane ID
|
||||
assignedCranes = make(map[string]*Crane) // ID = connected Hub ID
|
||||
cranesLock sync.RWMutex
|
||||
|
||||
runningTests bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("docks", nil, start, stopAllCranes, "terminal", "cabin", "access")
|
||||
}
|
||||
|
||||
func start() error {
|
||||
return registerMetrics()
|
||||
}
|
||||
|
||||
func registerCrane(crane *Crane) error {
|
||||
cranesLock.Lock()
|
||||
defer cranesLock.Unlock()
|
||||
|
||||
// Generate new IDs until a unique one is found.
|
||||
for i := 0; i < 100; i++ {
|
||||
// Generate random ID.
|
||||
randomID, err := rng.Bytes(3)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate crane ID: %w", err)
|
||||
}
|
||||
newID := hex.EncodeToString(randomID)
|
||||
|
||||
// Check if ID already exists.
|
||||
_, ok := allCranes[newID]
|
||||
if !ok {
|
||||
crane.ID = newID
|
||||
allCranes[crane.ID] = crane
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("failed to find unique crane ID")
|
||||
}
|
||||
|
||||
func unregisterCrane(crane *Crane) {
|
||||
cranesLock.Lock()
|
||||
defer cranesLock.Unlock()
|
||||
|
||||
delete(allCranes, crane.ID)
|
||||
if crane.ConnectedHub != nil {
|
||||
delete(assignedCranes, crane.ConnectedHub.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func stopAllCranes() error {
|
||||
for _, crane := range getAllCranes() {
|
||||
crane.Stop(nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssignCrane assigns a crane to the given Hub ID.
|
||||
func AssignCrane(hubID string, crane *Crane) {
|
||||
cranesLock.Lock()
|
||||
defer cranesLock.Unlock()
|
||||
|
||||
assignedCranes[hubID] = crane
|
||||
}
|
||||
|
||||
// GetAssignedCrane returns the assigned crane of the given Hub ID.
|
||||
func GetAssignedCrane(hubID string) *Crane {
|
||||
cranesLock.RLock()
|
||||
defer cranesLock.RUnlock()
|
||||
|
||||
crane, ok := assignedCranes[hubID]
|
||||
if ok {
|
||||
return crane
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAllCranes() map[string]*Crane {
|
||||
copiedCranes := make(map[string]*Crane, len(allCranes))
|
||||
|
||||
cranesLock.RLock()
|
||||
defer cranesLock.RUnlock()
|
||||
|
||||
for id, crane := range allCranes {
|
||||
copiedCranes[id] = crane
|
||||
}
|
||||
return copiedCranes
|
||||
}
|
||||
|
||||
// GetAllAssignedCranes returns a copy of the map of all assigned cranes.
|
||||
func GetAllAssignedCranes() map[string]*Crane {
|
||||
copiedCranes := make(map[string]*Crane, len(assignedCranes))
|
||||
|
||||
cranesLock.RLock()
|
||||
defer cranesLock.RUnlock()
|
||||
|
||||
for destination, crane := range assignedCranes {
|
||||
copiedCranes[destination] = crane
|
||||
}
|
||||
return copiedCranes
|
||||
}
|
||||
16
spn/docks/module_test.go
Normal file
16
spn/docks/module_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/service/core/pmtesting"
|
||||
"github.com/safing/portmaster/spn/access"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
runningTests = true
|
||||
conf.EnablePublicHub(true) // Make hub config available.
|
||||
access.EnableTestMode() // Register test zone instead of real ones.
|
||||
pmtesting.TestMain(m, module)
|
||||
}
|
||||
356
spn/docks/op_capacity.go
Normal file
356
spn/docks/op_capacity.go
Normal file
@@ -0,0 +1,356 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const (
|
||||
// CapacityTestOpType is the type ID of the capacity test operation.
|
||||
CapacityTestOpType = "capacity"
|
||||
|
||||
defaultCapacityTestVolume = 50000000 // 50MB
|
||||
maxCapacityTestVolume = 100000000 // 100MB
|
||||
|
||||
defaultCapacityTestMaxTime = 5 * time.Second
|
||||
maxCapacityTestMaxTime = 15 * time.Second
|
||||
capacityTestTimeout = 30 * time.Second
|
||||
|
||||
capacityTestMsgSize = 1000
|
||||
capacityTestSendTimeout = 1000 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
capacityTestSendData = make([]byte, capacityTestMsgSize)
|
||||
capacityTestDataReceivedSignal = []byte("ACK")
|
||||
|
||||
capacityTestRunning = abool.New()
|
||||
)
|
||||
|
||||
// CapacityTestOp is used for capacity test operations.
|
||||
type CapacityTestOp struct { //nolint:maligned
|
||||
terminal.OperationBase
|
||||
|
||||
opts *CapacityTestOptions
|
||||
|
||||
started bool
|
||||
startTime time.Time
|
||||
senderStarted bool
|
||||
|
||||
recvQueue chan *terminal.Msg
|
||||
dataReceived int
|
||||
dataReceivedAckWasAckd bool
|
||||
|
||||
dataSent *int64
|
||||
dataSentWasAckd *abool.AtomicBool
|
||||
|
||||
testResult int
|
||||
result chan *terminal.Error
|
||||
}
|
||||
|
||||
// CapacityTestOptions holds options for the capacity test.
|
||||
type CapacityTestOptions struct {
|
||||
TestVolume int
|
||||
MaxTime time.Duration
|
||||
testing bool
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *CapacityTestOp) Type() string {
|
||||
return CapacityTestOpType
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: CapacityTestOpType,
|
||||
Requires: terminal.IsCraneController,
|
||||
Start: startCapacityTestOp,
|
||||
})
|
||||
}
|
||||
|
||||
// NewCapacityTestOp runs a capacity test.
|
||||
func NewCapacityTestOp(t terminal.Terminal, opts *CapacityTestOptions) (*CapacityTestOp, *terminal.Error) {
|
||||
// Check options.
|
||||
if opts == nil {
|
||||
opts = &CapacityTestOptions{
|
||||
TestVolume: defaultCapacityTestVolume,
|
||||
MaxTime: defaultCapacityTestMaxTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if another test is already running.
|
||||
if !opts.testing && !capacityTestRunning.SetToIf(false, true) {
|
||||
return nil, terminal.ErrTryAgainLater.With("another capacity op is already running")
|
||||
}
|
||||
|
||||
// Create and init.
|
||||
op := &CapacityTestOp{
|
||||
opts: opts,
|
||||
recvQueue: make(chan *terminal.Msg),
|
||||
dataSent: new(int64),
|
||||
dataSentWasAckd: abool.New(),
|
||||
result: make(chan *terminal.Error, 1),
|
||||
}
|
||||
|
||||
// Make capacity test request.
|
||||
request, err := dsd.Dump(op.opts, dsd.CBOR)
|
||||
if err != nil {
|
||||
capacityTestRunning.UnSet()
|
||||
return nil, terminal.ErrInternalError.With("failed to serialize capactity test options: %w", err)
|
||||
}
|
||||
|
||||
// Send test request.
|
||||
tErr := t.StartOperation(op, container.New(request), 1*time.Second)
|
||||
if tErr != nil {
|
||||
capacityTestRunning.UnSet()
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
// Start handler.
|
||||
module.StartWorker("op capacity handler", op.handler)
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func startCapacityTestOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Check if another test is already running.
|
||||
if !capacityTestRunning.SetToIf(false, true) {
|
||||
return nil, terminal.ErrTryAgainLater.With("another capacity op is already running")
|
||||
}
|
||||
|
||||
// Parse options.
|
||||
opts := &CapacityTestOptions{}
|
||||
_, err := dsd.Load(data.CompileData(), opts)
|
||||
if err != nil {
|
||||
capacityTestRunning.UnSet()
|
||||
return nil, terminal.ErrMalformedData.With("failed to parse options: %w", err)
|
||||
}
|
||||
|
||||
// Check options.
|
||||
if opts.TestVolume > maxCapacityTestVolume {
|
||||
capacityTestRunning.UnSet()
|
||||
return nil, terminal.ErrInvalidOptions.With("maximum volume exceeded")
|
||||
}
|
||||
if opts.MaxTime > maxCapacityTestMaxTime {
|
||||
capacityTestRunning.UnSet()
|
||||
return nil, terminal.ErrInvalidOptions.With("maximum maxtime exceeded")
|
||||
}
|
||||
|
||||
// Create operation.
|
||||
op := &CapacityTestOp{
|
||||
opts: opts,
|
||||
recvQueue: make(chan *terminal.Msg, 1000),
|
||||
dataSent: new(int64),
|
||||
dataSentWasAckd: abool.New(),
|
||||
result: make(chan *terminal.Error, 1),
|
||||
}
|
||||
op.InitOperationBase(t, opID)
|
||||
|
||||
// Start handler and sender.
|
||||
op.senderStarted = true
|
||||
module.StartWorker("op capacity handler", op.handler)
|
||||
module.StartWorker("op capacity sender", op.sender)
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (op *CapacityTestOp) handler(ctx context.Context) error {
|
||||
defer capacityTestRunning.UnSet()
|
||||
|
||||
returnErr := terminal.ErrStopping
|
||||
defer func() {
|
||||
// Linters don't get that returnErr is used when directly used as defer.
|
||||
op.Stop(op, returnErr)
|
||||
}()
|
||||
|
||||
var maxTestTimeReached <-chan time.Time
|
||||
opTimeout := time.After(capacityTestTimeout)
|
||||
|
||||
// Setup unit handling
|
||||
var msg *terminal.Msg
|
||||
defer msg.Finish()
|
||||
|
||||
// Handle receives.
|
||||
for {
|
||||
msg.Finish()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
returnErr = terminal.ErrCanceled
|
||||
return nil
|
||||
|
||||
case <-opTimeout:
|
||||
returnErr = terminal.ErrTimeout
|
||||
return nil
|
||||
|
||||
case <-maxTestTimeReached:
|
||||
returnErr = op.reportMeasuredCapacity()
|
||||
return nil
|
||||
|
||||
case msg = <-op.recvQueue:
|
||||
// Record start time and start sender.
|
||||
if !op.started {
|
||||
op.started = true
|
||||
op.startTime = time.Now()
|
||||
maxTestTimeReached = time.After(op.opts.MaxTime)
|
||||
if !op.senderStarted {
|
||||
op.senderStarted = true
|
||||
module.StartWorker("op capacity sender", op.sender)
|
||||
}
|
||||
}
|
||||
|
||||
// Add to received data counter.
|
||||
op.dataReceived += msg.Data.Length()
|
||||
|
||||
// Check if we received the data received signal.
|
||||
if msg.Data.Length() == len(capacityTestDataReceivedSignal) &&
|
||||
bytes.Equal(msg.Data.CompileData(), capacityTestDataReceivedSignal) {
|
||||
op.dataSentWasAckd.Set()
|
||||
}
|
||||
|
||||
// Send the data received signal when we received the full test volume.
|
||||
if op.dataReceived >= op.opts.TestVolume && !op.dataReceivedAckWasAckd {
|
||||
tErr := op.Send(op.NewMsg(capacityTestDataReceivedSignal), capacityTestSendTimeout)
|
||||
if tErr != nil {
|
||||
returnErr = tErr.Wrap("failed to send data received signal")
|
||||
return nil
|
||||
}
|
||||
atomic.AddInt64(op.dataSent, int64(len(capacityTestDataReceivedSignal)))
|
||||
op.dataReceivedAckWasAckd = true
|
||||
|
||||
// Flush last message.
|
||||
op.Flush(10 * time.Second)
|
||||
}
|
||||
|
||||
// Check if we can complete the test.
|
||||
if op.dataReceivedAckWasAckd &&
|
||||
op.dataSentWasAckd.IsSet() {
|
||||
returnErr = op.reportMeasuredCapacity()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (op *CapacityTestOp) sender(ctx context.Context) error {
|
||||
for {
|
||||
// Send next chunk.
|
||||
msg := op.NewMsg(capacityTestSendData)
|
||||
msg.Unit.MakeHighPriority()
|
||||
tErr := op.Send(msg, capacityTestSendTimeout)
|
||||
if tErr != nil {
|
||||
op.Stop(op, tErr.Wrap("failed to send capacity test data"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add to sent data counter and stop sending if sending is complete.
|
||||
if atomic.AddInt64(op.dataSent, int64(len(capacityTestSendData))) >= int64(op.opts.TestVolume) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we have received an ack.
|
||||
if op.dataSentWasAckd.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if op has ended.
|
||||
if op.Stopped() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (op *CapacityTestOp) reportMeasuredCapacity() *terminal.Error {
|
||||
// Calculate lane capacity and set it.
|
||||
timeNeeded := time.Since(op.startTime)
|
||||
if timeNeeded <= 0 {
|
||||
timeNeeded = 1
|
||||
}
|
||||
duplexBits := float64((int64(op.dataReceived) + atomic.LoadInt64(op.dataSent)) * 8)
|
||||
duplexNSBitRate := duplexBits / float64(timeNeeded)
|
||||
bitRate := (duplexNSBitRate / 2) * float64(time.Second)
|
||||
op.testResult = int(bitRate)
|
||||
|
||||
// Save the result to the crane.
|
||||
if controller, ok := op.Terminal().(*CraneControllerTerminal); ok {
|
||||
if controller.Crane.ConnectedHub != nil {
|
||||
controller.Crane.ConnectedHub.GetMeasurements().SetCapacity(op.testResult)
|
||||
log.Infof(
|
||||
"docks: measured capacity to %s: %.2f Mbit/s (%.2fMB down / %.2fMB up in %s)",
|
||||
controller.Crane.ConnectedHub,
|
||||
float64(op.testResult)/1000000,
|
||||
float64(op.dataReceived)/1000000,
|
||||
float64(atomic.LoadInt64(op.dataSent))/1000000,
|
||||
timeNeeded,
|
||||
)
|
||||
return nil
|
||||
} else if controller.Crane.IsMine() {
|
||||
return terminal.ErrInternalError.With("capacity operation was run on %s without a connected hub set", controller.Crane)
|
||||
}
|
||||
} else if !runningTests {
|
||||
return terminal.ErrInternalError.With("capacity operation was run on terminal that is not a crane controller, but %T", op.Terminal())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deliver delivers a message.
|
||||
func (op *CapacityTestOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
// Optimized delivery with 1s timeout.
|
||||
select {
|
||||
case op.recvQueue <- msg:
|
||||
default:
|
||||
select {
|
||||
case op.recvQueue <- msg:
|
||||
case <-time.After(1 * time.Second):
|
||||
msg.Finish()
|
||||
return terminal.ErrTimeout
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *CapacityTestOp) HandleStop(tErr *terminal.Error) (errorToSend *terminal.Error) {
|
||||
// Return result to waiting routine.
|
||||
select {
|
||||
case op.result <- tErr:
|
||||
default:
|
||||
}
|
||||
|
||||
// Drain the recvQueue to finish the message units.
|
||||
drain:
|
||||
for {
|
||||
select {
|
||||
case msg := <-op.recvQueue:
|
||||
msg.Finish()
|
||||
default:
|
||||
select {
|
||||
case msg := <-op.recvQueue:
|
||||
msg.Finish()
|
||||
case <-time.After(3 * time.Millisecond):
|
||||
// Give some additional time buffer to drain the queue.
|
||||
break drain
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return error as is.
|
||||
return tErr
|
||||
}
|
||||
|
||||
// Result returns the result (end error) of the operation.
|
||||
func (op *CapacityTestOp) Result() <-chan *terminal.Error {
|
||||
return op.result
|
||||
}
|
||||
85
spn/docks/op_capacity_test.go
Normal file
85
spn/docks/op_capacity_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
var (
|
||||
testCapacityTestVolume = 1_000_000
|
||||
testCapacitytestMaxTime = 1 * time.Second
|
||||
)
|
||||
|
||||
func TestCapacityOp(t *testing.T) { //nolint:paralleltest // Performance test.
|
||||
// Defaults.
|
||||
testCapacityOp(t, &CapacityTestOptions{
|
||||
TestVolume: testCapacityTestVolume,
|
||||
MaxTime: testCapacitytestMaxTime,
|
||||
testing: true,
|
||||
})
|
||||
|
||||
// Hit max time first.
|
||||
testCapacityOp(t, &CapacityTestOptions{
|
||||
TestVolume: testCapacityTestVolume,
|
||||
MaxTime: 100 * time.Millisecond,
|
||||
testing: true,
|
||||
})
|
||||
|
||||
// Hit volume first.
|
||||
testCapacityOp(t, &CapacityTestOptions{
|
||||
TestVolume: 100_000,
|
||||
MaxTime: testCapacitytestMaxTime,
|
||||
testing: true,
|
||||
})
|
||||
}
|
||||
|
||||
func testCapacityOp(t *testing.T, opts *CapacityTestOptions) {
|
||||
t.Helper()
|
||||
|
||||
var (
|
||||
capTestDelay = 5 * time.Millisecond
|
||||
capTestQueueSize uint32 = 10
|
||||
)
|
||||
|
||||
// Create test terminal pair.
|
||||
a, b, err := terminal.NewSimpleTestTerminalPair(
|
||||
capTestDelay,
|
||||
int(capTestQueueSize),
|
||||
&terminal.TerminalOpts{
|
||||
FlowControl: terminal.FlowControlDFQ,
|
||||
FlowControlSize: capTestQueueSize,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test terminal pair: %s", err)
|
||||
}
|
||||
|
||||
// Grant permission for op on remote terminal and start op.
|
||||
b.GrantPermission(terminal.IsCraneController)
|
||||
op, tErr := NewCapacityTestOp(a, opts)
|
||||
if tErr != nil {
|
||||
t.Fatalf("failed to start op: %s", err)
|
||||
}
|
||||
|
||||
// Wait for result and check error.
|
||||
tErr = <-op.Result()
|
||||
if !tErr.IsOK() {
|
||||
t.Fatalf("op failed: %s", tErr)
|
||||
}
|
||||
t.Logf("measured capacity: %d bit/s", op.testResult)
|
||||
|
||||
// Calculate expected bandwidth.
|
||||
expectedBitsPerSecond := float64(capacityTestMsgSize*8*int64(capTestQueueSize)) / float64(capTestDelay) * float64(time.Second)
|
||||
t.Logf("expected capacity: %f bit/s", expectedBitsPerSecond)
|
||||
|
||||
// Check if measured bandwidth is within parameters.
|
||||
if float64(op.testResult) > expectedBitsPerSecond*1.6 {
|
||||
t.Fatal("measured capacity too high")
|
||||
}
|
||||
// TODO: Check if we can raise this to at least 90%.
|
||||
if float64(op.testResult) < expectedBitsPerSecond*0.2 {
|
||||
t.Fatal("measured capacity too low")
|
||||
}
|
||||
}
|
||||
393
spn/docks/op_expand.go
Normal file
393
spn/docks/op_expand.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// ExpandOpType is the type ID of the expand operation.
|
||||
const ExpandOpType string = "expand"
|
||||
|
||||
var activeExpandOps = new(int64)
|
||||
|
||||
// ExpandOp is used to expand to another Hub.
|
||||
type ExpandOp struct {
|
||||
terminal.OperationBase
|
||||
opts *terminal.TerminalOpts
|
||||
|
||||
// ctx is the context of the Terminal.
|
||||
ctx context.Context
|
||||
// cancelCtx cancels ctx.
|
||||
cancelCtx context.CancelFunc
|
||||
|
||||
dataRelayed *uint64
|
||||
ended *abool.AtomicBool
|
||||
|
||||
relayTerminal *ExpansionRelayTerminal
|
||||
|
||||
// flowControl holds the flow control system.
|
||||
flowControl terminal.FlowControl
|
||||
// deliverProxy is populated with the configured deliver function
|
||||
deliverProxy func(msg *terminal.Msg) *terminal.Error
|
||||
// recvProxy is populated with the configured recv function
|
||||
recvProxy func() <-chan *terminal.Msg
|
||||
// sendProxy is populated with the configured send function
|
||||
sendProxy func(msg *terminal.Msg, timeout time.Duration)
|
||||
}
|
||||
|
||||
// ExpansionRelayTerminal is a relay used for expansion.
|
||||
type ExpansionRelayTerminal struct {
|
||||
terminal.BareTerminal
|
||||
|
||||
op *ExpandOp
|
||||
|
||||
id uint32
|
||||
crane *Crane
|
||||
|
||||
abandoning *abool.AtomicBool
|
||||
|
||||
// flowControl holds the flow control system.
|
||||
flowControl terminal.FlowControl
|
||||
// deliverProxy is populated with the configured deliver function
|
||||
deliverProxy func(msg *terminal.Msg) *terminal.Error
|
||||
// recvProxy is populated with the configured recv function
|
||||
recvProxy func() <-chan *terminal.Msg
|
||||
// sendProxy is populated with the configured send function
|
||||
sendProxy func(msg *terminal.Msg, timeout time.Duration)
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *ExpandOp) Type() string {
|
||||
return ExpandOpType
|
||||
}
|
||||
|
||||
// ID returns the operation ID.
|
||||
func (t *ExpansionRelayTerminal) ID() uint32 {
|
||||
return t.id
|
||||
}
|
||||
|
||||
// Ctx returns the operation context.
|
||||
func (op *ExpandOp) Ctx() context.Context {
|
||||
return op.ctx
|
||||
}
|
||||
|
||||
// Ctx returns the relay terminal context.
|
||||
func (t *ExpansionRelayTerminal) Ctx() context.Context {
|
||||
return t.op.ctx
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the relay operation.
|
||||
func (op *ExpandOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
return op.deliverProxy(msg)
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the relay terminal.
|
||||
func (t *ExpansionRelayTerminal) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
return t.deliverProxy(msg)
|
||||
}
|
||||
|
||||
// Flush writes all data in the queues.
|
||||
func (op *ExpandOp) Flush(timeout time.Duration) {
|
||||
if op.flowControl != nil {
|
||||
op.flowControl.Flush(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// Flush writes all data in the queues.
|
||||
func (t *ExpansionRelayTerminal) Flush(timeout time.Duration) {
|
||||
if t.flowControl != nil {
|
||||
t.flowControl.Flush(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: ExpandOpType,
|
||||
Requires: terminal.MayExpand,
|
||||
Start: expand,
|
||||
})
|
||||
}
|
||||
|
||||
func expand(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Submit metrics.
|
||||
newExpandOp.Inc()
|
||||
|
||||
// Check if we are running a public hub.
|
||||
if !conf.PublicHub() {
|
||||
return nil, terminal.ErrPermissionDenied.With("expanding is only allowed on public hubs")
|
||||
}
|
||||
|
||||
// Parse destination hub ID.
|
||||
dstData, err := data.GetNextBlock()
|
||||
if err != nil {
|
||||
return nil, terminal.ErrMalformedData.With("failed to parse destination: %w", err)
|
||||
}
|
||||
|
||||
// Parse terminal options.
|
||||
opts, tErr := terminal.ParseTerminalOpts(data)
|
||||
if tErr != nil {
|
||||
return nil, tErr.Wrap("failed to parse terminal options")
|
||||
}
|
||||
|
||||
// Get crane with destination.
|
||||
relayCrane := GetAssignedCrane(string(dstData))
|
||||
if relayCrane == nil {
|
||||
return nil, terminal.ErrHubUnavailable.With("no crane assigned to %q", string(dstData))
|
||||
}
|
||||
|
||||
// TODO: Expand outside of hot path.
|
||||
|
||||
// Create operation and terminal.
|
||||
op := &ExpandOp{
|
||||
opts: opts,
|
||||
dataRelayed: new(uint64),
|
||||
ended: abool.New(),
|
||||
relayTerminal: &ExpansionRelayTerminal{
|
||||
crane: relayCrane,
|
||||
id: relayCrane.getNextTerminalID(),
|
||||
abandoning: abool.New(),
|
||||
},
|
||||
}
|
||||
op.InitOperationBase(t, opID)
|
||||
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
|
||||
op.relayTerminal.op = op
|
||||
|
||||
// Create flow control.
|
||||
switch opts.FlowControl {
|
||||
case terminal.FlowControlDFQ:
|
||||
// Operation
|
||||
op.flowControl = terminal.NewDuplexFlowQueue(op.ctx, opts.FlowControlSize, op.submitBackwardUpstream)
|
||||
op.deliverProxy = op.flowControl.Deliver
|
||||
op.recvProxy = op.flowControl.Receive
|
||||
op.sendProxy = op.submitBackwardFlowControl
|
||||
// Relay Terminal
|
||||
op.relayTerminal.flowControl = terminal.NewDuplexFlowQueue(op.ctx, opts.FlowControlSize, op.submitForwardUpstream)
|
||||
op.relayTerminal.deliverProxy = op.relayTerminal.flowControl.Deliver
|
||||
op.relayTerminal.recvProxy = op.relayTerminal.flowControl.Receive
|
||||
op.relayTerminal.sendProxy = op.submitForwardFlowControl
|
||||
case terminal.FlowControlNone:
|
||||
// Operation
|
||||
deliverToOp := make(chan *terminal.Msg, opts.FlowControlSize)
|
||||
op.deliverProxy = terminal.MakeDirectDeliveryDeliverFunc(op.ctx, deliverToOp)
|
||||
op.recvProxy = terminal.MakeDirectDeliveryRecvFunc(deliverToOp)
|
||||
op.sendProxy = op.submitBackwardUpstream
|
||||
// Relay Terminal
|
||||
deliverToRelay := make(chan *terminal.Msg, opts.FlowControlSize)
|
||||
op.relayTerminal.deliverProxy = terminal.MakeDirectDeliveryDeliverFunc(op.ctx, deliverToRelay)
|
||||
op.relayTerminal.recvProxy = terminal.MakeDirectDeliveryRecvFunc(deliverToRelay)
|
||||
op.relayTerminal.sendProxy = op.submitForwardUpstream
|
||||
case terminal.FlowControlDefault:
|
||||
fallthrough
|
||||
default:
|
||||
return nil, terminal.ErrInternalError.With("unknown flow control type %d", opts.FlowControl)
|
||||
}
|
||||
|
||||
// Establish terminal on destination.
|
||||
newInitData, tErr := opts.Pack()
|
||||
if tErr != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to re-pack options: %w", err)
|
||||
}
|
||||
tErr = op.relayTerminal.crane.EstablishNewTerminal(op.relayTerminal, newInitData)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
// Start workers.
|
||||
module.StartWorker("expand op forward relay", op.forwardHandler)
|
||||
module.StartWorker("expand op backward relay", op.backwardHandler)
|
||||
if op.flowControl != nil {
|
||||
op.flowControl.StartWorkers(module, "expand op")
|
||||
}
|
||||
if op.relayTerminal.flowControl != nil {
|
||||
op.relayTerminal.flowControl.StartWorkers(module, "expand op terminal")
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (op *ExpandOp) submitForwardFlowControl(msg *terminal.Msg, timeout time.Duration) {
|
||||
err := op.relayTerminal.flowControl.Send(msg, timeout)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
op.Stop(op, err.Wrap("failed to submit to forward flow control"))
|
||||
}
|
||||
}
|
||||
|
||||
func (op *ExpandOp) submitBackwardFlowControl(msg *terminal.Msg, timeout time.Duration) {
|
||||
err := op.flowControl.Send(msg, timeout)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
op.Stop(op, err.Wrap("failed to submit to backward flow control"))
|
||||
}
|
||||
}
|
||||
|
||||
func (op *ExpandOp) submitForwardUpstream(msg *terminal.Msg, timeout time.Duration) {
|
||||
msg.FlowID = op.relayTerminal.id
|
||||
if msg.Unit.IsHighPriority() && op.opts.UsePriorityDataMsgs {
|
||||
msg.Type = terminal.MsgTypePriorityData
|
||||
} else {
|
||||
msg.Type = terminal.MsgTypeData
|
||||
}
|
||||
err := op.relayTerminal.crane.Send(msg, timeout)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
op.Stop(op, err.Wrap("failed to submit to forward upstream"))
|
||||
}
|
||||
}
|
||||
|
||||
func (op *ExpandOp) submitBackwardUpstream(msg *terminal.Msg, timeout time.Duration) {
|
||||
msg.FlowID = op.relayTerminal.id
|
||||
if msg.Unit.IsHighPriority() && op.opts.UsePriorityDataMsgs {
|
||||
msg.Type = terminal.MsgTypePriorityData
|
||||
} else {
|
||||
msg.Type = terminal.MsgTypeData
|
||||
msg.Unit.RemovePriority()
|
||||
}
|
||||
// Note: op.Send() will transform high priority units to priority data msgs.
|
||||
err := op.Send(msg, timeout)
|
||||
if err != nil {
|
||||
msg.Finish()
|
||||
op.Stop(op, err.Wrap("failed to submit to backward upstream"))
|
||||
}
|
||||
}
|
||||
|
||||
func (op *ExpandOp) forwardHandler(_ context.Context) error {
|
||||
// Metrics setup and submitting.
|
||||
atomic.AddInt64(activeExpandOps, 1)
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
atomic.AddInt64(activeExpandOps, -1)
|
||||
expandOpDurationHistogram.UpdateDuration(started)
|
||||
expandOpRelayedDataHistogram.Update(float64(atomic.LoadUint64(op.dataRelayed)))
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-op.recvProxy():
|
||||
// Debugging:
|
||||
// log.Debugf("spn/testing: forwarding at %s: %s", op.FmtID(), spew.Sdump(c.CompileData()))
|
||||
|
||||
// Wait for processing slot.
|
||||
msg.Unit.WaitForSlot()
|
||||
|
||||
// Count relayed data for metrics.
|
||||
atomic.AddUint64(op.dataRelayed, uint64(msg.Data.Length()))
|
||||
|
||||
// Receive data from the origin and forward it to the relay.
|
||||
op.relayTerminal.sendProxy(msg, 1*time.Minute)
|
||||
|
||||
case <-op.ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (op *ExpandOp) backwardHandler(_ context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case msg := <-op.relayTerminal.recvProxy():
|
||||
// Debugging:
|
||||
// log.Debugf("spn/testing: backwarding at %s: %s", op.FmtID(), spew.Sdump(c.CompileData()))
|
||||
|
||||
// Wait for processing slot.
|
||||
msg.Unit.WaitForSlot()
|
||||
|
||||
// Count relayed data for metrics.
|
||||
atomic.AddUint64(op.dataRelayed, uint64(msg.Data.Length()))
|
||||
|
||||
// Receive data from the relay and forward it to the origin.
|
||||
op.sendProxy(msg, 1*time.Minute)
|
||||
|
||||
case <-op.ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *ExpandOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
// Flush all messages before stopping.
|
||||
op.Flush(1 * time.Minute)
|
||||
op.relayTerminal.Flush(1 * time.Minute)
|
||||
|
||||
// Stop connected workers.
|
||||
op.cancelCtx()
|
||||
|
||||
// Abandon connected terminal.
|
||||
op.relayTerminal.Abandon(nil)
|
||||
|
||||
// Add context to error.
|
||||
if err.IsError() {
|
||||
return err.Wrap("relay operation failed with")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon().
|
||||
func (t *ExpansionRelayTerminal) Abandon(err *terminal.Error) {
|
||||
if t.abandoning.SetToIf(false, true) {
|
||||
module.StartWorker("terminal abandon procedure", func(_ context.Context) error {
|
||||
t.handleAbandonProcedure(err)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAbandon gives the terminal the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
func (t *ExpansionRelayTerminal) HandleAbandon(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
// Stop the connected relay operation.
|
||||
t.op.Stop(t.op, err)
|
||||
|
||||
// Add context to error.
|
||||
if err.IsError() {
|
||||
return err.Wrap("relay terminal failed with")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleDestruction gives the terminal the ability to clean up.
|
||||
// The terminal has already fully shut down at this point.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
func (t *ExpansionRelayTerminal) HandleDestruction(err *terminal.Error) {}
|
||||
|
||||
func (t *ExpansionRelayTerminal) handleAbandonProcedure(err *terminal.Error) {
|
||||
// Call operation stop handle function for proper shutdown cleaning up.
|
||||
err = t.HandleAbandon(err)
|
||||
|
||||
// Flush all messages before stopping.
|
||||
t.Flush(1 * time.Minute)
|
||||
|
||||
// Send error to the connected Operation, if the error is internal.
|
||||
if !err.IsExternal() {
|
||||
if err == nil {
|
||||
err = terminal.ErrStopping
|
||||
}
|
||||
|
||||
msg := terminal.NewMsg(err.Pack())
|
||||
msg.FlowID = t.ID()
|
||||
msg.Type = terminal.MsgTypeStop
|
||||
t.op.submitForwardUpstream(msg, 1*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// FmtID returns the expansion ID hierarchy.
|
||||
func (op *ExpandOp) FmtID() string {
|
||||
return fmt.Sprintf("%s>%d <r> %s#%d", op.Terminal().FmtID(), op.ID(), op.relayTerminal.crane.ID, op.relayTerminal.id)
|
||||
}
|
||||
|
||||
// FmtID returns the expansion ID hierarchy.
|
||||
func (t *ExpansionRelayTerminal) FmtID() string {
|
||||
return fmt.Sprintf("%s#%d", t.crane.ID, t.id)
|
||||
}
|
||||
298
spn/docks/op_latency.go
Normal file
298
spn/docks/op_latency.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/rng"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const (
|
||||
// LatencyTestOpType is the type ID of the latency test operation.
|
||||
LatencyTestOpType = "latency"
|
||||
|
||||
latencyPingRequest = 1
|
||||
latencyPingResponse = 2
|
||||
|
||||
latencyTestNonceSize = 16
|
||||
latencyTestRuns = 10
|
||||
)
|
||||
|
||||
var (
|
||||
latencyTestPauseDuration = 1 * time.Second
|
||||
latencyTestOpTimeout = latencyTestRuns * latencyTestPauseDuration * 3
|
||||
)
|
||||
|
||||
// LatencyTestOp is used to measure latency.
|
||||
type LatencyTestOp struct {
|
||||
terminal.OperationBase
|
||||
}
|
||||
|
||||
// LatencyTestClientOp is the client version of LatencyTestOp.
|
||||
type LatencyTestClientOp struct {
|
||||
LatencyTestOp
|
||||
|
||||
lastPingSentAt time.Time
|
||||
lastPingNonce []byte
|
||||
measuredLatencies []time.Duration
|
||||
responses chan *terminal.Msg
|
||||
testResult time.Duration
|
||||
|
||||
result chan *terminal.Error
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *LatencyTestOp) Type() string {
|
||||
return LatencyTestOpType
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: LatencyTestOpType,
|
||||
Requires: terminal.IsCraneController,
|
||||
Start: startLatencyTestOp,
|
||||
})
|
||||
}
|
||||
|
||||
// NewLatencyTestOp runs a latency test.
|
||||
func NewLatencyTestOp(t terminal.Terminal) (*LatencyTestClientOp, *terminal.Error) {
|
||||
// Create and init.
|
||||
op := &LatencyTestClientOp{
|
||||
responses: make(chan *terminal.Msg),
|
||||
measuredLatencies: make([]time.Duration, 0, latencyTestRuns),
|
||||
result: make(chan *terminal.Error, 1),
|
||||
}
|
||||
|
||||
// Make ping request.
|
||||
pingRequest, err := op.createPingRequest()
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("%w", err)
|
||||
}
|
||||
|
||||
// Send ping.
|
||||
tErr := t.StartOperation(op, pingRequest, 1*time.Second)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
// Start handler.
|
||||
module.StartWorker("op latency handler", op.handler)
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (op *LatencyTestClientOp) handler(ctx context.Context) error {
|
||||
returnErr := terminal.ErrStopping
|
||||
defer func() {
|
||||
// Linters don't get that returnErr is used when directly used as defer.
|
||||
op.Stop(op, returnErr)
|
||||
}()
|
||||
|
||||
var nextTest <-chan time.Time
|
||||
opTimeout := time.After(latencyTestOpTimeout)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
|
||||
case <-opTimeout:
|
||||
return nil
|
||||
|
||||
case <-nextTest:
|
||||
// Create ping request msg.
|
||||
pingRequest, err := op.createPingRequest()
|
||||
if err != nil {
|
||||
returnErr = terminal.ErrInternalError.With("%w", err)
|
||||
return nil
|
||||
}
|
||||
msg := op.NewEmptyMsg()
|
||||
msg.Unit.MakeHighPriority()
|
||||
msg.Data = pingRequest
|
||||
|
||||
// Send it.
|
||||
tErr := op.Send(msg, latencyTestOpTimeout)
|
||||
if tErr != nil {
|
||||
returnErr = tErr.Wrap("failed to send ping request")
|
||||
return nil
|
||||
}
|
||||
op.Flush(1 * time.Second)
|
||||
|
||||
nextTest = nil
|
||||
|
||||
case msg := <-op.responses:
|
||||
// Check if the op ended.
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle response
|
||||
tErr := op.handleResponse(msg)
|
||||
if tErr != nil {
|
||||
returnErr = tErr
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
|
||||
// Check if we have enough latency tests.
|
||||
if len(op.measuredLatencies) >= latencyTestRuns {
|
||||
returnErr = op.reportMeasuredLatencies()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Schedule next latency test, if not yet scheduled.
|
||||
if nextTest == nil {
|
||||
nextTest = time.After(latencyTestPauseDuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (op *LatencyTestClientOp) createPingRequest() (*container.Container, error) {
|
||||
// Generate nonce.
|
||||
nonce, err := rng.Bytes(latencyTestNonceSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ping nonce")
|
||||
}
|
||||
|
||||
// Set client request state.
|
||||
op.lastPingSentAt = time.Now()
|
||||
op.lastPingNonce = nonce
|
||||
|
||||
return container.New(
|
||||
varint.Pack8(latencyPingRequest),
|
||||
nonce,
|
||||
), nil
|
||||
}
|
||||
|
||||
func (op *LatencyTestClientOp) handleResponse(msg *terminal.Msg) *terminal.Error {
|
||||
defer msg.Finish()
|
||||
|
||||
rType, err := msg.Data.GetNextN8()
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to get response type: %w", err)
|
||||
}
|
||||
|
||||
switch rType {
|
||||
case latencyPingResponse:
|
||||
// Check if the ping nonce matches.
|
||||
if !bytes.Equal(op.lastPingNonce, msg.Data.CompileData()) {
|
||||
return terminal.ErrIntegrity.With("ping nonce mismatch")
|
||||
}
|
||||
op.lastPingNonce = nil
|
||||
// Save latency.
|
||||
op.measuredLatencies = append(op.measuredLatencies, time.Since(op.lastPingSentAt))
|
||||
|
||||
return nil
|
||||
default:
|
||||
return terminal.ErrIncorrectUsage.With("unknown response type")
|
||||
}
|
||||
}
|
||||
|
||||
func (op *LatencyTestClientOp) reportMeasuredLatencies() *terminal.Error {
|
||||
// Find lowest value.
|
||||
lowestLatency := time.Hour
|
||||
for _, latency := range op.measuredLatencies {
|
||||
if latency < lowestLatency {
|
||||
lowestLatency = latency
|
||||
}
|
||||
}
|
||||
op.testResult = lowestLatency
|
||||
|
||||
// Save the result to the crane.
|
||||
if controller, ok := op.Terminal().(*CraneControllerTerminal); ok {
|
||||
if controller.Crane.ConnectedHub != nil {
|
||||
controller.Crane.ConnectedHub.GetMeasurements().SetLatency(op.testResult)
|
||||
log.Infof("spn/docks: measured latency to %s: %s", controller.Crane.ConnectedHub, op.testResult)
|
||||
return nil
|
||||
} else if controller.Crane.IsMine() {
|
||||
return terminal.ErrInternalError.With("latency operation was run on %s without a connected hub set", controller.Crane)
|
||||
}
|
||||
} else if !runningTests {
|
||||
return terminal.ErrInternalError.With("latency operation was run on terminal that is not a crane controller, but %T", op.Terminal())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
func (op *LatencyTestClientOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
// Optimized delivery with 1s timeout.
|
||||
select {
|
||||
case op.responses <- msg:
|
||||
default:
|
||||
select {
|
||||
case op.responses <- msg:
|
||||
case <-time.After(1 * time.Second):
|
||||
return terminal.ErrTimeout
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *LatencyTestClientOp) HandleStop(tErr *terminal.Error) (errorToSend *terminal.Error) {
|
||||
close(op.responses)
|
||||
select {
|
||||
case op.result <- tErr:
|
||||
default:
|
||||
}
|
||||
return tErr
|
||||
}
|
||||
|
||||
// Result returns the result (end error) of the operation.
|
||||
func (op *LatencyTestClientOp) Result() <-chan *terminal.Error {
|
||||
return op.result
|
||||
}
|
||||
|
||||
func startLatencyTestOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Create operation.
|
||||
op := &LatencyTestOp{}
|
||||
op.InitOperationBase(t, opID)
|
||||
|
||||
// Handle first request.
|
||||
msg := op.NewEmptyMsg()
|
||||
msg.Data = data
|
||||
tErr := op.Deliver(msg)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
func (op *LatencyTestOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
// Get request type.
|
||||
rType, err := msg.Data.GetNextN8()
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to get response type: %w", err)
|
||||
}
|
||||
|
||||
switch rType {
|
||||
case latencyPingRequest:
|
||||
// Keep the nonce and just replace the msg type.
|
||||
msg.Data.PrependNumber(latencyPingResponse)
|
||||
msg.Type = terminal.MsgTypeData
|
||||
msg.Unit.ReUse()
|
||||
msg.Unit.MakeHighPriority()
|
||||
|
||||
// Send response.
|
||||
tErr := op.Send(msg, latencyTestOpTimeout)
|
||||
if tErr != nil {
|
||||
return tErr.Wrap("failed to send ping response")
|
||||
}
|
||||
op.Flush(1 * time.Second)
|
||||
|
||||
return nil
|
||||
|
||||
default:
|
||||
return terminal.ErrIncorrectUsage.With("unknown request type")
|
||||
}
|
||||
}
|
||||
59
spn/docks/op_latency_test.go
Normal file
59
spn/docks/op_latency_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
func TestLatencyOp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
latTestDelay = 10 * time.Millisecond
|
||||
latTestQueueSize uint32 = 10
|
||||
)
|
||||
|
||||
// Reduce waiting time.
|
||||
latencyTestPauseDuration = 100 * time.Millisecond
|
||||
|
||||
// Create test terminal pair.
|
||||
a, b, err := terminal.NewSimpleTestTerminalPair(
|
||||
latTestDelay,
|
||||
int(latTestQueueSize),
|
||||
&terminal.TerminalOpts{
|
||||
FlowControl: terminal.FlowControlNone,
|
||||
FlowControlSize: latTestQueueSize,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test terminal pair: %s", err)
|
||||
}
|
||||
|
||||
// Grant permission for op on remote terminal and start op.
|
||||
b.GrantPermission(terminal.IsCraneController)
|
||||
op, tErr := NewLatencyTestOp(a)
|
||||
if tErr != nil {
|
||||
t.Fatalf("failed to start op: %s", err)
|
||||
}
|
||||
|
||||
// Wait for result and check error.
|
||||
tErr = <-op.Result()
|
||||
if tErr.IsError() {
|
||||
t.Fatalf("op failed: %s", tErr)
|
||||
}
|
||||
t.Logf("measured latency: %f ms", float64(op.testResult)/float64(time.Millisecond))
|
||||
|
||||
// Calculate expected latency.
|
||||
expectedLatency := float64(latTestDelay * 2)
|
||||
t.Logf("expected latency: %f ms", expectedLatency/float64(time.Millisecond))
|
||||
|
||||
// Check if measured latency is within parameters.
|
||||
if float64(op.testResult) > expectedLatency*1.2 {
|
||||
t.Fatal("measured latency too high")
|
||||
}
|
||||
if float64(op.testResult) < expectedLatency*0.9 {
|
||||
t.Fatal("measured latency too low")
|
||||
}
|
||||
}
|
||||
150
spn/docks/op_sync_state.go
Normal file
150
spn/docks/op_sync_state.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portmaster/spn/conf"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// SyncStateOpType is the type ID of the sync state operation.
|
||||
const SyncStateOpType = "sync/state"
|
||||
|
||||
// SyncStateOp is used to sync the crane state.
|
||||
type SyncStateOp struct {
|
||||
terminal.OneOffOperationBase
|
||||
}
|
||||
|
||||
// SyncStateMessage holds the sync data.
|
||||
type SyncStateMessage struct {
|
||||
Stopping bool
|
||||
RequestStopping bool
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *SyncStateOp) Type() string {
|
||||
return SyncStateOpType
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: SyncStateOpType,
|
||||
Requires: terminal.IsCraneController,
|
||||
Start: runSyncStateOp,
|
||||
})
|
||||
}
|
||||
|
||||
// startSyncStateOp starts a worker that runs the sync state operation.
|
||||
func (crane *Crane) startSyncStateOp() {
|
||||
module.StartWorker("sync crane state", func(ctx context.Context) error {
|
||||
tErr := crane.Controller.SyncState(ctx)
|
||||
if tErr != nil {
|
||||
return tErr
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SyncState runs a sync state operation.
|
||||
func (controller *CraneControllerTerminal) SyncState(ctx context.Context) *terminal.Error {
|
||||
// Check if we are a public Hub, whether we own the crane and whether the lane is public too.
|
||||
if !conf.PublicHub() || !controller.Crane.Public() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create and init.
|
||||
op := &SyncStateOp{}
|
||||
op.Init()
|
||||
|
||||
// Get optimization states.
|
||||
requestStopping := false
|
||||
func() {
|
||||
controller.Crane.NetState.lock.Lock()
|
||||
defer controller.Crane.NetState.lock.Unlock()
|
||||
|
||||
requestStopping = controller.Crane.NetState.stoppingRequested
|
||||
}()
|
||||
|
||||
// Create sync message.
|
||||
msg := &SyncStateMessage{
|
||||
Stopping: controller.Crane.stopping.IsSet(),
|
||||
RequestStopping: requestStopping,
|
||||
}
|
||||
data, err := dsd.Dump(msg, dsd.CBOR)
|
||||
if err != nil {
|
||||
return terminal.ErrInternalError.With("%w", err)
|
||||
}
|
||||
|
||||
// Send message.
|
||||
tErr := controller.StartOperation(op, container.New(data), 30*time.Second)
|
||||
if tErr != nil {
|
||||
return tErr
|
||||
}
|
||||
|
||||
// Wait for reply
|
||||
select {
|
||||
case tErr = <-op.Result:
|
||||
if tErr.IsError() {
|
||||
return tErr
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(1 * time.Minute):
|
||||
return terminal.ErrTimeout.With("timed out while waiting for sync crane result")
|
||||
}
|
||||
}
|
||||
|
||||
func runSyncStateOp(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Check if we are a on a crane controller.
|
||||
var ok bool
|
||||
var controller *CraneControllerTerminal
|
||||
if controller, ok = t.(*CraneControllerTerminal); !ok {
|
||||
return nil, terminal.ErrIncorrectUsage.With("can only be used with a crane controller")
|
||||
}
|
||||
|
||||
// Check if we are a public Hub and whether the lane is public too.
|
||||
if !conf.PublicHub() || !controller.Crane.Public() {
|
||||
return nil, terminal.ErrPermissionDenied.With("only public lanes can sync crane status")
|
||||
}
|
||||
|
||||
// Load message.
|
||||
syncState := &SyncStateMessage{}
|
||||
_, err := dsd.Load(data.CompileData(), syncState)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrMalformedData.With("failed to load sync state message: %w", err)
|
||||
}
|
||||
|
||||
// Apply optimization state.
|
||||
controller.Crane.NetState.lock.Lock()
|
||||
defer controller.Crane.NetState.lock.Unlock()
|
||||
controller.Crane.NetState.stoppingRequestedByPeer = syncState.RequestStopping
|
||||
|
||||
// Apply crane state only when we don't own the crane.
|
||||
if !controller.Crane.IsMine() {
|
||||
// Apply sync state.
|
||||
var changed bool
|
||||
if syncState.Stopping {
|
||||
if controller.Crane.stopping.SetToIf(false, true) {
|
||||
controller.Crane.NetState.markedStoppingAt = time.Now()
|
||||
changed = true
|
||||
}
|
||||
} else {
|
||||
if controller.Crane.stopping.SetToIf(true, false) {
|
||||
controller.Crane.NetState.markedStoppingAt = time.Time{}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Notify of change.
|
||||
if changed {
|
||||
controller.Crane.NotifyUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
135
spn/docks/op_whoami.go
Normal file
135
spn/docks/op_whoami.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portbase/formats/dsd"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const (
|
||||
// WhoAmIType is the type ID of the latency test operation.
|
||||
WhoAmIType = "whoami"
|
||||
|
||||
whoAmITimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// WhoAmIOp is used to request some metadata about the other side.
|
||||
type WhoAmIOp struct {
|
||||
terminal.OneOffOperationBase
|
||||
|
||||
response *WhoAmIResponse
|
||||
}
|
||||
|
||||
// WhoAmIResponse is a whoami response.
|
||||
type WhoAmIResponse struct {
|
||||
// Timestamp in nanoseconds
|
||||
Timestamp int64 `cbor:"t,omitempty" json:"t,omitempty"`
|
||||
|
||||
// Addr is the remote address as reported by the crane terminal (IP and port).
|
||||
Addr string `cbor:"a,omitempty" json:"a,omitempty"`
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *WhoAmIOp) Type() string {
|
||||
return WhoAmIType
|
||||
}
|
||||
|
||||
func init() {
|
||||
terminal.RegisterOpType(terminal.OperationFactory{
|
||||
Type: WhoAmIType,
|
||||
Start: startWhoAmI,
|
||||
})
|
||||
}
|
||||
|
||||
// WhoAmI executes a whoami operation and returns the response.
|
||||
func WhoAmI(t terminal.Terminal) (*WhoAmIResponse, *terminal.Error) {
|
||||
whoami, err := NewWhoAmIOp(t)
|
||||
if err.IsError() {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait for response.
|
||||
select {
|
||||
case tErr := <-whoami.Result:
|
||||
if tErr.IsError() {
|
||||
return nil, tErr
|
||||
}
|
||||
return whoami.response, nil
|
||||
case <-time.After(whoAmITimeout * 2):
|
||||
return nil, terminal.ErrTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// NewWhoAmIOp starts a new whoami operation.
|
||||
func NewWhoAmIOp(t terminal.Terminal) (*WhoAmIOp, *terminal.Error) {
|
||||
// Create operation and init.
|
||||
op := &WhoAmIOp{}
|
||||
op.OneOffOperationBase.Init()
|
||||
|
||||
// Send ping.
|
||||
tErr := t.StartOperation(op, nil, whoAmITimeout)
|
||||
if tErr != nil {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
func (op *WhoAmIOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
defer msg.Finish()
|
||||
|
||||
// Parse response.
|
||||
response := &WhoAmIResponse{}
|
||||
_, err := dsd.Load(msg.Data.CompileData(), response)
|
||||
if err != nil {
|
||||
return terminal.ErrMalformedData.With("failed to parse ping response: %w", err)
|
||||
}
|
||||
|
||||
op.response = response
|
||||
return terminal.ErrExplicitAck
|
||||
}
|
||||
|
||||
func startWhoAmI(t terminal.Terminal, opID uint32, data *container.Container) (terminal.Operation, *terminal.Error) {
|
||||
// Get crane terminal, if available.
|
||||
ct, _ := t.(*CraneTerminal)
|
||||
|
||||
// Create response.
|
||||
r := &WhoAmIResponse{
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
}
|
||||
if ct != nil {
|
||||
r.Addr = ct.RemoteAddr().String()
|
||||
}
|
||||
response, err := dsd.Dump(r, dsd.CBOR)
|
||||
if err != nil {
|
||||
return nil, terminal.ErrInternalError.With("failed to create whoami response: %w", err)
|
||||
}
|
||||
|
||||
// Send response.
|
||||
msg := terminal.NewMsg(response)
|
||||
msg.FlowID = opID
|
||||
msg.Unit.MakeHighPriority()
|
||||
if terminal.UsePriorityDataMsgs {
|
||||
msg.Type = terminal.MsgTypePriorityData
|
||||
}
|
||||
tErr := t.Send(msg, whoAmITimeout)
|
||||
if tErr != nil {
|
||||
// Finish message unit on failure.
|
||||
msg.Finish()
|
||||
return nil, tErr.With("failed to send ping response")
|
||||
}
|
||||
|
||||
// Operation is just one response and finished successfully.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *WhoAmIOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
// Continue with usual handling of inherited base.
|
||||
return op.OneOffOperationBase.HandleStop(err)
|
||||
}
|
||||
24
spn/docks/op_whoami_test.go
Normal file
24
spn/docks/op_whoami_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
func TestWhoAmIOp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create test terminal pair.
|
||||
a, _, err := terminal.NewSimpleTestTerminalPair(0, 0, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test terminal pair: %s", err)
|
||||
}
|
||||
|
||||
// Run op.
|
||||
resp, tErr := WhoAmI(a)
|
||||
if tErr.IsError() {
|
||||
t.Fatal(tErr)
|
||||
}
|
||||
t.Logf("whoami: %+v", resp)
|
||||
}
|
||||
150
spn/docks/terminal_expansion.go
Normal file
150
spn/docks/terminal_expansion.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
// ExpansionTerminal is used for expanding to another Hub.
|
||||
type ExpansionTerminal struct {
|
||||
*terminal.TerminalBase
|
||||
|
||||
relayOp *ExpansionTerminalRelayOp
|
||||
|
||||
changeNotifyFuncReady *abool.AtomicBool
|
||||
changeNotifyFunc func()
|
||||
|
||||
reachableChecked time.Time
|
||||
reachableLock sync.Mutex
|
||||
}
|
||||
|
||||
// ExpansionTerminalRelayOp is the operation that connects to the relay.
|
||||
type ExpansionTerminalRelayOp struct {
|
||||
terminal.OperationBase
|
||||
|
||||
expansionTerminal *ExpansionTerminal
|
||||
}
|
||||
|
||||
// Type returns the type ID.
|
||||
func (op *ExpansionTerminalRelayOp) Type() string {
|
||||
return ExpandOpType
|
||||
}
|
||||
|
||||
// ExpandTo initiates an expansion.
|
||||
func ExpandTo(from terminal.Terminal, routeTo string, encryptFor *hub.Hub) (*ExpansionTerminal, *terminal.Error) {
|
||||
// First, create the local endpoint terminal to generate the init data.
|
||||
|
||||
// Create options and bare expansion terminal.
|
||||
opts := terminal.DefaultExpansionTerminalOpts()
|
||||
opts.Encrypt = encryptFor != nil
|
||||
expansion := &ExpansionTerminal{
|
||||
changeNotifyFuncReady: abool.New(),
|
||||
}
|
||||
expansion.relayOp = &ExpansionTerminalRelayOp{
|
||||
expansionTerminal: expansion,
|
||||
}
|
||||
|
||||
// Create base terminal for expansion.
|
||||
base, initData, tErr := terminal.NewLocalBaseTerminal(
|
||||
module.Ctx,
|
||||
0, // Ignore; The ID of the operation is used for communication.
|
||||
from.FmtID(),
|
||||
encryptFor,
|
||||
opts,
|
||||
expansion.relayOp,
|
||||
)
|
||||
if tErr != nil {
|
||||
return nil, tErr.Wrap("failed to create expansion terminal base")
|
||||
}
|
||||
expansion.TerminalBase = base
|
||||
base.SetTerminalExtension(expansion)
|
||||
base.SetTimeout(defaultTerminalIdleTimeout)
|
||||
|
||||
// Second, start the actual relay operation.
|
||||
|
||||
// Create setup message for relay operation.
|
||||
opInitData := container.New()
|
||||
opInitData.AppendAsBlock([]byte(routeTo))
|
||||
opInitData.AppendContainer(initData)
|
||||
|
||||
// Start relay operation on connected Hub.
|
||||
tErr = from.StartOperation(expansion.relayOp, opInitData, 5*time.Second)
|
||||
if tErr != nil {
|
||||
return nil, tErr.Wrap("failed to start expansion operation")
|
||||
}
|
||||
|
||||
// Start Workers.
|
||||
base.StartWorkers(module, "expansion terminal")
|
||||
|
||||
return expansion, nil
|
||||
}
|
||||
|
||||
// SetChangeNotifyFunc sets a callback function that is called when the terminal state changes.
|
||||
func (t *ExpansionTerminal) SetChangeNotifyFunc(f func()) {
|
||||
if t.changeNotifyFuncReady.IsSet() {
|
||||
return
|
||||
}
|
||||
t.changeNotifyFunc = f
|
||||
t.changeNotifyFuncReady.Set()
|
||||
}
|
||||
|
||||
// NeedsReachableCheck returns whether the terminal should be checked if it is
|
||||
// reachable via the existing network internal relayed connection.
|
||||
func (t *ExpansionTerminal) NeedsReachableCheck(maxCheckAge time.Duration) bool {
|
||||
t.reachableLock.Lock()
|
||||
defer t.reachableLock.Unlock()
|
||||
|
||||
return time.Since(t.reachableChecked) > maxCheckAge
|
||||
}
|
||||
|
||||
// MarkReachable marks the terminal as reachable via the existing network
|
||||
// internal relayed connection.
|
||||
func (t *ExpansionTerminal) MarkReachable() {
|
||||
t.reachableLock.Lock()
|
||||
defer t.reachableLock.Unlock()
|
||||
|
||||
t.reachableChecked = time.Now()
|
||||
}
|
||||
|
||||
// HandleDestruction gives the terminal the ability to clean up.
|
||||
// The terminal has already fully shut down at this point.
|
||||
// Should never be called directly. Call Abandon() instead.
|
||||
func (t *ExpansionTerminal) HandleDestruction(err *terminal.Error) {
|
||||
// Trigger update of connected Pin.
|
||||
if t.changeNotifyFuncReady.IsSet() {
|
||||
t.changeNotifyFunc()
|
||||
}
|
||||
|
||||
// Stop the relay operation.
|
||||
// The error message is arlready sent by the terminal.
|
||||
t.relayOp.Stop(t.relayOp, nil)
|
||||
}
|
||||
|
||||
// CustomIDFormat formats the terminal ID.
|
||||
func (t *ExpansionTerminal) CustomIDFormat() string {
|
||||
return fmt.Sprintf("%s~%d", t.relayOp.Terminal().FmtID(), t.relayOp.ID())
|
||||
}
|
||||
|
||||
// Deliver delivers a message to the operation.
|
||||
func (op *ExpansionTerminalRelayOp) Deliver(msg *terminal.Msg) *terminal.Error {
|
||||
// Proxy directly to expansion terminal.
|
||||
return op.expansionTerminal.Deliver(msg)
|
||||
}
|
||||
|
||||
// HandleStop gives the operation the ability to cleanly shut down.
|
||||
// The returned error is the error to send to the other side.
|
||||
// Should never be called directly. Call Stop() instead.
|
||||
func (op *ExpansionTerminalRelayOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error) {
|
||||
// Stop the expansion terminal.
|
||||
// The error message will be sent by the operation.
|
||||
op.expansionTerminal.Abandon(nil)
|
||||
|
||||
return err
|
||||
}
|
||||
305
spn/docks/terminal_expansion_test.go
Normal file
305
spn/docks/terminal_expansion_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package docks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/spn/access"
|
||||
"github.com/safing/portmaster/spn/cabin"
|
||||
"github.com/safing/portmaster/spn/hub"
|
||||
"github.com/safing/portmaster/spn/ships"
|
||||
"github.com/safing/portmaster/spn/terminal"
|
||||
)
|
||||
|
||||
const defaultTestQueueSize = 200
|
||||
|
||||
func TestExpansion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test without and with encryption.
|
||||
for _, encrypt := range []bool{false, true} {
|
||||
// Test down/up separately and in parallel.
|
||||
for _, parallel := range []bool{false, true} {
|
||||
// Test with different flow controls.
|
||||
for _, fc := range []struct {
|
||||
flowControl terminal.FlowControlType
|
||||
flowControlSize uint32
|
||||
}{
|
||||
{
|
||||
flowControl: terminal.FlowControlNone,
|
||||
flowControlSize: 5,
|
||||
},
|
||||
{
|
||||
flowControl: terminal.FlowControlDFQ,
|
||||
flowControlSize: defaultTestQueueSize,
|
||||
},
|
||||
} {
|
||||
// Run tests with combined options.
|
||||
testExpansion(
|
||||
t,
|
||||
"expansion-hop-test",
|
||||
&terminal.TerminalOpts{
|
||||
Encrypt: encrypt,
|
||||
Padding: 8,
|
||||
FlowControl: fc.flowControl,
|
||||
FlowControlSize: fc.flowControlSize,
|
||||
},
|
||||
defaultTestQueueSize,
|
||||
defaultTestQueueSize,
|
||||
parallel,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stressTestOpts := &terminal.TerminalOpts{
|
||||
Encrypt: true,
|
||||
Padding: 8,
|
||||
FlowControl: terminal.FlowControlDFQ,
|
||||
FlowControlSize: defaultTestQueueSize,
|
||||
}
|
||||
testExpansion(t, "expansion-stress-test-down", stressTestOpts, defaultTestQueueSize*100, 0, false)
|
||||
testExpansion(t, "expansion-stress-test-up", stressTestOpts, 0, defaultTestQueueSize*100, false)
|
||||
testExpansion(t, "expansion-stress-test-duplex", stressTestOpts, defaultTestQueueSize*100, defaultTestQueueSize*100, false)
|
||||
}
|
||||
|
||||
func testExpansion( //nolint:maintidx,thelper
|
||||
t *testing.T,
|
||||
testID string,
|
||||
terminalOpts *terminal.TerminalOpts,
|
||||
clientCountTo,
|
||||
serverCountTo uint64,
|
||||
inParallel bool,
|
||||
) {
|
||||
testID += fmt.Sprintf(":encrypt=%v,flowType=%d,parallel=%v", terminalOpts.Encrypt, terminalOpts.FlowControl, inParallel)
|
||||
|
||||
var identity2, identity3, identity4 *cabin.Identity
|
||||
var connectedHub2, connectedHub3, connectedHub4 *hub.Hub
|
||||
if terminalOpts.Encrypt {
|
||||
identity2, connectedHub2 = getTestIdentity(t)
|
||||
identity3, connectedHub3 = getTestIdentity(t)
|
||||
identity4, connectedHub4 = getTestIdentity(t)
|
||||
}
|
||||
|
||||
// Build ships and cranes.
|
||||
optimalMinLoadSize = 100
|
||||
ship1to2 := ships.NewTestShip(!terminalOpts.Encrypt, 100)
|
||||
ship2to3 := ships.NewTestShip(!terminalOpts.Encrypt, 100)
|
||||
ship3to4 := ships.NewTestShip(!terminalOpts.Encrypt, 100)
|
||||
|
||||
var crane1, crane2to1, crane2to3, crane3to2, crane3to4, crane4 *Crane
|
||||
var craneWg sync.WaitGroup
|
||||
craneWg.Add(6)
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
crane1, err = NewCrane(ship1to2, connectedHub2, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not create crane1: %s", testID, err))
|
||||
}
|
||||
crane1.ID = "c1"
|
||||
err = crane1.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not start crane1: %s", testID, err))
|
||||
}
|
||||
crane1.ship.MarkPublic()
|
||||
craneWg.Done()
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
crane2to1, err = NewCrane(ship1to2.Reverse(), nil, identity2)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not create crane2to1: %s", testID, err))
|
||||
}
|
||||
crane2to1.ID = "c2to1"
|
||||
err = crane2to1.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not start crane2to1: %s", testID, err))
|
||||
}
|
||||
crane2to1.ship.MarkPublic()
|
||||
craneWg.Done()
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
crane2to3, err = NewCrane(ship2to3, connectedHub3, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not create crane2to3: %s", testID, err))
|
||||
}
|
||||
crane2to3.ID = "c2to3"
|
||||
err = crane2to3.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not start crane2to3: %s", testID, err))
|
||||
}
|
||||
crane2to3.ship.MarkPublic()
|
||||
craneWg.Done()
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
crane3to2, err = NewCrane(ship2to3.Reverse(), nil, identity3)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not create crane3to2: %s", testID, err))
|
||||
}
|
||||
crane3to2.ID = "c3to2"
|
||||
err = crane3to2.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not start crane3to2: %s", testID, err))
|
||||
}
|
||||
crane3to2.ship.MarkPublic()
|
||||
craneWg.Done()
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
crane3to4, err = NewCrane(ship3to4, connectedHub4, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not create crane3to4: %s", testID, err))
|
||||
}
|
||||
crane3to4.ID = "c3to4"
|
||||
err = crane3to4.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not start crane3to4: %s", testID, err))
|
||||
}
|
||||
crane3to4.ship.MarkPublic()
|
||||
craneWg.Done()
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
crane4, err = NewCrane(ship3to4.Reverse(), nil, identity4)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not create crane4: %s", testID, err))
|
||||
}
|
||||
crane4.ID = "c4"
|
||||
err = crane4.Start(module.Ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("expansion test %s could not start crane4: %s", testID, err))
|
||||
}
|
||||
crane4.ship.MarkPublic()
|
||||
craneWg.Done()
|
||||
}()
|
||||
craneWg.Wait()
|
||||
|
||||
// Assign cranes.
|
||||
crane3HubID := testID + "-crane3HubID"
|
||||
AssignCrane(crane3HubID, crane2to3)
|
||||
crane4HubID := testID + "-crane4HubID"
|
||||
AssignCrane(crane4HubID, crane3to4)
|
||||
|
||||
t.Logf("expansion test %s: initial setup complete", testID)
|
||||
|
||||
// Wait async for test to complete, print stack after timeout.
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(30 * time.Second):
|
||||
fmt.Printf("expansion test %s is taking too long, print stack:\n", testID)
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start initial crane.
|
||||
homeTerminal, initData, tErr := NewLocalCraneTerminal(crane1, nil, &terminal.TerminalOpts{})
|
||||
if tErr != nil {
|
||||
t.Fatalf("expansion test %s failed to create home terminal: %s", testID, tErr)
|
||||
}
|
||||
tErr = crane1.EstablishNewTerminal(homeTerminal, initData)
|
||||
if tErr != nil {
|
||||
t.Fatalf("expansion test %s failed to connect home terminal: %s", testID, tErr)
|
||||
}
|
||||
|
||||
t.Logf("expansion test %s: home terminal setup complete", testID)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Start counters for testing.
|
||||
op0, tErr := terminal.NewCounterOp(homeTerminal, terminal.CounterOpts{
|
||||
ClientCountTo: clientCountTo,
|
||||
ServerCountTo: serverCountTo,
|
||||
})
|
||||
if tErr != nil {
|
||||
t.Fatalf("expansion test %s failed to run counter op: %s", testID, tErr)
|
||||
}
|
||||
t.Logf("expansion test %s: home terminal counter setup complete", testID)
|
||||
if !inParallel {
|
||||
op0.Wait()
|
||||
}
|
||||
|
||||
// Start expansion to crane 3.
|
||||
opAuthTo2, tErr := access.AuthorizeToTerminal(homeTerminal)
|
||||
if tErr != nil {
|
||||
t.Fatalf("expansion test %s failed to auth with home terminal: %s", testID, tErr)
|
||||
}
|
||||
tErr = <-opAuthTo2.Result
|
||||
if tErr.IsError() {
|
||||
t.Fatalf("expansion test %s failed to auth with home terminal: %s", testID, tErr)
|
||||
}
|
||||
expansionTerminalTo3, err := ExpandTo(homeTerminal, crane3HubID, connectedHub3)
|
||||
if err != nil {
|
||||
t.Fatalf("expansion test %s failed to expand to %s: %s", testID, crane3HubID, tErr)
|
||||
}
|
||||
|
||||
// Start counters for testing.
|
||||
op1, tErr := terminal.NewCounterOp(expansionTerminalTo3, terminal.CounterOpts{
|
||||
ClientCountTo: clientCountTo,
|
||||
ServerCountTo: serverCountTo,
|
||||
})
|
||||
if tErr != nil {
|
||||
t.Fatalf("expansion test %s failed to run counter op: %s", testID, tErr)
|
||||
}
|
||||
|
||||
t.Logf("expansion test %s: expansion to crane3 and counter setup complete", testID)
|
||||
if !inParallel {
|
||||
op1.Wait()
|
||||
}
|
||||
|
||||
// Start expansion to crane 4.
|
||||
opAuthTo3, tErr := access.AuthorizeToTerminal(expansionTerminalTo3)
|
||||
if tErr != nil {
|
||||
t.Fatalf("expansion test %s failed to auth with extenstion terminal: %s", testID, tErr)
|
||||
}
|
||||
tErr = <-opAuthTo3.Result
|
||||
if tErr.IsError() {
|
||||
t.Fatalf("expansion test %s failed to auth with extenstion terminal: %s", testID, tErr)
|
||||
}
|
||||
|
||||
expansionTerminalTo4, err := ExpandTo(expansionTerminalTo3, crane4HubID, connectedHub4)
|
||||
if err != nil {
|
||||
t.Fatalf("expansion test %s failed to expand to %s: %s", testID, crane4HubID, tErr)
|
||||
}
|
||||
|
||||
// Start counters for testing.
|
||||
op2, tErr := terminal.NewCounterOp(expansionTerminalTo4, terminal.CounterOpts{
|
||||
ClientCountTo: clientCountTo,
|
||||
ServerCountTo: serverCountTo,
|
||||
})
|
||||
if tErr != nil {
|
||||
t.Fatalf("expansion test %s failed to run counter op: %s", testID, tErr)
|
||||
}
|
||||
|
||||
t.Logf("expansion test %s: expansion to crane4 and counter setup complete", testID)
|
||||
op2.Wait()
|
||||
|
||||
// Wait for op1 if not already.
|
||||
if inParallel {
|
||||
op0.Wait()
|
||||
op1.Wait()
|
||||
}
|
||||
|
||||
// Wait for completion.
|
||||
close(finished)
|
||||
|
||||
// Wait a little so that all errors can be propagated, so we can truly see
|
||||
// if we succeeded.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check errors.
|
||||
if op1.Error != nil {
|
||||
t.Fatalf("crane test %s counter op1 failed: %s", testID, op1.Error)
|
||||
}
|
||||
if op2.Error != nil {
|
||||
t.Fatalf("crane test %s counter op2 failed: %s", testID, op2.Error)
|
||||
}
|
||||
}
|
||||
202
spn/hub/database.go
Normal file
202
spn/hub/database.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/iterator"
|
||||
"github.com/safing/portbase/database/query"
|
||||
"github.com/safing/portbase/database/record"
|
||||
)
|
||||
|
||||
var (
|
||||
db = database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
getFromNavigator func(mapName, hubID string) *Hub
|
||||
)
|
||||
|
||||
// MakeHubDBKey makes a hub db key.
|
||||
func MakeHubDBKey(mapName, hubID string) string {
|
||||
return fmt.Sprintf("cache:spn/hubs/%s/%s", mapName, hubID)
|
||||
}
|
||||
|
||||
// MakeHubMsgDBKey makes a hub msg db key.
|
||||
func MakeHubMsgDBKey(mapName string, msgType MsgType, hubID string) string {
|
||||
return fmt.Sprintf("cache:spn/msgs/%s/%s/%s", mapName, msgType, hubID)
|
||||
}
|
||||
|
||||
// SetNavigatorAccess sets a shortcut function to access hubs from the navigator instead of having go through the database.
|
||||
// This also reduces the number of object in RAM and better caches parsed attributes.
|
||||
func SetNavigatorAccess(fn func(mapName, hubID string) *Hub) {
|
||||
if getFromNavigator == nil {
|
||||
getFromNavigator = fn
|
||||
}
|
||||
}
|
||||
|
||||
// GetHub get a Hub from the database - or the navigator, if configured.
|
||||
func GetHub(mapName string, hubID string) (*Hub, error) {
|
||||
if getFromNavigator != nil {
|
||||
hub := getFromNavigator(mapName, hubID)
|
||||
if hub != nil {
|
||||
return hub, nil
|
||||
}
|
||||
}
|
||||
|
||||
return GetHubByKey(MakeHubDBKey(mapName, hubID))
|
||||
}
|
||||
|
||||
// GetHubByKey returns a hub by its raw DB key.
|
||||
func GetHubByKey(key string) (*Hub, error) {
|
||||
r, err := db.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hub, err := EnsureHub(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hub, nil
|
||||
}
|
||||
|
||||
// EnsureHub makes sure a database record is a Hub.
|
||||
func EnsureHub(r record.Record) (*Hub, error) {
|
||||
// unwrap
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
newHub := &Hub{}
|
||||
err := record.Unwrap(r, newHub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newHub = prepHub(newHub)
|
||||
|
||||
// Fully validate when getting from database.
|
||||
if err := newHub.Info.validateFormatting(); err != nil {
|
||||
return nil, fmt.Errorf("announcement failed format validation: %w", err)
|
||||
}
|
||||
if err := newHub.Status.validateFormatting(); err != nil {
|
||||
return nil, fmt.Errorf("status failed format validation: %w", err)
|
||||
}
|
||||
if err := newHub.Info.prepare(false); err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare announcement: %w", err)
|
||||
}
|
||||
|
||||
return newHub, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
newHub, ok := r.(*Hub)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *Hub, but %T", r)
|
||||
}
|
||||
newHub = prepHub(newHub)
|
||||
|
||||
// Prepare only when already parsed.
|
||||
if err := newHub.Info.prepare(false); err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare announcement: %w", err)
|
||||
}
|
||||
|
||||
// ensure status
|
||||
return newHub, nil
|
||||
}
|
||||
|
||||
func prepHub(h *Hub) *Hub {
|
||||
if h.Status == nil {
|
||||
h.Status = &Status{}
|
||||
}
|
||||
h.Measurements = getSharedMeasurements(h.ID, h.Measurements)
|
||||
return h
|
||||
}
|
||||
|
||||
// Save saves to Hub to the correct scope in the database.
|
||||
func (h *Hub) Save() error {
|
||||
if !h.KeyIsSet() {
|
||||
h.SetKey(MakeHubDBKey(h.Map, h.ID))
|
||||
}
|
||||
|
||||
return db.Put(h)
|
||||
}
|
||||
|
||||
// RemoveHubAndMsgs deletes a Hub and it's saved messages from the database.
|
||||
func RemoveHubAndMsgs(mapName string, hubID string) (err error) {
|
||||
err = db.Delete(MakeHubDBKey(mapName, hubID))
|
||||
if err != nil && !errors.Is(err, database.ErrNotFound) {
|
||||
return fmt.Errorf("failed to delete main hub entry: %w", err)
|
||||
}
|
||||
|
||||
err = db.Delete(MakeHubMsgDBKey(mapName, MsgTypeAnnouncement, hubID))
|
||||
if err != nil && !errors.Is(err, database.ErrNotFound) {
|
||||
return fmt.Errorf("failed to delete hub announcement data: %w", err)
|
||||
}
|
||||
|
||||
err = db.Delete(MakeHubMsgDBKey(mapName, MsgTypeStatus, hubID))
|
||||
if err != nil && !errors.Is(err, database.ErrNotFound) {
|
||||
return fmt.Errorf("failed to delete hub status data: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HubMsg stores raw Hub messages.
|
||||
type HubMsg struct { //nolint:golint
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
ID string
|
||||
Map string
|
||||
Type MsgType
|
||||
Data []byte
|
||||
|
||||
Received int64
|
||||
}
|
||||
|
||||
// SaveHubMsg saves a raw (and signed) message received by another Hub.
|
||||
func SaveHubMsg(id string, mapName string, msgType MsgType, data []byte) error {
|
||||
// create wrapper record
|
||||
msg := &HubMsg{
|
||||
ID: id,
|
||||
Map: mapName,
|
||||
Type: msgType,
|
||||
Data: data,
|
||||
Received: time.Now().Unix(),
|
||||
}
|
||||
// set key
|
||||
msg.SetKey(MakeHubMsgDBKey(msg.Map, msg.Type, msg.ID))
|
||||
// save
|
||||
return db.PutNew(msg)
|
||||
}
|
||||
|
||||
// QueryRawGossipMsgs queries the database for raw gossip messages.
|
||||
func QueryRawGossipMsgs(mapName string, msgType MsgType) (it *iterator.Iterator, err error) {
|
||||
it, err = db.Query(query.New(MakeHubMsgDBKey(mapName, msgType, "")))
|
||||
return
|
||||
}
|
||||
|
||||
// EnsureHubMsg makes sure a database record is a HubMsg.
|
||||
func EnsureHubMsg(r record.Record) (*HubMsg, error) {
|
||||
// unwrap
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
newHubMsg := &HubMsg{}
|
||||
err := record.Unwrap(r, newHubMsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newHubMsg, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
newHubMsg, ok := r.(*HubMsg)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *Hub, but %T", r)
|
||||
}
|
||||
return newHubMsg, nil
|
||||
}
|
||||
21
spn/hub/errors.go
Normal file
21
spn/hub/errors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package hub
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrMissingInfo signifies that the hub is missing the HubAnnouncement.
|
||||
ErrMissingInfo = errors.New("hub has no announcement")
|
||||
|
||||
// ErrMissingTransports signifies that the hub announcement did not specify any transports.
|
||||
ErrMissingTransports = errors.New("hub announcement has no transports")
|
||||
|
||||
// ErrMissingIPs signifies that the hub announcement did not specify any IPs,
|
||||
// or none of the IPs is supported by the client.
|
||||
ErrMissingIPs = errors.New("hub announcement has no (supported) IPs")
|
||||
|
||||
// ErrTemporaryValidationError is returned when a validation error might be temporary.
|
||||
ErrTemporaryValidationError = errors.New("temporary validation error")
|
||||
|
||||
// ErrOldData is returned when received data is outdated.
|
||||
ErrOldData = errors.New("")
|
||||
)
|
||||
69
spn/hub/format.go
Normal file
69
spn/hub/format.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
|
||||
"github.com/safing/portmaster/service/network/netutils"
|
||||
)
|
||||
|
||||
// BaselineCharset defines the permitted characters.
|
||||
var BaselineCharset = regexp.MustCompile(
|
||||
// Start of charset selection.
|
||||
`^[` +
|
||||
// Printable ASCII (character code 32-127), excluding common control characters of different languages: "$%&';<>\` and DELETE.
|
||||
` !#()*+,\-\./0-9:=?@A-Z[\]^_a-z{|}~` +
|
||||
// Only latin characters from extended ASCII (character code 128-255).
|
||||
`ŠŒŽšœžŸ¡¿ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿ` +
|
||||
// End of charset selection.
|
||||
`]*$`,
|
||||
)
|
||||
|
||||
func checkStringFormat(fieldName, value string, maxLength int) error {
|
||||
switch {
|
||||
case len(value) > maxLength:
|
||||
return fmt.Errorf("field %s with length of %d exceeds max length of %d", fieldName, len(value), maxLength)
|
||||
case !BaselineCharset.MatchString(value):
|
||||
return fmt.Errorf("field %s contains characters not permitted by baseline validation", fieldName)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkStringSliceFormat(fieldName string, value []string, maxLength, maxStringLength int) error { //nolint:unparam
|
||||
if len(value) > maxLength {
|
||||
return fmt.Errorf("field %s with array/slice length of %d exceeds max length of %d", fieldName, len(value), maxLength)
|
||||
}
|
||||
for _, s := range value {
|
||||
if err := checkStringFormat(fieldName, s, maxStringLength); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkByteSliceFormat(fieldName string, value []byte, maxLength int) error {
|
||||
switch {
|
||||
case len(value) > maxLength:
|
||||
return fmt.Errorf("field %s with length of %d exceeds max length of %d", fieldName, len(value), maxLength)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkIPFormat(fieldName string, value net.IP) error {
|
||||
// Check if there is an IP address.
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(value) != 4 && len(value) != 16:
|
||||
return fmt.Errorf("field %s has an invalid length of %d for an IP address", fieldName, len(value))
|
||||
case netutils.GetIPScope(value) == netutils.Invalid:
|
||||
return fmt.Errorf("field %s holds an invalid IP address: %s", fieldName, value)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
81
spn/hub/format_test.go
Normal file
81
spn/hub/format_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCheckStringFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testSet := map[string]bool{
|
||||
// Printable ASCII (character code 32-127)
|
||||
" ": true, "!": true, `"`: false, "#": true, "$": false, "%": false, "&": false, "'": false,
|
||||
"(": true, ")": true, "*": true, "+": true, ",": true, "-": true, ".": true, "/": true,
|
||||
"0": true, "1": true, "2": true, "3": true, "4": true, "5": true, "6": true, "7": true,
|
||||
"8": true, "9": true, ":": true, ";": false, "<": false, "=": true, ">": false, "?": true,
|
||||
"@": true, "A": true, "B": true, "C": true, "D": true, "E": true, "F": true, "G": true,
|
||||
"H": true, "I": true, "J": true, "K": true, "L": true, "M": true, "N": true, "O": true,
|
||||
"P": true, "Q": true, "R": true, "S": true, "T": true, "U": true, "V": true, "W": true,
|
||||
"X": true, "Y": true, "Z": true, "[": true, `\`: false, "]": true, "^": true, "_": true,
|
||||
"`": false, "a": true, "b": true, "c": true, "d": true, "e": true, "f": true, "g": true,
|
||||
"h": true, "i": true, "j": true, "k": true, "l": true, "m": true, "n": true, "o": true,
|
||||
"p": true, "q": true, "r": true, "s": true, "t": true, "u": true, "v": true, "w": true,
|
||||
"x": true, "y": true, "z": true, "{": true, "|": true, "}": true, "~": true,
|
||||
// Not testing for DELETE character.
|
||||
|
||||
// Extended ASCII (character code 128-255)
|
||||
"€": false, "‚": false, "ƒ": false, "„": false, "…": false, "†": false, "‡": false, "ˆ": false,
|
||||
"‰": false, "Š": true, "‹": false, "Œ": true, "Ž": true, "‘": false, "’": false, "“": false,
|
||||
"”": false, "•": false, "–": false, "—": false, "˜": false, "™": false, "š": true, "›": false,
|
||||
"œ": true, "ž": true, "Ÿ": true, "¡": true, "¢": false, "£": false, "¤": false, "¥": false,
|
||||
"¦": false, "§": false, "¨": false, "©": false, "ª": false, "«": false, "¬": false, "®": false,
|
||||
"¯": false, "°": false, "±": false, "²": false, "³": false, "´": false, "µ": false, "¶": false,
|
||||
"·": false, "¸": false, "¹": false, "º": false, "»": false, "¼": false, "½": false, "¾": false,
|
||||
"¿": true, "À": true, "Á": true, "Â": true, "Ã": true, "Ä": true, "Å": true, "Æ": true,
|
||||
"Ç": true, "È": true, "É": true, "Ê": true, "Ë": true, "Ì": true, "Í": true, "Î": true,
|
||||
"Ï": true, "Ð": true, "Ñ": true, "Ò": true, "Ó": true, "Ô": true, "Õ": true, "Ö": true,
|
||||
"×": false, "Ø": true, "Ù": true, "Ú": true, "Û": true, "Ü": true, "Ý": true, "Þ": true,
|
||||
"ß": true, "à": true, "á": true, "â": true, "ã": true, "ä": true, "å": true, "æ": true,
|
||||
"ç": true, "è": true, "é": true, "ê": true, "ë": true, "ì": true, "í": true, "î": true,
|
||||
"ï": true, "ð": true, "ñ": true, "ò": true, "ó": true, "ô": true, "õ": true, "ö": true,
|
||||
"÷": false, "ø": true, "ù": true, "ú": true, "û": true, "ü": true, "ý": true, "þ": true,
|
||||
"ÿ": true,
|
||||
}
|
||||
|
||||
for testCharacter, isPermitted := range testSet {
|
||||
if isPermitted {
|
||||
assert.NoError(t, checkStringFormat(fmt.Sprintf("test character %q", testCharacter), testCharacter, 3))
|
||||
} else {
|
||||
assert.Error(t, checkStringFormat(fmt.Sprintf("test character %q", testCharacter), testCharacter, 3))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIPFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// IPv4
|
||||
assert.NoError(t, checkIPFormat("test IP 1.1.1.1", net.IPv4(1, 1, 1, 1)))
|
||||
assert.NoError(t, checkIPFormat("test IP 192.168.1.1", net.IPv4(192, 168, 1, 1)))
|
||||
assert.Error(t, checkIPFormat("test IP 255.0.0.1", net.IPv4(255, 0, 0, 1)))
|
||||
|
||||
// IPv6
|
||||
assert.NoError(t, checkIPFormat("test IP ::1", net.ParseIP("::1")))
|
||||
assert.NoError(t, checkIPFormat("test IP 2606:4700:4700::1111", net.ParseIP("2606:4700:4700::1111")))
|
||||
|
||||
// Invalid
|
||||
assert.Error(t, checkIPFormat("test IP with length 3", net.IP([]byte{0, 0, 0})))
|
||||
assert.Error(t, checkIPFormat("test IP with length 5", net.IP([]byte{0, 0, 0, 0, 0})))
|
||||
assert.Error(t, checkIPFormat(
|
||||
"test IP with length 15",
|
||||
net.IP([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
|
||||
))
|
||||
assert.Error(t, checkIPFormat(
|
||||
"test IP with length 17",
|
||||
net.IP([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
|
||||
))
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user