Restructure modules (#1572)
* Move portbase into monorepo * Add new simple module mgr * [WIP] Switch to new simple module mgr * Add StateMgr and more worker variants * [WIP] Switch more modules * [WIP] Switch more modules * [WIP] swtich more modules * [WIP] switch all SPN modules * [WIP] switch all service modules * [WIP] Convert all workers to the new module system * [WIP] add new task system to module manager * [WIP] Add second take for scheduling workers * [WIP] Add FIXME for bugs in new scheduler * [WIP] Add minor improvements to scheduler * [WIP] Add new worker scheduler * [WIP] Fix more bug related to new module system * [WIP] Fix start handing of the new module system * [WIP] Improve startup process * [WIP] Fix minor issues * [WIP] Fix missing subsystem in settings * [WIP] Initialize managers in constructor * [WIP] Move module event initialization to constrictors * [WIP] Fix setting for enabling and disabling the SPN module * [WIP] Move API registeration into module construction * [WIP] Update states mgr for all modules * [WIP] Add CmdLine operation support * Add state helper methods to module group and instance * Add notification and module status handling to status package * Fix starting issues * Remove pilot widget and update security lock to new status data * Remove debug logs * Improve http server shutdown * Add workaround for cleanly shutting down firewall+netquery * Improve logging * Add syncing states with notifications for new module system * Improve starting, stopping, shutdown; resolve FIXMEs/TODOs * [WIP] Fix most unit tests * Review new module system and fix minor issues * Push shutdown and restart events again via API * Set sleep mode via interface * Update example/template module * [WIP] Fix spn/cabin unit test * Remove deprecated UI elements * Make log output more similar for the logging transition phase * Switch spn hub and observer cmds to new module system * Fix log sources * Make worker mgr less error prone * Fix tests and minor issues * Fix observation hub * Improve shutdown and restart handling * Split up big connection.go source file * Move varint and dsd packages to structures repo * Improve expansion test * Fix linter warnings * Fix interception module on windows * Fix linter errors --------- Co-authored-by: Vladimir Stoilov <vladimir@safing.io>
This commit is contained in:
173
base/api/api_bridge.go
Normal file
173
base/api/api_bridge.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
endpointBridgeRemoteAddress = "websocket-bridge"
|
||||
apiDatabaseName = "api"
|
||||
)
|
||||
|
||||
func registerEndpointBridgeDB() error {
|
||||
if _, err := database.Register(&database.Database{
|
||||
Name: apiDatabaseName,
|
||||
Description: "API Bridge",
|
||||
StorageType: "injected",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := database.InjectDatabase("api", &endpointBridgeStorage{})
|
||||
return err
|
||||
}
|
||||
|
||||
type endpointBridgeStorage struct {
|
||||
storage.InjectBase
|
||||
}
|
||||
|
||||
// EndpointBridgeRequest holds a bridged request API request.
|
||||
type EndpointBridgeRequest struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Method string
|
||||
Path string
|
||||
Query map[string]string
|
||||
Data []byte
|
||||
MimeType string
|
||||
}
|
||||
|
||||
// EndpointBridgeResponse holds a bridged request API response.
|
||||
type EndpointBridgeResponse struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
MimeType string
|
||||
Body string
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (ebs *endpointBridgeStorage) Get(key string) (record.Record, error) {
|
||||
if key == "" {
|
||||
return nil, database.ErrNotFound
|
||||
}
|
||||
|
||||
return callAPI(&EndpointBridgeRequest{
|
||||
Method: http.MethodGet,
|
||||
Path: key,
|
||||
})
|
||||
}
|
||||
|
||||
// Get returns the metadata of a database record.
|
||||
func (ebs *endpointBridgeStorage) GetMeta(key string) (*record.Meta, error) {
|
||||
// This interface is an API, always return a fresh copy.
|
||||
m := &record.Meta{}
|
||||
m.Update()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (ebs *endpointBridgeStorage) Put(r record.Record) (record.Record, error) {
|
||||
if r.DatabaseKey() == "" {
|
||||
return nil, database.ErrNotFound
|
||||
}
|
||||
|
||||
// Prepare data.
|
||||
var ebr *EndpointBridgeRequest
|
||||
if r.IsWrapped() {
|
||||
// Only allocate a new struct, if we need it.
|
||||
ebr = &EndpointBridgeRequest{}
|
||||
err := record.Unwrap(r, ebr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
var ok bool
|
||||
ebr, ok = r.(*EndpointBridgeRequest)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *EndpointBridgeRequest, but %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
// Override path with key to mitigate sneaky stuff.
|
||||
ebr.Path = r.DatabaseKey()
|
||||
return callAPI(ebr)
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (ebs *endpointBridgeStorage) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func callAPI(ebr *EndpointBridgeRequest) (record.Record, error) {
|
||||
// Add API prefix to path.
|
||||
requestURL := path.Join(apiV1Path, ebr.Path)
|
||||
// Check if path is correct. (Defense in depth)
|
||||
if !strings.HasPrefix(requestURL, apiV1Path) {
|
||||
return nil, fmt.Errorf("bridged request for %q violates scope", ebr.Path)
|
||||
}
|
||||
|
||||
// Apply default Method.
|
||||
if ebr.Method == "" {
|
||||
if len(ebr.Data) > 0 {
|
||||
ebr.Method = http.MethodPost
|
||||
} else {
|
||||
ebr.Method = http.MethodGet
|
||||
}
|
||||
}
|
||||
|
||||
// Build URL.
|
||||
u, err := url.ParseRequestURI(requestURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build bridged request url: %w", err)
|
||||
}
|
||||
// Build query values.
|
||||
if ebr.Query != nil && len(ebr.Query) > 0 {
|
||||
query := url.Values{}
|
||||
for k, v := range ebr.Query {
|
||||
query.Set(k, v)
|
||||
}
|
||||
u.RawQuery = query.Encode()
|
||||
}
|
||||
|
||||
// Create request and response objects.
|
||||
r := httptest.NewRequest(ebr.Method, u.String(), bytes.NewBuffer(ebr.Data))
|
||||
r.RemoteAddr = endpointBridgeRemoteAddress
|
||||
if ebr.MimeType != "" {
|
||||
r.Header.Set("Content-Type", ebr.MimeType)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
// Let the API handle the request.
|
||||
server.Handler.ServeHTTP(w, r)
|
||||
switch w.Code {
|
||||
case 200:
|
||||
// Everything okay, continue.
|
||||
case 500:
|
||||
// A Go error was returned internally.
|
||||
// We can safely return this as an error.
|
||||
return nil, fmt.Errorf("bridged api call failed: %s", w.Body.String())
|
||||
default:
|
||||
return nil, fmt.Errorf("bridged api call returned unexpected error code %d", w.Code)
|
||||
}
|
||||
|
||||
response := &EndpointBridgeResponse{
|
||||
MimeType: w.Header().Get("Content-Type"),
|
||||
Body: w.Body.String(),
|
||||
}
|
||||
response.SetKey(apiDatabaseName + ":" + ebr.Path)
|
||||
response.UpdateMeta()
|
||||
|
||||
return response, nil
|
||||
}
|
||||
30
base/api/auth_wrapper.go
Normal file
30
base/api/auth_wrapper.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package api
|
||||
|
||||
import "net/http"
|
||||
|
||||
// WrapInAuthHandler wraps a simple http.HandlerFunc into a handler that
|
||||
// exposes the required API permissions for this handler.
|
||||
func WrapInAuthHandler(fn http.HandlerFunc, read, write Permission) http.Handler {
|
||||
return &wrappedAuthenticatedHandler{
|
||||
HandlerFunc: fn,
|
||||
read: read,
|
||||
write: write,
|
||||
}
|
||||
}
|
||||
|
||||
type wrappedAuthenticatedHandler struct {
|
||||
http.HandlerFunc
|
||||
|
||||
read Permission
|
||||
write Permission
|
||||
}
|
||||
|
||||
// ReadPermission returns the read permission for the handler.
|
||||
func (wah *wrappedAuthenticatedHandler) ReadPermission(r *http.Request) Permission {
|
||||
return wah.read
|
||||
}
|
||||
|
||||
// WritePermission returns the write permission for the handler.
|
||||
func (wah *wrappedAuthenticatedHandler) WritePermission(r *http.Request) Permission {
|
||||
return wah.write
|
||||
}
|
||||
598
base/api/authentication.go
Normal file
598
base/api/authentication.go
Normal file
@@ -0,0 +1,598 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/config"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/base/rng"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionCookieName = "Portmaster-API-Token"
|
||||
sessionCookieTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
apiKeys = make(map[string]*AuthToken)
|
||||
apiKeysLock sync.Mutex
|
||||
|
||||
authFnSet = abool.New()
|
||||
authFn AuthenticatorFunc
|
||||
|
||||
sessions = make(map[string]*session)
|
||||
sessionsLock sync.Mutex
|
||||
|
||||
// ErrAPIAccessDeniedMessage should be wrapped by errors returned by
|
||||
// AuthenticatorFunc in order to signify a blocked request, including a error
|
||||
// message for the user. This is an empty message on purpose, as to allow the
|
||||
// function to define the full text of the error shown to the user.
|
||||
ErrAPIAccessDeniedMessage = errors.New("")
|
||||
)
|
||||
|
||||
// Permission defines an API requests permission.
|
||||
type Permission int8
|
||||
|
||||
const (
|
||||
// NotFound declares that the operation does not exist.
|
||||
NotFound Permission = -2
|
||||
|
||||
// Dynamic declares that the operation requires permission to be processed,
|
||||
// but anyone can execute the operation, as it reacts to permissions itself.
|
||||
Dynamic Permission = -1
|
||||
|
||||
// NotSupported declares that the operation is not supported.
|
||||
NotSupported Permission = 0
|
||||
|
||||
// PermitAnyone declares that anyone can execute the operation without any
|
||||
// authentication.
|
||||
PermitAnyone Permission = 1
|
||||
|
||||
// PermitUser declares that the operation may be executed by authenticated
|
||||
// third party applications that are categorized as representing a simple
|
||||
// user and is limited in access.
|
||||
PermitUser Permission = 2
|
||||
|
||||
// PermitAdmin declares that the operation may be executed by authenticated
|
||||
// third party applications that are categorized as representing an
|
||||
// administrator and has broad in access.
|
||||
PermitAdmin Permission = 3
|
||||
|
||||
// PermitSelf declares that the operation may only be executed by the
|
||||
// software itself and its own (first party) components.
|
||||
PermitSelf Permission = 4
|
||||
)
|
||||
|
||||
// AuthenticatorFunc is a function that can be set as the authenticator for the
|
||||
// API endpoint. If none is set, all requests will have full access.
|
||||
// The returned AuthToken represents the permissions that the request has.
|
||||
type AuthenticatorFunc func(r *http.Request, s *http.Server) (*AuthToken, error)
|
||||
|
||||
// AuthToken represents either a set of required or granted permissions.
|
||||
// All attributes must be set when the struct is built and must not be changed
|
||||
// later. Functions may be called at any time.
|
||||
// The Write permission implicitly also includes reading.
|
||||
type AuthToken struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
ValidUntil *time.Time
|
||||
}
|
||||
|
||||
type session struct {
|
||||
sync.Mutex
|
||||
|
||||
token *AuthToken
|
||||
validUntil time.Time
|
||||
}
|
||||
|
||||
// Expired returns whether the session has expired.
|
||||
func (sess *session) Expired() bool {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
|
||||
return time.Now().After(sess.validUntil)
|
||||
}
|
||||
|
||||
// Refresh refreshes the validity of the session with the given TTL.
|
||||
func (sess *session) Refresh(ttl time.Duration) {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
|
||||
sess.validUntil = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
// AuthenticatedHandler defines the handler interface to specify custom
|
||||
// permission for an API handler. The returned permission is the required
|
||||
// permission for the request to proceed.
|
||||
type AuthenticatedHandler interface {
|
||||
ReadPermission(r *http.Request) Permission
|
||||
WritePermission(r *http.Request) Permission
|
||||
}
|
||||
|
||||
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted.
|
||||
func SetAuthenticator(fn AuthenticatorFunc) error {
|
||||
if module.online.Load() {
|
||||
return ErrAuthenticationImmutable
|
||||
}
|
||||
|
||||
if !authFnSet.SetToIf(false, true) {
|
||||
return ErrAuthenticationAlreadySet
|
||||
}
|
||||
|
||||
authFn = fn
|
||||
return nil
|
||||
}
|
||||
|
||||
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler, readMethod bool) *AuthToken {
|
||||
tracer := log.Tracer(r.Context())
|
||||
|
||||
// Get required permission for target handler.
|
||||
requiredPermission := PermitSelf
|
||||
if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok {
|
||||
if readMethod {
|
||||
requiredPermission = authdHandler.ReadPermission(r)
|
||||
} else {
|
||||
requiredPermission = authdHandler.WritePermission(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we need to do any authentication at all.
|
||||
switch requiredPermission { //nolint:exhaustive
|
||||
case NotFound:
|
||||
// Not found.
|
||||
tracer.Debug("api: no API endpoint registered for this path")
|
||||
http.Error(w, "Not found.", http.StatusNotFound)
|
||||
return nil
|
||||
case NotSupported:
|
||||
// A read or write permission can be marked as not supported.
|
||||
tracer.Trace("api: authenticated handler reported: not supported")
|
||||
http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
case PermitAnyone:
|
||||
// Don't process permissions, as we don't need them.
|
||||
tracer.Tracef("api: granted %s access to public handler", r.RemoteAddr)
|
||||
return &AuthToken{
|
||||
Read: PermitAnyone,
|
||||
Write: PermitAnyone,
|
||||
}
|
||||
case Dynamic:
|
||||
// Continue processing permissions, but treat as PermitAnyone.
|
||||
requiredPermission = PermitAnyone
|
||||
}
|
||||
|
||||
// The required permission must match the request permission values after
|
||||
// handling the specials.
|
||||
if requiredPermission < PermitAnyone || requiredPermission > PermitSelf {
|
||||
tracer.Warningf(
|
||||
"api: handler returned invalid permission: %s (%d)",
|
||||
requiredPermission,
|
||||
requiredPermission,
|
||||
)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticate request.
|
||||
token, handled := checkAuth(w, r, requiredPermission > PermitAnyone)
|
||||
switch {
|
||||
case handled:
|
||||
return nil
|
||||
case token == nil:
|
||||
// Use default permissions.
|
||||
token = &AuthToken{
|
||||
Read: PermitAnyone,
|
||||
Write: PermitAnyone,
|
||||
}
|
||||
}
|
||||
|
||||
// Get effective permission for request.
|
||||
var requestPermission Permission
|
||||
if readMethod {
|
||||
requestPermission = token.Read
|
||||
} else {
|
||||
requestPermission = token.Write
|
||||
}
|
||||
|
||||
// Check for valid request permission.
|
||||
if requestPermission < PermitAnyone || requestPermission > PermitSelf {
|
||||
tracer.Warningf(
|
||||
"api: authenticator returned invalid permission: %s (%d)",
|
||||
requestPermission,
|
||||
requestPermission,
|
||||
)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check permission.
|
||||
if requestPermission < requiredPermission {
|
||||
// If the token is strictly public, return an authentication request.
|
||||
if token.Read == PermitAnyone && token.Write == PermitAnyone {
|
||||
w.Header().Set(
|
||||
"WWW-Authenticate",
|
||||
`Bearer realm="Portmaster API" domain="/"`,
|
||||
)
|
||||
http.Error(w, "Authorization required.", http.StatusUnauthorized)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise just inform of insufficient permissions.
|
||||
http.Error(w, "Insufficient permissions.", http.StatusForbidden)
|
||||
return nil
|
||||
}
|
||||
|
||||
tracer.Tracef("api: granted %s access to protected handler", r.RemoteAddr)
|
||||
|
||||
// Make a copy of the AuthToken in order mitigate the handler poisoning the
|
||||
// token, as changes would apply to future requests.
|
||||
return &AuthToken{
|
||||
Read: token.Read,
|
||||
Write: token.Write,
|
||||
}
|
||||
}
|
||||
|
||||
func checkAuth(w http.ResponseWriter, r *http.Request, authRequired bool) (token *AuthToken, handled bool) {
|
||||
// Return highest possible permissions in dev mode.
|
||||
if devMode() {
|
||||
return &AuthToken{
|
||||
Read: PermitSelf,
|
||||
Write: PermitSelf,
|
||||
}, false
|
||||
}
|
||||
|
||||
// Database Bridge Access.
|
||||
if r.RemoteAddr == endpointBridgeRemoteAddress {
|
||||
return &AuthToken{
|
||||
Read: dbCompatibilityPermission,
|
||||
Write: dbCompatibilityPermission,
|
||||
}, false
|
||||
}
|
||||
|
||||
// Check for valid API key.
|
||||
token = checkAPIKey(r)
|
||||
if token != nil {
|
||||
return token, false
|
||||
}
|
||||
|
||||
// Check for valid session cookie.
|
||||
token = checkSessionCookie(r)
|
||||
if token != nil {
|
||||
return token, false
|
||||
}
|
||||
|
||||
// Check if an external authentication method is available.
|
||||
if !authFnSet.IsSet() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Authenticate externally.
|
||||
token, err := authFn(r, server)
|
||||
if err != nil {
|
||||
// Check if the authentication process failed internally.
|
||||
if !errors.Is(err, ErrAPIAccessDeniedMessage) {
|
||||
log.Tracer(r.Context()).Errorf("api: authenticator failed: %s", err)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// Return authentication failure message if authentication is required.
|
||||
if authRequired {
|
||||
log.Tracer(r.Context()).Warningf("api: denying api access from %s", r.RemoteAddr)
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Abort if no token is returned.
|
||||
if token == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Create session cookie for authenticated request.
|
||||
err = createSession(w, r, token)
|
||||
if err != nil {
|
||||
log.Tracer(r.Context()).Warningf("api: failed to create session: %s", err)
|
||||
}
|
||||
return token, false
|
||||
}
|
||||
|
||||
func checkAPIKey(r *http.Request) *AuthToken {
|
||||
// Get API key from request.
|
||||
key := r.Header.Get("Authorization")
|
||||
if key == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse API key.
|
||||
switch {
|
||||
case strings.HasPrefix(key, "Bearer "):
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
case strings.HasPrefix(key, "Basic "):
|
||||
user, pass, _ := r.BasicAuth()
|
||||
key = user + pass
|
||||
default:
|
||||
log.Tracer(r.Context()).Tracef(
|
||||
"api: provided api key type %s is unsupported", strings.Split(key, " ")[0],
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKeysLock.Lock()
|
||||
defer apiKeysLock.Unlock()
|
||||
|
||||
// Check if the provided API key exists.
|
||||
token, ok := apiKeys[key]
|
||||
if !ok {
|
||||
log.Tracer(r.Context()).Tracef(
|
||||
"api: provided api key %s... is unknown", key[:4],
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Abort if the token is expired.
|
||||
if token.ValidUntil != nil && time.Now().After(*token.ValidUntil) {
|
||||
log.Tracer(r.Context()).Warningf("api: denying api access from %s using expired token", r.RemoteAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func updateAPIKeys() {
|
||||
apiKeysLock.Lock()
|
||||
defer apiKeysLock.Unlock()
|
||||
|
||||
log.Debug("api: importing possibly updated API keys from config")
|
||||
|
||||
// Delete current keys.
|
||||
for k := range apiKeys {
|
||||
delete(apiKeys, k)
|
||||
}
|
||||
|
||||
// whether or not we found expired API keys that should be removed
|
||||
// from the setting
|
||||
hasExpiredKeys := false
|
||||
|
||||
// a list of valid API keys. Used when hasExpiredKeys is set to true.
|
||||
// in that case we'll update the setting to only contain validAPIKeys
|
||||
validAPIKeys := []string{}
|
||||
|
||||
// Parse new keys.
|
||||
for _, key := range configuredAPIKeys() {
|
||||
u, err := url.Parse(key)
|
||||
if err != nil {
|
||||
log.Errorf("api: failed to parse configured API key %s: %s", key, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if u.Path == "" {
|
||||
log.Errorf("api: malformed API key %s: missing path section", key)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Create token with default permissions.
|
||||
token := &AuthToken{
|
||||
Read: PermitAnyone,
|
||||
Write: PermitAnyone,
|
||||
}
|
||||
|
||||
// Update with configured permissions.
|
||||
q := u.Query()
|
||||
// Parse read permission.
|
||||
readPermission, err := parseAPIPermission(q.Get("read"))
|
||||
if err != nil {
|
||||
log.Errorf("api: invalid API key %s: %s", key, err)
|
||||
continue
|
||||
}
|
||||
token.Read = readPermission
|
||||
// Parse write permission.
|
||||
writePermission, err := parseAPIPermission(q.Get("write"))
|
||||
if err != nil {
|
||||
log.Errorf("api: invalid API key %s: %s", key, err)
|
||||
continue
|
||||
}
|
||||
token.Write = writePermission
|
||||
|
||||
expireStr := q.Get("expires")
|
||||
if expireStr != "" {
|
||||
validUntil, err := time.Parse(time.RFC3339, expireStr)
|
||||
if err != nil {
|
||||
log.Errorf("api: invalid API key %s: %s", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// continue to the next token if this one is already invalid
|
||||
if time.Now().After(validUntil) {
|
||||
// mark the key as expired so we'll remove it from the setting afterwards
|
||||
hasExpiredKeys = true
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
token.ValidUntil = &validUntil
|
||||
}
|
||||
|
||||
// Save token.
|
||||
apiKeys[u.Path] = token
|
||||
validAPIKeys = append(validAPIKeys, key)
|
||||
}
|
||||
|
||||
if hasExpiredKeys {
|
||||
module.mgr.Go("api key cleanup", func(ctx *mgr.WorkerCtx) error {
|
||||
if err := config.SetConfigOption(CfgAPIKeys, validAPIKeys); err != nil {
|
||||
log.Errorf("api: failed to remove expired API keys: %s", err)
|
||||
} else {
|
||||
log.Infof("api: removed expired API keys from %s", CfgAPIKeys)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkSessionCookie(r *http.Request) *AuthToken {
|
||||
// Get session cookie from request.
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if session cookie is registered.
|
||||
sessionsLock.Lock()
|
||||
sess, ok := sessions[c.Value]
|
||||
sessionsLock.Unlock()
|
||||
if !ok {
|
||||
log.Tracer(r.Context()).Tracef("api: provided session cookie %s is unknown", c.Value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if session is still valid.
|
||||
if sess.Expired() {
|
||||
log.Tracer(r.Context()).Tracef("api: provided session cookie %s has expired", c.Value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh session and return.
|
||||
sess.Refresh(sessionCookieTTL)
|
||||
log.Tracer(r.Context()).Tracef("api: session cookie %s is valid, refreshing", c.Value)
|
||||
return sess.token
|
||||
}
|
||||
|
||||
func createSession(w http.ResponseWriter, r *http.Request, token *AuthToken) error {
|
||||
// Generate new session key.
|
||||
secret, err := rng.Bytes(32) // 256 bit
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionKey := base64.RawURLEncoding.EncodeToString(secret)
|
||||
|
||||
// Set token cookie in response.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: sessionKey,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
// Create session.
|
||||
sess := &session{
|
||||
token: token,
|
||||
}
|
||||
sess.Refresh(sessionCookieTTL)
|
||||
|
||||
// Save session.
|
||||
sessionsLock.Lock()
|
||||
defer sessionsLock.Unlock()
|
||||
sessions[sessionKey] = sess
|
||||
log.Tracer(r.Context()).Debug("api: issued session cookie")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanSessions(_ *mgr.WorkerCtx) error {
|
||||
sessionsLock.Lock()
|
||||
defer sessionsLock.Unlock()
|
||||
|
||||
for sessionKey, sess := range sessions {
|
||||
if sess.Expired() {
|
||||
delete(sessions, sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteSession(sessionKey string) {
|
||||
sessionsLock.Lock()
|
||||
defer sessionsLock.Unlock()
|
||||
|
||||
delete(sessions, sessionKey)
|
||||
}
|
||||
|
||||
func getEffectiveMethod(r *http.Request) (eMethod string, readMethod bool, ok bool) {
|
||||
method := r.Method
|
||||
|
||||
// Get CORS request method if OPTIONS request.
|
||||
if r.Method == http.MethodOptions {
|
||||
method = r.Header.Get("Access-Control-Request-Method")
|
||||
if method == "" {
|
||||
return "", false, false
|
||||
}
|
||||
}
|
||||
|
||||
switch method {
|
||||
case http.MethodGet, http.MethodHead:
|
||||
return http.MethodGet, true, true
|
||||
case http.MethodPost, http.MethodPut, http.MethodDelete:
|
||||
return method, false, true
|
||||
default:
|
||||
return "", false, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseAPIPermission(s string) (Permission, error) {
|
||||
switch strings.ToLower(s) {
|
||||
case "", "anyone":
|
||||
return PermitAnyone, nil
|
||||
case "user":
|
||||
return PermitUser, nil
|
||||
case "admin":
|
||||
return PermitAdmin, nil
|
||||
default:
|
||||
return PermitAnyone, fmt.Errorf("invalid permission: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (p Permission) String() string {
|
||||
switch p {
|
||||
case NotSupported:
|
||||
return "NotSupported"
|
||||
case Dynamic:
|
||||
return "Dynamic"
|
||||
case PermitAnyone:
|
||||
return "PermitAnyone"
|
||||
case PermitUser:
|
||||
return "PermitUser"
|
||||
case PermitAdmin:
|
||||
return "PermitAdmin"
|
||||
case PermitSelf:
|
||||
return "PermitSelf"
|
||||
case NotFound:
|
||||
return "NotFound"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Role returns a string representation of the permission role.
|
||||
func (p Permission) Role() string {
|
||||
switch p {
|
||||
case PermitAnyone:
|
||||
return "Anyone"
|
||||
case PermitUser:
|
||||
return "User"
|
||||
case PermitAdmin:
|
||||
return "Admin"
|
||||
case PermitSelf:
|
||||
return "Self"
|
||||
case Dynamic, NotFound, NotSupported:
|
||||
return "Invalid"
|
||||
default:
|
||||
return "Invalid"
|
||||
}
|
||||
}
|
||||
186
base/api/authentication_test.go
Normal file
186
base/api/authentication_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testToken = new(AuthToken)
|
||||
|
||||
func testAuthenticator(r *http.Request, s *http.Server) (*AuthToken, error) {
|
||||
switch {
|
||||
case testToken.Read == -127 || testToken.Write == -127:
|
||||
return nil, errors.New("test error")
|
||||
case testToken.Read == -128 || testToken.Write == -128:
|
||||
return nil, fmt.Errorf("%wdenied", ErrAPIAccessDeniedMessage)
|
||||
default:
|
||||
return testToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
type testAuthHandler struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) ReadPermission(r *http.Request) Permission {
|
||||
return ah.Read
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) WritePermission(r *http.Request) Permission {
|
||||
return ah.Write
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if request is as expected.
|
||||
ar := GetAPIRequest(r)
|
||||
switch {
|
||||
case ar == nil:
|
||||
http.Error(w, "ar == nil", http.StatusInternalServerError)
|
||||
case ar.AuthToken == nil:
|
||||
http.Error(w, "ar.AuthToken == nil", http.StatusInternalServerError)
|
||||
default:
|
||||
http.Error(w, "auth success", http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func makeAuthTestPath(reading bool, p Permission) string {
|
||||
if reading {
|
||||
return fmt.Sprintf("/test/auth/read/%s", p)
|
||||
}
|
||||
return fmt.Sprintf("/test/auth/write/%s", p)
|
||||
}
|
||||
|
||||
func TestPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testHandler := &mainHandler{
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// Define permissions that need testing.
|
||||
permissionsToTest := []Permission{
|
||||
NotSupported,
|
||||
PermitAnyone,
|
||||
PermitUser,
|
||||
PermitAdmin,
|
||||
PermitSelf,
|
||||
Dynamic,
|
||||
NotFound,
|
||||
100, // Test a too high value.
|
||||
-100, // Test a too low value.
|
||||
-127, // Simulate authenticator failure.
|
||||
-128, // Simulate authentication denied message.
|
||||
}
|
||||
|
||||
// Register test handlers.
|
||||
for _, p := range permissionsToTest {
|
||||
RegisterHandler(makeAuthTestPath(true, p), &testAuthHandler{Read: p})
|
||||
RegisterHandler(makeAuthTestPath(false, p), &testAuthHandler{Write: p})
|
||||
}
|
||||
|
||||
// Test all the combinations.
|
||||
for _, requestPerm := range permissionsToTest {
|
||||
for _, handlerPerm := range permissionsToTest {
|
||||
for _, method := range []string{
|
||||
http.MethodGet,
|
||||
http.MethodHead,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodDelete,
|
||||
} {
|
||||
|
||||
// Set request permission for test requests.
|
||||
_, reading, _ := getEffectiveMethod(&http.Request{Method: method})
|
||||
if reading {
|
||||
testToken.Read = requestPerm
|
||||
testToken.Write = NotSupported
|
||||
} else {
|
||||
testToken.Read = NotSupported
|
||||
testToken.Write = requestPerm
|
||||
}
|
||||
|
||||
// Evaluate expected result.
|
||||
var expectSuccess bool
|
||||
switch {
|
||||
case handlerPerm == PermitAnyone:
|
||||
// This is fast-tracked. There are not additional checks.
|
||||
expectSuccess = true
|
||||
case handlerPerm == Dynamic:
|
||||
// This is turned into PermitAnyone in the authenticator.
|
||||
// But authentication is still processed and the result still gets
|
||||
// sanity checked!
|
||||
if requestPerm >= PermitAnyone &&
|
||||
requestPerm <= PermitSelf {
|
||||
expectSuccess = true
|
||||
}
|
||||
// Another special case is when the handler requires permission to be
|
||||
// processed but the authenticator fails to authenticate the request.
|
||||
// In this case, a fallback token with PermitAnyone is used.
|
||||
if requestPerm == -128 {
|
||||
// -128 is used to simulate a permission denied message.
|
||||
expectSuccess = true
|
||||
}
|
||||
case handlerPerm <= NotSupported:
|
||||
// Invalid handler permission.
|
||||
case handlerPerm > PermitSelf:
|
||||
// Invalid handler permission.
|
||||
case requestPerm <= NotSupported:
|
||||
// Invalid request permission.
|
||||
case requestPerm > PermitSelf:
|
||||
// Invalid request permission.
|
||||
case requestPerm < handlerPerm:
|
||||
// Valid, but insufficient request permission.
|
||||
default:
|
||||
expectSuccess = true
|
||||
}
|
||||
|
||||
if expectSuccess {
|
||||
// Test for success.
|
||||
if !assert.HTTPBodyContains(
|
||||
t,
|
||||
testHandler.ServeHTTP,
|
||||
method,
|
||||
makeAuthTestPath(reading, handlerPerm),
|
||||
nil,
|
||||
"auth success",
|
||||
) {
|
||||
t.Errorf(
|
||||
"%s with %s (%d) to handler %s (%d)",
|
||||
method,
|
||||
requestPerm, requestPerm,
|
||||
handlerPerm, handlerPerm,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Test for error.
|
||||
if !assert.HTTPError(t,
|
||||
testHandler.ServeHTTP,
|
||||
method,
|
||||
makeAuthTestPath(reading, handlerPerm),
|
||||
nil,
|
||||
) {
|
||||
t.Errorf(
|
||||
"%s with %s (%d) to handler %s (%d)",
|
||||
method,
|
||||
requestPerm, requestPerm,
|
||||
handlerPerm, handlerPerm,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionDefinitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if NotSupported != 0 {
|
||||
t.Fatalf("NotSupported must be zero, was %v", NotSupported)
|
||||
}
|
||||
}
|
||||
57
base/api/client/api.go
Normal file
57
base/api/client/api.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package client
|
||||
|
||||
// Get sends a get command to the API.
|
||||
func (c *Client) Get(key string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestGet, key, nil)
|
||||
return op
|
||||
}
|
||||
|
||||
// Query sends a query command to the API.
|
||||
func (c *Client) Query(query string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestQuery, query, nil)
|
||||
return op
|
||||
}
|
||||
|
||||
// Sub sends a sub command to the API.
|
||||
func (c *Client) Sub(query string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestSub, query, nil)
|
||||
return op
|
||||
}
|
||||
|
||||
// Qsub sends a qsub command to the API.
|
||||
func (c *Client) Qsub(query string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestQsub, query, nil)
|
||||
return op
|
||||
}
|
||||
|
||||
// Create sends a create command to the API.
|
||||
func (c *Client) Create(key string, value interface{}, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestCreate, key, value)
|
||||
return op
|
||||
}
|
||||
|
||||
// Update sends an update command to the API.
|
||||
func (c *Client) Update(key string, value interface{}, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestUpdate, key, value)
|
||||
return op
|
||||
}
|
||||
|
||||
// Insert sends an insert command to the API.
|
||||
func (c *Client) Insert(key string, value interface{}, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestInsert, key, value)
|
||||
return op
|
||||
}
|
||||
|
||||
// Delete sends a delete command to the API.
|
||||
func (c *Client) Delete(key string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestDelete, key, nil)
|
||||
return op
|
||||
}
|
||||
240
base/api/client/client.go
Normal file
240
base/api/client/client.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
const (
|
||||
backOffTimer = 1 * time.Second
|
||||
|
||||
offlineSignal uint8 = 0
|
||||
onlineSignal uint8 = 1
|
||||
)
|
||||
|
||||
// The Client enables easy interaction with the API.
|
||||
type Client struct {
|
||||
sync.Mutex
|
||||
|
||||
server string
|
||||
|
||||
onlineSignal chan struct{}
|
||||
offlineSignal chan struct{}
|
||||
shutdownSignal chan struct{}
|
||||
lastSignal uint8
|
||||
|
||||
send chan *Message
|
||||
resend chan *Message
|
||||
recv chan *Message
|
||||
|
||||
operations map[string]*Operation
|
||||
nextOpID uint64
|
||||
|
||||
lastError string
|
||||
}
|
||||
|
||||
// NewClient returns a new Client.
|
||||
func NewClient(server string) *Client {
|
||||
c := &Client{
|
||||
server: server,
|
||||
onlineSignal: make(chan struct{}),
|
||||
offlineSignal: make(chan struct{}),
|
||||
shutdownSignal: make(chan struct{}),
|
||||
lastSignal: offlineSignal,
|
||||
send: make(chan *Message, 100),
|
||||
resend: make(chan *Message, 1),
|
||||
recv: make(chan *Message, 100),
|
||||
operations: make(map[string]*Operation),
|
||||
}
|
||||
go c.handler()
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect connects to the API once.
|
||||
func (c *Client) Connect() error {
|
||||
defer c.signalOffline()
|
||||
|
||||
err := c.wsConnect()
|
||||
if err != nil && err.Error() != c.lastError {
|
||||
log.Errorf("client: error connecting to Portmaster: %s", err)
|
||||
c.lastError = err.Error()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// StayConnected calls Connect again whenever the connection is lost.
|
||||
func (c *Client) StayConnected() {
|
||||
log.Infof("client: connecting to Portmaster at %s", c.server)
|
||||
|
||||
_ = c.Connect()
|
||||
for {
|
||||
select {
|
||||
case <-time.After(backOffTimer):
|
||||
log.Infof("client: reconnecting...")
|
||||
_ = c.Connect()
|
||||
case <-c.shutdownSignal:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown shuts the client down.
|
||||
func (c *Client) Shutdown() {
|
||||
select {
|
||||
case <-c.shutdownSignal:
|
||||
default:
|
||||
close(c.shutdownSignal)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) signalOnline() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.lastSignal == offlineSignal {
|
||||
log.Infof("client: went online")
|
||||
c.offlineSignal = make(chan struct{})
|
||||
close(c.onlineSignal)
|
||||
c.lastSignal = onlineSignal
|
||||
|
||||
// resend unsent request
|
||||
for _, op := range c.operations {
|
||||
if op.resuscitationEnabled.IsSet() && op.request.sent != nil && op.request.sent.SetToIf(true, false) {
|
||||
op.client.send <- op.request
|
||||
log.Infof("client: resuscitated %s %s %s", op.request.OpID, op.request.Type, op.request.Key)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) signalOffline() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.lastSignal == onlineSignal {
|
||||
log.Infof("client: went offline")
|
||||
c.onlineSignal = make(chan struct{})
|
||||
close(c.offlineSignal)
|
||||
c.lastSignal = offlineSignal
|
||||
|
||||
// signal offline status to operations
|
||||
for _, op := range c.operations {
|
||||
op.handle(&Message{
|
||||
OpID: op.ID,
|
||||
Type: MsgOffline,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Online returns a closed channel read if the client is connected to the API.
|
||||
func (c *Client) Online() <-chan struct{} {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
return c.onlineSignal
|
||||
}
|
||||
|
||||
// Offline returns a closed channel read if the client is not connected to the API.
|
||||
func (c *Client) Offline() <-chan struct{} {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
return c.offlineSignal
|
||||
}
|
||||
|
||||
func (c *Client) handler() {
|
||||
for {
|
||||
select {
|
||||
|
||||
case m := <-c.recv:
|
||||
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
op, ok := c.operations[m.OpID]
|
||||
c.Unlock()
|
||||
|
||||
if ok {
|
||||
log.Tracef("client: [%s] received %s msg: %s", m.OpID, m.Type, m.Key)
|
||||
op.handle(m)
|
||||
} else {
|
||||
log.Tracef("client: received message for unknown operation %s", m.OpID)
|
||||
}
|
||||
|
||||
case <-c.shutdownSignal:
|
||||
return
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Operation represents a single operation by a client.
|
||||
type Operation struct {
|
||||
ID string
|
||||
request *Message
|
||||
client *Client
|
||||
handleFunc func(*Message)
|
||||
handler chan *Message
|
||||
resuscitationEnabled *abool.AtomicBool
|
||||
}
|
||||
|
||||
func (op *Operation) handle(m *Message) {
|
||||
if op.handleFunc != nil {
|
||||
op.handleFunc(m)
|
||||
} else {
|
||||
select {
|
||||
case op.handler <- m:
|
||||
default:
|
||||
log.Warningf("client: handler channel of operation %s overflowed", op.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel the operation.
|
||||
func (op *Operation) Cancel() {
|
||||
op.client.Lock()
|
||||
defer op.client.Unlock()
|
||||
delete(op.client.operations, op.ID)
|
||||
close(op.handler)
|
||||
}
|
||||
|
||||
// Send sends a request to the API.
|
||||
func (op *Operation) Send(command, text string, data interface{}) {
|
||||
op.request = &Message{
|
||||
OpID: op.ID,
|
||||
Type: command,
|
||||
Key: text,
|
||||
Value: data,
|
||||
sent: abool.NewBool(false),
|
||||
}
|
||||
log.Tracef("client: [%s] sending %s msg: %s", op.request.OpID, op.request.Type, op.request.Key)
|
||||
op.client.send <- op.request
|
||||
}
|
||||
|
||||
// EnableResuscitation will resend the request after reconnecting to the API.
|
||||
func (op *Operation) EnableResuscitation() {
|
||||
op.resuscitationEnabled.Set()
|
||||
}
|
||||
|
||||
// NewOperation returns a new operation.
|
||||
func (c *Client) NewOperation(handleFunc func(*Message)) *Operation {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
c.nextOpID++
|
||||
op := &Operation{
|
||||
ID: fmt.Sprintf("#%d", c.nextOpID),
|
||||
client: c,
|
||||
handleFunc: handleFunc,
|
||||
handler: make(chan *Message, 100),
|
||||
resuscitationEnabled: abool.NewBool(false),
|
||||
}
|
||||
c.operations[op.ID] = op
|
||||
return op
|
||||
}
|
||||
28
base/api/client/const.go
Normal file
28
base/api/client/const.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package client
|
||||
|
||||
// Message Types.
|
||||
const (
|
||||
msgRequestGet = "get"
|
||||
msgRequestQuery = "query"
|
||||
msgRequestSub = "sub"
|
||||
msgRequestQsub = "qsub"
|
||||
msgRequestCreate = "create"
|
||||
msgRequestUpdate = "update"
|
||||
msgRequestInsert = "insert"
|
||||
msgRequestDelete = "delete"
|
||||
|
||||
MsgOk = "ok"
|
||||
MsgError = "error"
|
||||
MsgDone = "done"
|
||||
MsgSuccess = "success"
|
||||
MsgUpdate = "upd"
|
||||
MsgNew = "new"
|
||||
MsgDelete = "del"
|
||||
MsgWarning = "warning"
|
||||
|
||||
MsgOffline = "offline" // special message type for signaling the handler that the connection was lost
|
||||
|
||||
apiSeperator = "|"
|
||||
)
|
||||
|
||||
var apiSeperatorBytes = []byte(apiSeperator)
|
||||
95
base/api/client/message.go
Normal file
95
base/api/client/message.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/structures/container"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// ErrMalformedMessage is returned when a malformed message was encountered.
|
||||
var ErrMalformedMessage = errors.New("malformed message")
|
||||
|
||||
// Message is an API message.
|
||||
type Message struct {
|
||||
OpID string
|
||||
Type string
|
||||
Key string
|
||||
RawValue []byte
|
||||
Value interface{}
|
||||
sent *abool.AtomicBool
|
||||
}
|
||||
|
||||
// ParseMessage parses the given raw data and returns a Message.
|
||||
func ParseMessage(data []byte) (*Message, error) {
|
||||
parts := bytes.SplitN(data, apiSeperatorBytes, 4)
|
||||
if len(parts) < 2 {
|
||||
return nil, ErrMalformedMessage
|
||||
}
|
||||
|
||||
m := &Message{
|
||||
OpID: string(parts[0]),
|
||||
Type: string(parts[1]),
|
||||
}
|
||||
|
||||
switch m.Type {
|
||||
case MsgOk, MsgUpdate, MsgNew:
|
||||
// parse key and data
|
||||
// 127|ok|<key>|<data>
|
||||
// 127|upd|<key>|<data>
|
||||
// 127|new|<key>|<data>
|
||||
if len(parts) != 4 {
|
||||
return nil, ErrMalformedMessage
|
||||
}
|
||||
m.Key = string(parts[2])
|
||||
m.RawValue = parts[3]
|
||||
case MsgDelete:
|
||||
// parse key
|
||||
// 127|del|<key>
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrMalformedMessage
|
||||
}
|
||||
m.Key = string(parts[2])
|
||||
case MsgWarning, MsgError:
|
||||
// parse message
|
||||
// 127|error|<message>
|
||||
// 127|warning|<message> // error with single record, operation continues
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrMalformedMessage
|
||||
}
|
||||
m.Key = string(parts[2])
|
||||
case MsgDone, MsgSuccess:
|
||||
// nothing more to do
|
||||
// 127|success
|
||||
// 127|done
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Pack serializes a message into a []byte slice.
|
||||
func (m *Message) Pack() ([]byte, error) {
|
||||
c := container.New([]byte(m.OpID), apiSeperatorBytes, []byte(m.Type))
|
||||
|
||||
if m.Key != "" {
|
||||
c.Append(apiSeperatorBytes)
|
||||
c.Append([]byte(m.Key))
|
||||
if len(m.RawValue) > 0 {
|
||||
c.Append(apiSeperatorBytes)
|
||||
c.Append(m.RawValue)
|
||||
} else if m.Value != nil {
|
||||
var err error
|
||||
m.RawValue, err = dsd.Dump(m.Value, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Append(apiSeperatorBytes)
|
||||
c.Append(m.RawValue)
|
||||
}
|
||||
}
|
||||
|
||||
return c.CompileData(), nil
|
||||
}
|
||||
121
base/api/client/websocket.go
Normal file
121
base/api/client/websocket.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
type wsState struct {
|
||||
wsConn *websocket.Conn
|
||||
wg sync.WaitGroup
|
||||
failing *abool.AtomicBool
|
||||
failSignal chan struct{}
|
||||
}
|
||||
|
||||
func (c *Client) wsConnect() error {
|
||||
state := &wsState{
|
||||
failing: abool.NewBool(false),
|
||||
failSignal: make(chan struct{}),
|
||||
}
|
||||
|
||||
var err error
|
||||
state.wsConn, _, err = websocket.DefaultDialer.Dial(fmt.Sprintf("ws://%s/api/database/v1", c.server), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.signalOnline()
|
||||
|
||||
state.wg.Add(2)
|
||||
go c.wsReader(state)
|
||||
go c.wsWriter(state)
|
||||
|
||||
// wait for end of connection
|
||||
select {
|
||||
case <-state.failSignal:
|
||||
case <-c.shutdownSignal:
|
||||
state.Error("")
|
||||
}
|
||||
_ = state.wsConn.Close()
|
||||
state.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) wsReader(state *wsState) {
|
||||
defer state.wg.Done()
|
||||
for {
|
||||
_, data, err := state.wsConn.ReadMessage()
|
||||
log.Tracef("client: read message")
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
state.Error(fmt.Sprintf("client: read error: %s", err))
|
||||
} else {
|
||||
state.Error("client: connection closed by server")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Tracef("client: received message: %s", string(data))
|
||||
m, err := ParseMessage(data)
|
||||
if err != nil {
|
||||
log.Warningf("client: failed to parse message: %s", err)
|
||||
} else {
|
||||
select {
|
||||
case c.recv <- m:
|
||||
case <-state.failSignal:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) wsWriter(state *wsState) {
|
||||
defer state.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-state.failSignal:
|
||||
return
|
||||
case m := <-c.resend:
|
||||
data, err := m.Pack()
|
||||
if err == nil {
|
||||
err = state.wsConn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
if err != nil {
|
||||
state.Error(fmt.Sprintf("client: write error: %s", err))
|
||||
return
|
||||
}
|
||||
log.Tracef("client: sent message: %s", string(data))
|
||||
if m.sent != nil {
|
||||
m.sent.Set()
|
||||
}
|
||||
case m := <-c.send:
|
||||
data, err := m.Pack()
|
||||
if err == nil {
|
||||
err = state.wsConn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
if err != nil {
|
||||
c.resend <- m
|
||||
state.Error(fmt.Sprintf("client: write error: %s", err))
|
||||
return
|
||||
}
|
||||
log.Tracef("client: sent message: %s", string(data))
|
||||
if m.sent != nil {
|
||||
m.sent.Set()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (state *wsState) Error(message string) {
|
||||
if state.failing.SetToIf(false, true) {
|
||||
close(state.failSignal)
|
||||
if message != "" {
|
||||
log.Warning(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
91
base/api/config.go
Normal file
91
base/api/config.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
"github.com/safing/portmaster/base/config"
|
||||
)
|
||||
|
||||
// Config Keys.
|
||||
const (
|
||||
CfgDefaultListenAddressKey = "core/listenAddress"
|
||||
CfgAPIKeys = "core/apiKeys"
|
||||
)
|
||||
|
||||
var (
|
||||
listenAddressFlag string
|
||||
listenAddressConfig config.StringOption
|
||||
defaultListenAddress string
|
||||
|
||||
configuredAPIKeys config.StringArrayOption
|
||||
|
||||
devMode config.BoolOption
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(
|
||||
&listenAddressFlag,
|
||||
"api-address",
|
||||
"",
|
||||
"set api listen address; configuration is stronger",
|
||||
)
|
||||
}
|
||||
|
||||
func getDefaultListenAddress() string {
|
||||
// check if overridden
|
||||
if listenAddressFlag != "" {
|
||||
return listenAddressFlag
|
||||
}
|
||||
// return internal default
|
||||
return defaultListenAddress
|
||||
}
|
||||
|
||||
func registerConfig() error {
|
||||
err := config.Register(&config.Option{
|
||||
Name: "API Listen Address",
|
||||
Key: CfgDefaultListenAddressKey,
|
||||
Description: "Defines the IP address and port on which the internal API listens.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
DefaultValue: getDefaultListenAddress(),
|
||||
ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$",
|
||||
RequiresRestart: true,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: 513,
|
||||
config.CategoryAnnotation: "Development",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
listenAddressConfig = config.GetAsString(CfgDefaultListenAddressKey, getDefaultListenAddress())
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "API Keys",
|
||||
Key: CfgAPIKeys,
|
||||
Description: "Define API keys for privileged access to the API. Every entry is a separate API key with respective permissions. Format is `<key>?read=<perm>&write=<perm>`. Permissions are `anyone`, `user` and `admin`, and may be omitted.",
|
||||
Sensitive: true,
|
||||
OptType: config.OptTypeStringArray,
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
DefaultValue: []string{},
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: 514,
|
||||
config.CategoryAnnotation: "Development",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configuredAPIKeys = config.GetAsStringArray(CfgAPIKeys, []string{})
|
||||
|
||||
devMode = config.Concurrent.GetAsBool(config.CfgDevModeKey, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultAPIListenAddress sets the default listen address for the API.
|
||||
func SetDefaultAPIListenAddress(address string) {
|
||||
defaultListenAddress = address
|
||||
}
|
||||
698
base/api/database.go
Normal file
698
base/api/database.go
Normal file
@@ -0,0 +1,698 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/tevino/abool"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
"github.com/safing/structures/container"
|
||||
"github.com/safing/structures/dsd"
|
||||
"github.com/safing/structures/varint"
|
||||
)
|
||||
|
||||
const (
|
||||
dbMsgTypeOk = "ok"
|
||||
dbMsgTypeError = "error"
|
||||
dbMsgTypeDone = "done"
|
||||
dbMsgTypeSuccess = "success"
|
||||
dbMsgTypeUpd = "upd"
|
||||
dbMsgTypeNew = "new"
|
||||
dbMsgTypeDel = "del"
|
||||
dbMsgTypeWarning = "warning"
|
||||
|
||||
dbAPISeperator = "|"
|
||||
emptyString = ""
|
||||
)
|
||||
|
||||
var (
|
||||
dbAPISeperatorBytes = []byte(dbAPISeperator)
|
||||
dbCompatibilityPermission = PermitAdmin
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterHandler("/api/database/v1", WrapInAuthHandler(
|
||||
startDatabaseWebsocketAPI,
|
||||
// Default to admin read/write permissions until the database gets support
|
||||
// for api permissions.
|
||||
dbCompatibilityPermission,
|
||||
dbCompatibilityPermission,
|
||||
))
|
||||
}
|
||||
|
||||
// DatabaseAPI is a generic database API interface.
|
||||
type DatabaseAPI struct {
|
||||
queriesLock sync.Mutex
|
||||
queries map[string]*iterator.Iterator
|
||||
|
||||
subsLock sync.Mutex
|
||||
subs map[string]*database.Subscription
|
||||
|
||||
shutdownSignal chan struct{}
|
||||
shuttingDown *abool.AtomicBool
|
||||
db *database.Interface
|
||||
|
||||
sendBytes func(data []byte)
|
||||
}
|
||||
|
||||
// DatabaseWebsocketAPI is a database websocket API interface.
|
||||
type DatabaseWebsocketAPI struct {
|
||||
DatabaseAPI
|
||||
|
||||
sendQueue chan []byte
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
func allowAnyOrigin(r *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateDatabaseAPI creates a new database interface.
|
||||
func CreateDatabaseAPI(sendFunction func(data []byte)) DatabaseAPI {
|
||||
return DatabaseAPI{
|
||||
queries: make(map[string]*iterator.Iterator),
|
||||
subs: make(map[string]*database.Subscription),
|
||||
shutdownSignal: make(chan struct{}),
|
||||
shuttingDown: abool.NewBool(false),
|
||||
db: database.NewInterface(nil),
|
||||
sendBytes: sendFunction,
|
||||
}
|
||||
}
|
||||
|
||||
func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) {
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: allowAnyOrigin,
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 65536,
|
||||
}
|
||||
wsConn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("could not upgrade: %s", err)
|
||||
log.Error(errMsg)
|
||||
http.Error(w, errMsg, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newDBAPI := &DatabaseWebsocketAPI{
|
||||
DatabaseAPI: DatabaseAPI{
|
||||
queries: make(map[string]*iterator.Iterator),
|
||||
subs: make(map[string]*database.Subscription),
|
||||
shutdownSignal: make(chan struct{}),
|
||||
shuttingDown: abool.NewBool(false),
|
||||
db: database.NewInterface(nil),
|
||||
},
|
||||
|
||||
sendQueue: make(chan []byte, 100),
|
||||
conn: wsConn,
|
||||
}
|
||||
|
||||
newDBAPI.sendBytes = func(data []byte) {
|
||||
newDBAPI.sendQueue <- data
|
||||
}
|
||||
|
||||
module.mgr.Go("database api handler", newDBAPI.handler)
|
||||
module.mgr.Go("database api writer", newDBAPI.writer)
|
||||
|
||||
log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI)
|
||||
}
|
||||
|
||||
func (api *DatabaseWebsocketAPI) handler(_ *mgr.WorkerCtx) error {
|
||||
defer func() {
|
||||
_ = api.shutdown(nil)
|
||||
}()
|
||||
|
||||
for {
|
||||
_, msg, err := api.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return api.shutdown(err)
|
||||
}
|
||||
|
||||
api.Handle(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *DatabaseWebsocketAPI) writer(ctx *mgr.WorkerCtx) error {
|
||||
defer func() {
|
||||
_ = api.shutdown(nil)
|
||||
}()
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
for {
|
||||
select {
|
||||
// prioritize direct writes
|
||||
case data = <-api.sendQueue:
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-api.shutdownSignal:
|
||||
return nil
|
||||
}
|
||||
|
||||
// log.Tracef("api: sending %s", string(*msg))
|
||||
err = api.conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
if err != nil {
|
||||
return api.shutdown(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (api *DatabaseWebsocketAPI) shutdown(err error) error {
|
||||
// Check if we are the first to shut down.
|
||||
if !api.shuttingDown.SetToIf(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check the given error.
|
||||
if err != nil {
|
||||
if websocket.IsCloseError(err,
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseAbnormalClosure,
|
||||
) {
|
||||
log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
|
||||
} else {
|
||||
log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger shutdown.
|
||||
close(api.shutdownSignal)
|
||||
_ = api.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle handles a message for the database API.
|
||||
func (api *DatabaseAPI) Handle(msg []byte) {
|
||||
// 123|get|<key>
|
||||
// 123|ok|<key>|<data>
|
||||
// 123|error|<message>
|
||||
// 124|query|<query>
|
||||
// 124|ok|<key>|<data>
|
||||
// 124|done
|
||||
// 124|error|<message>
|
||||
// 124|warning|<message> // error with single record, operation continues
|
||||
// 124|cancel
|
||||
// 125|sub|<query>
|
||||
// 125|upd|<key>|<data>
|
||||
// 125|new|<key>|<data>
|
||||
// 127|del|<key>
|
||||
// 125|warning|<message> // error with single record, operation continues
|
||||
// 125|cancel
|
||||
// 127|qsub|<query>
|
||||
// 127|ok|<key>|<data>
|
||||
// 127|done
|
||||
// 127|error|<message>
|
||||
// 127|upd|<key>|<data>
|
||||
// 127|new|<key>|<data>
|
||||
// 127|del|<key>
|
||||
// 127|warning|<message> // error with single record, operation continues
|
||||
// 127|cancel
|
||||
|
||||
// 128|create|<key>|<data>
|
||||
// 128|success
|
||||
// 128|error|<message>
|
||||
// 129|update|<key>|<data>
|
||||
// 129|success
|
||||
// 129|error|<message>
|
||||
// 130|insert|<key>|<data>
|
||||
// 130|success
|
||||
// 130|error|<message>
|
||||
// 131|delete|<key>
|
||||
// 131|success
|
||||
// 131|error|<message>
|
||||
|
||||
parts := bytes.SplitN(msg, []byte("|"), 3)
|
||||
|
||||
// Handle special command "cancel"
|
||||
if len(parts) == 2 && string(parts[1]) == "cancel" {
|
||||
// 124|cancel
|
||||
// 125|cancel
|
||||
// 127|cancel
|
||||
go api.handleCancel(parts[0])
|
||||
return
|
||||
}
|
||||
|
||||
if len(parts) != 3 {
|
||||
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
|
||||
return
|
||||
}
|
||||
|
||||
switch string(parts[1]) {
|
||||
case "get":
|
||||
// 123|get|<key>
|
||||
go api.handleGet(parts[0], string(parts[2]))
|
||||
case "query":
|
||||
// 124|query|<query>
|
||||
go api.handleQuery(parts[0], string(parts[2]))
|
||||
case "sub":
|
||||
// 125|sub|<query>
|
||||
go api.handleSub(parts[0], string(parts[2]))
|
||||
case "qsub":
|
||||
// 127|qsub|<query>
|
||||
go api.handleQsub(parts[0], string(parts[2]))
|
||||
case "create", "update", "insert":
|
||||
// split key and payload
|
||||
dataParts := bytes.SplitN(parts[2], []byte("|"), 2)
|
||||
if len(dataParts) != 2 {
|
||||
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
|
||||
return
|
||||
}
|
||||
|
||||
switch string(parts[1]) {
|
||||
case "create":
|
||||
// 128|create|<key>|<data>
|
||||
go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true)
|
||||
case "update":
|
||||
// 129|update|<key>|<data>
|
||||
go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false)
|
||||
case "insert":
|
||||
// 130|insert|<key>|<data>
|
||||
go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1])
|
||||
}
|
||||
case "delete":
|
||||
// 131|delete|<key>
|
||||
go api.handleDelete(parts[0], string(parts[2]))
|
||||
default:
|
||||
api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) {
|
||||
c := container.New(opID)
|
||||
c.Append(dbAPISeperatorBytes)
|
||||
c.Append([]byte(msgType))
|
||||
|
||||
if msgOrKey != emptyString {
|
||||
c.Append(dbAPISeperatorBytes)
|
||||
c.Append([]byte(msgOrKey))
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
c.Append(dbAPISeperatorBytes)
|
||||
c.Append(data)
|
||||
}
|
||||
|
||||
api.sendBytes(c.CompileData())
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleGet(opID []byte, key string) {
|
||||
// 123|get|<key>
|
||||
// 123|ok|<key>|<data>
|
||||
// 123|error|<message>
|
||||
|
||||
var data []byte
|
||||
|
||||
r, err := api.db.Get(key)
|
||||
if err == nil {
|
||||
data, err = MarshalRecord(r, true)
|
||||
}
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
api.send(opID, dbMsgTypeOk, r.Key(), data)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleQuery(opID []byte, queryText string) {
|
||||
// 124|query|<query>
|
||||
// 124|ok|<key>|<data>
|
||||
// 124|done
|
||||
// 124|warning|<message>
|
||||
// 124|error|<message>
|
||||
// 124|warning|<message> // error with single record, operation continues
|
||||
// 124|cancel
|
||||
|
||||
var err error
|
||||
|
||||
q, err := query.ParseQuery(queryText)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
api.processQuery(opID, q)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
|
||||
it, err := api.db.Query(q)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return false
|
||||
}
|
||||
|
||||
// Save query iterator.
|
||||
api.queriesLock.Lock()
|
||||
api.queries[string(opID)] = it
|
||||
api.queriesLock.Unlock()
|
||||
|
||||
// Remove query iterator after it ended.
|
||||
defer func() {
|
||||
api.queriesLock.Lock()
|
||||
defer api.queriesLock.Unlock()
|
||||
delete(api.queries, string(opID))
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-api.shutdownSignal:
|
||||
// cancel query and return
|
||||
it.Cancel()
|
||||
return false
|
||||
case r := <-it.Next:
|
||||
// process query feed
|
||||
if r != nil {
|
||||
// process record
|
||||
data, err := MarshalRecord(r, true)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
|
||||
continue
|
||||
}
|
||||
api.send(opID, dbMsgTypeOk, r.Key(), data)
|
||||
} else {
|
||||
// sub feed ended
|
||||
if it.Err() != nil {
|
||||
api.send(opID, dbMsgTypeError, it.Err().Error(), nil)
|
||||
return false
|
||||
}
|
||||
api.send(opID, dbMsgTypeDone, emptyString, nil)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// func (api *DatabaseWebsocketAPI) runQuery()
|
||||
|
||||
func (api *DatabaseAPI) handleSub(opID []byte, queryText string) {
|
||||
// 125|sub|<query>
|
||||
// 125|upd|<key>|<data>
|
||||
// 125|new|<key>|<data>
|
||||
// 125|delete|<key>
|
||||
// 125|warning|<message> // error with single record, operation continues
|
||||
// 125|cancel
|
||||
var err error
|
||||
|
||||
q, err := query.ParseQuery(queryText)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
sub, ok := api.registerSub(opID, q)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
api.processSub(opID, sub)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database.Subscription, ok bool) {
|
||||
var err error
|
||||
sub, err = api.db.Subscribe(q)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return sub, true
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
|
||||
// Save subscription.
|
||||
api.subsLock.Lock()
|
||||
api.subs[string(opID)] = sub
|
||||
api.subsLock.Unlock()
|
||||
|
||||
// Remove subscription after it ended.
|
||||
defer func() {
|
||||
api.subsLock.Lock()
|
||||
defer api.subsLock.Unlock()
|
||||
delete(api.subs, string(opID))
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-api.shutdownSignal:
|
||||
// cancel sub and return
|
||||
_ = sub.Cancel()
|
||||
return
|
||||
case r := <-sub.Feed:
|
||||
// process sub feed
|
||||
if r != nil {
|
||||
// process record
|
||||
data, err := MarshalRecord(r, true)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
|
||||
continue
|
||||
}
|
||||
// TODO: use upd, new and delete msgTypes
|
||||
r.Lock()
|
||||
isDeleted := r.Meta().IsDeleted()
|
||||
isNew := r.Meta().Created == r.Meta().Modified
|
||||
r.Unlock()
|
||||
switch {
|
||||
case isDeleted:
|
||||
api.send(opID, dbMsgTypeDel, r.Key(), nil)
|
||||
case isNew:
|
||||
api.send(opID, dbMsgTypeNew, r.Key(), data)
|
||||
default:
|
||||
api.send(opID, dbMsgTypeUpd, r.Key(), data)
|
||||
}
|
||||
} else {
|
||||
// sub feed ended
|
||||
api.send(opID, dbMsgTypeDone, "", nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) {
|
||||
// 127|qsub|<query>
|
||||
// 127|ok|<key>|<data>
|
||||
// 127|done
|
||||
// 127|error|<message>
|
||||
// 127|upd|<key>|<data>
|
||||
// 127|new|<key>|<data>
|
||||
// 127|delete|<key>
|
||||
// 127|warning|<message> // error with single record, operation continues
|
||||
// 127|cancel
|
||||
|
||||
var err error
|
||||
|
||||
q, err := query.ParseQuery(queryText)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
sub, ok := api.registerSub(opID, q)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ok = api.processQuery(opID, q)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
api.processSub(opID, sub)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleCancel(opID []byte) {
|
||||
api.cancelQuery(opID)
|
||||
api.cancelSub(opID)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) cancelQuery(opID []byte) {
|
||||
api.queriesLock.Lock()
|
||||
defer api.queriesLock.Unlock()
|
||||
|
||||
// Get subscription from api.
|
||||
it, ok := api.queries[string(opID)]
|
||||
if !ok {
|
||||
// Fail silently as quries end by themselves when finished.
|
||||
return
|
||||
}
|
||||
|
||||
// End query.
|
||||
it.Cancel()
|
||||
|
||||
// The query handler will end the communication with a done message.
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) cancelSub(opID []byte) {
|
||||
api.subsLock.Lock()
|
||||
defer api.subsLock.Unlock()
|
||||
|
||||
// Get subscription from api.
|
||||
sub, ok := api.subs[string(opID)]
|
||||
if !ok {
|
||||
api.send(opID, dbMsgTypeError, "could not find subscription", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// End subscription.
|
||||
err := sub.Cancel()
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, fmt.Sprintf("failed to cancel subscription: %s", err), nil)
|
||||
}
|
||||
|
||||
// The subscription handler will end the communication with a done message.
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create bool) {
|
||||
// 128|create|<key>|<data>
|
||||
// 128|success
|
||||
// 128|error|<message>
|
||||
|
||||
// 129|update|<key>|<data>
|
||||
// 129|success
|
||||
// 129|error|<message>
|
||||
|
||||
if len(data) < 2 {
|
||||
api.send(opID, dbMsgTypeError, "bad request: malformed message", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO - staged for deletion: remove transition code
|
||||
// if data[0] != dsd.JSON {
|
||||
// typedData := make([]byte, len(data)+1)
|
||||
// typedData[0] = dsd.JSON
|
||||
// copy(typedData[1:], data)
|
||||
// data = typedData
|
||||
// }
|
||||
|
||||
r, err := record.NewWrapper(key, nil, data[0], data[1:])
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
if create {
|
||||
err = api.db.PutNew(r)
|
||||
} else {
|
||||
err = api.db.Put(r)
|
||||
}
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
api.send(opID, dbMsgTypeSuccess, emptyString, nil)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleInsert(opID []byte, key string, data []byte) {
|
||||
// 130|insert|<key>|<data>
|
||||
// 130|success
|
||||
// 130|error|<message>
|
||||
|
||||
r, err := api.db.Get(key)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
acc := r.GetAccessor(r)
|
||||
|
||||
result := gjson.ParseBytes(data)
|
||||
anythingPresent := false
|
||||
var insertError error
|
||||
result.ForEach(func(key gjson.Result, value gjson.Result) bool {
|
||||
anythingPresent = true
|
||||
if !key.Exists() {
|
||||
insertError = errors.New("values must be in a map")
|
||||
return false
|
||||
}
|
||||
if key.Type != gjson.String {
|
||||
insertError = errors.New("keys must be strings")
|
||||
return false
|
||||
}
|
||||
if !value.Exists() {
|
||||
insertError = errors.New("non-existent value")
|
||||
return false
|
||||
}
|
||||
insertError = acc.Set(key.String(), value.Value())
|
||||
return insertError == nil
|
||||
})
|
||||
|
||||
if insertError != nil {
|
||||
api.send(opID, dbMsgTypeError, insertError.Error(), nil)
|
||||
return
|
||||
}
|
||||
if !anythingPresent {
|
||||
api.send(opID, dbMsgTypeError, "could not find any valid values", nil)
|
||||
return
|
||||
}
|
||||
|
||||
err = api.db.Put(r)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
api.send(opID, dbMsgTypeSuccess, emptyString, nil)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleDelete(opID []byte, key string) {
|
||||
// 131|delete|<key>
|
||||
// 131|success
|
||||
// 131|error|<message>
|
||||
|
||||
err := api.db.Delete(key)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
api.send(opID, dbMsgTypeSuccess, emptyString, nil)
|
||||
}
|
||||
|
||||
// MarshalRecord locks and marshals the given record, additionally adding
|
||||
// metadata and returning it as json.
|
||||
func MarshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
// Pour record into JSON.
|
||||
jsonData, err := r.Marshal(r, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove JSON identifier for manual editing.
|
||||
jsonData = bytes.TrimPrefix(jsonData, varint.Pack8(dsd.JSON))
|
||||
|
||||
// Add metadata.
|
||||
jsonData, err = sjson.SetBytes(jsonData, "_meta", r.Meta())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add database key.
|
||||
jsonData, err = sjson.SetBytes(jsonData, "_meta.Key", r.Key())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add JSON identifier again.
|
||||
if withDSDIdentifier {
|
||||
formatID := varint.Pack8(dsd.JSON)
|
||||
finalData := make([]byte, 0, len(formatID)+len(jsonData))
|
||||
finalData = append(finalData, formatID...)
|
||||
finalData = append(finalData, jsonData...)
|
||||
return finalData, nil
|
||||
}
|
||||
return jsonData, nil
|
||||
}
|
||||
10
base/api/doc.go
Normal file
10
base/api/doc.go
Normal file
@@ -0,0 +1,10 @@
|
||||
/*
|
||||
Package api provides an API for integration with other components of the same software package and also third party components.
|
||||
|
||||
It provides direct database access as well as a simpler way to register API endpoints. You can of course also register raw `http.Handler`s directly.
|
||||
|
||||
Optional authentication guards registered handlers. This is achieved by attaching functions to the `http.Handler`s that are registered, which allow them to specify the required permissions for the handler.
|
||||
|
||||
The permissions are divided into the roles and assume a single user per host. The Roles are User, Admin and Self. User roles are expected to have mostly read access and react to notifications or system events, like a system tray program. The Admin role is meant for advanced components that also change settings, but are restricted so they cannot break the software. Self is reserved for internal use with full access.
|
||||
*/
|
||||
package api
|
||||
521
base/api/endpoints.go
Normal file
521
base/api/endpoints.go
Normal file
@@ -0,0 +1,521 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// Endpoint describes an API Endpoint.
|
||||
// Path and at least one permission are required.
|
||||
// As is exactly one function.
|
||||
type Endpoint struct { //nolint:maligned
|
||||
// Name is the human reabable name of the endpoint.
|
||||
Name string
|
||||
// Description is the human readable description and documentation of the endpoint.
|
||||
Description string
|
||||
// Parameters is the parameter documentation.
|
||||
Parameters []Parameter `json:",omitempty"`
|
||||
|
||||
// Path describes the URL path of the endpoint.
|
||||
Path string
|
||||
|
||||
// MimeType defines the content type of the returned data.
|
||||
MimeType string
|
||||
|
||||
// Read defines the required read permission.
|
||||
Read Permission `json:",omitempty"`
|
||||
|
||||
// ReadMethod sets the required read method for the endpoint.
|
||||
// Available methods are:
|
||||
// GET: Returns data only, no action is taken, nothing is changed.
|
||||
// If omitted, defaults to GET.
|
||||
//
|
||||
// This field is currently being introduced and will only warn and not deny
|
||||
// access if the write method does not match.
|
||||
ReadMethod string `json:",omitempty"`
|
||||
|
||||
// Write defines the required write permission.
|
||||
Write Permission `json:",omitempty"`
|
||||
|
||||
// WriteMethod sets the required write method for the endpoint.
|
||||
// Available methods are:
|
||||
// POST: Create a new resource; Change a status; Execute a function
|
||||
// PUT: Update an existing resource
|
||||
// DELETE: Remove an existing resource
|
||||
// If omitted, defaults to POST.
|
||||
//
|
||||
// This field is currently being introduced and will only warn and not deny
|
||||
// access if the write method does not match.
|
||||
WriteMethod string `json:",omitempty"`
|
||||
|
||||
// ActionFunc is for simple actions with a return message for the user.
|
||||
ActionFunc ActionFunc `json:"-"`
|
||||
|
||||
// DataFunc is for returning raw data that the caller for further processing.
|
||||
DataFunc DataFunc `json:"-"`
|
||||
|
||||
// StructFunc is for returning any kind of struct.
|
||||
StructFunc StructFunc `json:"-"`
|
||||
|
||||
// RecordFunc is for returning a database record. It will be properly locked
|
||||
// and marshalled including metadata.
|
||||
RecordFunc RecordFunc `json:"-"`
|
||||
|
||||
// HandlerFunc is the raw http handler.
|
||||
HandlerFunc http.HandlerFunc `json:"-"`
|
||||
}
|
||||
|
||||
// Parameter describes a parameterized variation of an endpoint.
|
||||
type Parameter struct {
|
||||
Method string
|
||||
Field string
|
||||
Value string
|
||||
Description string
|
||||
}
|
||||
|
||||
// HTTPStatusProvider is an interface for errors to provide a custom HTTP
|
||||
// status code.
|
||||
type HTTPStatusProvider interface {
|
||||
HTTPStatus() int
|
||||
}
|
||||
|
||||
// HTTPStatusError represents an error with an HTTP status code.
|
||||
type HTTPStatusError struct {
|
||||
err error
|
||||
code int
|
||||
}
|
||||
|
||||
// Error returns the error message.
|
||||
func (e *HTTPStatusError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
// Unwrap return the wrapped error.
|
||||
func (e *HTTPStatusError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// HTTPStatus returns the HTTP status code this error.
|
||||
func (e *HTTPStatusError) HTTPStatus() int {
|
||||
return e.code
|
||||
}
|
||||
|
||||
// ErrorWithStatus adds the HTTP status code to the error.
|
||||
func ErrorWithStatus(err error, code int) error {
|
||||
return &HTTPStatusError{
|
||||
err: err,
|
||||
code: code,
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
// ActionFunc is for simple actions with a return message for the user.
|
||||
ActionFunc func(ar *Request) (msg string, err error)
|
||||
|
||||
// DataFunc is for returning raw data that the caller for further processing.
|
||||
DataFunc func(ar *Request) (data []byte, err error)
|
||||
|
||||
// StructFunc is for returning any kind of struct.
|
||||
StructFunc func(ar *Request) (i interface{}, err error)
|
||||
|
||||
// RecordFunc is for returning a database record. It will be properly locked
|
||||
// and marshalled including metadata.
|
||||
RecordFunc func(ar *Request) (r record.Record, err error)
|
||||
)
|
||||
|
||||
// MIME Types.
|
||||
const (
|
||||
MimeTypeJSON string = "application/json"
|
||||
MimeTypeText string = "text/plain"
|
||||
|
||||
apiV1Path = "/api/v1/"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterHandler(apiV1Path+"{endpointPath:.+}", &endpointHandler{})
|
||||
}
|
||||
|
||||
var (
|
||||
endpoints = make(map[string]*Endpoint)
|
||||
endpointsMux = mux.NewRouter()
|
||||
endpointsLock sync.RWMutex
|
||||
|
||||
// ErrInvalidEndpoint is returned when an invalid endpoint is registered.
|
||||
ErrInvalidEndpoint = errors.New("endpoint is invalid")
|
||||
|
||||
// ErrAlreadyRegistered is returned when there already is an endpoint with
|
||||
// the same path registered.
|
||||
ErrAlreadyRegistered = errors.New("an endpoint for this path is already registered")
|
||||
)
|
||||
|
||||
func getAPIContext(r *http.Request) (apiEndpoint *Endpoint, apiRequest *Request) {
|
||||
// Get request context and check if we already have an action cached.
|
||||
apiRequest = GetAPIRequest(r)
|
||||
if apiRequest == nil {
|
||||
return nil, nil
|
||||
}
|
||||
var ok bool
|
||||
apiEndpoint, ok = apiRequest.HandlerCache.(*Endpoint)
|
||||
if ok {
|
||||
return apiEndpoint, apiRequest
|
||||
}
|
||||
|
||||
endpointsLock.RLock()
|
||||
defer endpointsLock.RUnlock()
|
||||
|
||||
// Get handler for request.
|
||||
// Gorilla does not support handling this on our own very well.
|
||||
// See github.com/gorilla/mux.ServeHTTP for reference.
|
||||
var match mux.RouteMatch
|
||||
var handler http.Handler
|
||||
if endpointsMux.Match(r, &match) {
|
||||
handler = match.Handler
|
||||
apiRequest.Route = match.Route
|
||||
// Add/Override variables instead of replacing.
|
||||
for k, v := range match.Vars {
|
||||
apiRequest.URLVars[k] = v
|
||||
}
|
||||
} else {
|
||||
return nil, apiRequest
|
||||
}
|
||||
|
||||
apiEndpoint, ok = handler.(*Endpoint)
|
||||
if ok {
|
||||
// Cache for next operation.
|
||||
apiRequest.HandlerCache = apiEndpoint
|
||||
}
|
||||
return apiEndpoint, apiRequest
|
||||
}
|
||||
|
||||
// RegisterEndpoint registers a new endpoint. An error will be returned if it
|
||||
// does not pass the sanity checks.
|
||||
func RegisterEndpoint(e Endpoint) error {
|
||||
if err := e.check(); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrInvalidEndpoint, err)
|
||||
}
|
||||
|
||||
endpointsLock.Lock()
|
||||
defer endpointsLock.Unlock()
|
||||
|
||||
_, ok := endpoints[e.Path]
|
||||
if ok {
|
||||
return ErrAlreadyRegistered
|
||||
}
|
||||
|
||||
endpoints[e.Path] = &e
|
||||
endpointsMux.Handle(apiV1Path+e.Path, &e)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEndpointByPath returns the endpoint registered with the given path.
|
||||
func GetEndpointByPath(path string) (*Endpoint, error) {
|
||||
endpointsLock.Lock()
|
||||
defer endpointsLock.Unlock()
|
||||
endpoint, ok := endpoints[path]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no registered endpoint on path: %q", path)
|
||||
}
|
||||
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) check() error {
|
||||
// Check path.
|
||||
if strings.TrimSpace(e.Path) == "" {
|
||||
return errors.New("path is missing")
|
||||
}
|
||||
|
||||
// Check permissions.
|
||||
if e.Read < Dynamic || e.Read > PermitSelf {
|
||||
return errors.New("invalid read permission")
|
||||
}
|
||||
if e.Write < Dynamic || e.Write > PermitSelf {
|
||||
return errors.New("invalid write permission")
|
||||
}
|
||||
|
||||
// Check methods.
|
||||
if e.Read != NotSupported {
|
||||
switch e.ReadMethod {
|
||||
case http.MethodGet:
|
||||
// All good.
|
||||
case "":
|
||||
// Set to default.
|
||||
e.ReadMethod = http.MethodGet
|
||||
default:
|
||||
return errors.New("invalid read method")
|
||||
}
|
||||
} else {
|
||||
e.ReadMethod = ""
|
||||
}
|
||||
if e.Write != NotSupported {
|
||||
switch e.WriteMethod {
|
||||
case http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodDelete:
|
||||
// All good.
|
||||
case "":
|
||||
// Set to default.
|
||||
e.WriteMethod = http.MethodPost
|
||||
default:
|
||||
return errors.New("invalid write method")
|
||||
}
|
||||
} else {
|
||||
e.WriteMethod = ""
|
||||
}
|
||||
|
||||
// Check functions.
|
||||
var defaultMimeType string
|
||||
fnCnt := 0
|
||||
if e.ActionFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if e.DataFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if e.StructFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeJSON
|
||||
}
|
||||
if e.RecordFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeJSON
|
||||
}
|
||||
if e.HandlerFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if fnCnt != 1 {
|
||||
return errors.New("only one function may be set")
|
||||
}
|
||||
|
||||
// Set default mime type.
|
||||
if e.MimeType == "" {
|
||||
e.MimeType = defaultMimeType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportEndpoints exports the registered endpoints. The returned data must be
|
||||
// treated as immutable.
|
||||
func ExportEndpoints() []*Endpoint {
|
||||
endpointsLock.RLock()
|
||||
defer endpointsLock.RUnlock()
|
||||
|
||||
// Copy the map into a slice.
|
||||
eps := make([]*Endpoint, 0, len(endpoints))
|
||||
for _, ep := range endpoints {
|
||||
eps = append(eps, ep)
|
||||
}
|
||||
|
||||
sort.Sort(sortByPath(eps))
|
||||
return eps
|
||||
}
|
||||
|
||||
type sortByPath []*Endpoint
|
||||
|
||||
func (eps sortByPath) Len() int { return len(eps) }
|
||||
func (eps sortByPath) Less(i, j int) bool { return eps[i].Path < eps[j].Path }
|
||||
func (eps sortByPath) Swap(i, j int) { eps[i], eps[j] = eps[j], eps[i] }
|
||||
|
||||
type endpointHandler struct{}
|
||||
|
||||
var _ AuthenticatedHandler = &endpointHandler{} // Compile time interface check.
|
||||
|
||||
// ReadPermission returns the read permission for the handler.
|
||||
func (eh *endpointHandler) ReadPermission(r *http.Request) Permission {
|
||||
apiEndpoint, _ := getAPIContext(r)
|
||||
if apiEndpoint != nil {
|
||||
return apiEndpoint.Read
|
||||
}
|
||||
return NotFound
|
||||
}
|
||||
|
||||
// WritePermission returns the write permission for the handler.
|
||||
func (eh *endpointHandler) WritePermission(r *http.Request) Permission {
|
||||
apiEndpoint, _ := getAPIContext(r)
|
||||
if apiEndpoint != nil {
|
||||
return apiEndpoint.Write
|
||||
}
|
||||
return NotFound
|
||||
}
|
||||
|
||||
// ServeHTTP handles the http request.
|
||||
func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
apiEndpoint, apiRequest := getAPIContext(r)
|
||||
if apiEndpoint == nil || apiRequest == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
apiEndpoint.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// ServeHTTP handles the http request.
|
||||
func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
_, apiRequest := getAPIContext(r)
|
||||
if apiRequest == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Return OPTIONS request before starting to handle normal requests.
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
eMethod, readMethod, ok := getEffectiveMethod(r)
|
||||
if !ok {
|
||||
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if readMethod {
|
||||
if eMethod != e.ReadMethod {
|
||||
log.Tracer(r.Context()).Warningf(
|
||||
"api: method %q does not match required read method %q%s",
|
||||
r.Method,
|
||||
e.ReadMethod,
|
||||
" - this will be an error and abort the request in the future",
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if eMethod != e.WriteMethod {
|
||||
log.Tracer(r.Context()).Warningf(
|
||||
"api: method %q does not match required write method %q%s",
|
||||
r.Method,
|
||||
e.WriteMethod,
|
||||
" - this will be an error and abort the request in the future",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch eMethod {
|
||||
case http.MethodGet, http.MethodDelete:
|
||||
// Nothing to do for these.
|
||||
case http.MethodPost, http.MethodPut:
|
||||
// Read body data.
|
||||
inputData, ok := readBody(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
apiRequest.InputData = inputData
|
||||
|
||||
// restore request body for any http.HandlerFunc below
|
||||
r.Body = io.NopCloser(bytes.NewReader(inputData))
|
||||
default:
|
||||
// Defensive.
|
||||
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Add response headers to request struct so that the endpoint can work with them.
|
||||
apiRequest.ResponseHeader = w.Header()
|
||||
|
||||
// Execute action function and get response data
|
||||
var responseData []byte
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case e.ActionFunc != nil:
|
||||
var msg string
|
||||
msg, err = e.ActionFunc(apiRequest)
|
||||
if !strings.HasSuffix(msg, "\n") {
|
||||
msg += "\n"
|
||||
}
|
||||
if err == nil {
|
||||
responseData = []byte(msg)
|
||||
}
|
||||
|
||||
case e.DataFunc != nil:
|
||||
responseData, err = e.DataFunc(apiRequest)
|
||||
|
||||
case e.StructFunc != nil:
|
||||
var v interface{}
|
||||
v, err = e.StructFunc(apiRequest)
|
||||
if err == nil && v != nil {
|
||||
var mimeType string
|
||||
responseData, mimeType, _, err = dsd.MimeDump(v, r.Header.Get("Accept"))
|
||||
if err == nil {
|
||||
w.Header().Set("Content-Type", mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
case e.RecordFunc != nil:
|
||||
var rec record.Record
|
||||
rec, err = e.RecordFunc(apiRequest)
|
||||
if err == nil && r != nil {
|
||||
responseData, err = MarshalRecord(rec, false)
|
||||
}
|
||||
|
||||
case e.HandlerFunc != nil:
|
||||
e.HandlerFunc(w, r)
|
||||
return
|
||||
|
||||
default:
|
||||
http.Error(w, "missing handler", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for handler error.
|
||||
if err != nil {
|
||||
var statusProvider HTTPStatusProvider
|
||||
if errors.As(err, &statusProvider) {
|
||||
http.Error(w, err.Error(), statusProvider.HTTPStatus())
|
||||
} else {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Return no content if there is none, or if request is HEAD.
|
||||
if len(responseData) == 0 || r.Method == http.MethodHead {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
// Set content type if not yet set.
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
|
||||
}
|
||||
|
||||
// Write response.
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(responseData)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err = w.Write(responseData)
|
||||
if err != nil {
|
||||
log.Tracer(r.Context()).Warningf("api: failed to write response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func readBody(w http.ResponseWriter, r *http.Request) (inputData []byte, ok bool) {
|
||||
// Check for too long content in order to prevent death.
|
||||
if r.ContentLength > 20000000 { // 20MB
|
||||
http.Error(w, "too much input data", http.StatusRequestEntityTooLarge)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Read and close body.
|
||||
inputData, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to read body"+err.Error(), http.StatusInternalServerError)
|
||||
return nil, false
|
||||
}
|
||||
return inputData, true
|
||||
}
|
||||
24
base/api/endpoints_config.go
Normal file
24
base/api/endpoints_config.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/config"
|
||||
)
|
||||
|
||||
func registerConfigEndpoints() error {
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "config/options",
|
||||
Read: PermitAnyone,
|
||||
MimeType: MimeTypeJSON,
|
||||
StructFunc: listConfig,
|
||||
Name: "Export Configuration Options",
|
||||
Description: "Returns a list of all registered configuration options and their metadata. This does not include the current active or default settings.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func listConfig(ar *Request) (i interface{}, err error) {
|
||||
return config.ExportOptions(), nil
|
||||
}
|
||||
249
base/api/endpoints_debug.go
Normal file
249
base/api/endpoints_debug.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/info"
|
||||
"github.com/safing/portmaster/base/utils/debug"
|
||||
)
|
||||
|
||||
func registerDebugEndpoints() error {
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "ping",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: ping,
|
||||
Name: "Ping",
|
||||
Description: "Pong.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "ready",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: ready,
|
||||
Name: "Ready",
|
||||
Description: "Check if Portmaster has completed starting and is ready.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/stack",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: getStack,
|
||||
Name: "Get Goroutine Stack",
|
||||
Description: "Returns the current goroutine stack.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/stack/print",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: printStack,
|
||||
Name: "Print Goroutine Stack",
|
||||
Description: "Prints the current goroutine stack to stdout.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/cpu",
|
||||
MimeType: "application/octet-stream",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: handleCPUProfile,
|
||||
Name: "Get CPU Profile",
|
||||
Description: strings.ReplaceAll(`Gather and return the CPU profile.
|
||||
This data needs to gathered over a period of time, which is specified using the duration parameter.
|
||||
|
||||
You can easily view this data in your browser with this command (with Go installed):
|
||||
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/cpu"
|
||||
`, `"`, "`"),
|
||||
Parameters: []Parameter{{
|
||||
Method: http.MethodGet,
|
||||
Field: "duration",
|
||||
Value: "10s",
|
||||
Description: "Specify the formatting style. The default is simple markdown formatting.",
|
||||
}},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/heap",
|
||||
MimeType: "application/octet-stream",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: handleHeapProfile,
|
||||
Name: "Get Heap Profile",
|
||||
Description: strings.ReplaceAll(`Gather and return the heap memory profile.
|
||||
|
||||
You can easily view this data in your browser with this command (with Go installed):
|
||||
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/heap"
|
||||
`, `"`, "`"),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/allocs",
|
||||
MimeType: "application/octet-stream",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: handleAllocsProfile,
|
||||
Name: "Get Allocs Profile",
|
||||
Description: strings.ReplaceAll(`Gather and return the memory allocation profile.
|
||||
|
||||
You can easily view this data in your browser with this command (with Go installed):
|
||||
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/allocs"
|
||||
`, `"`, "`"),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/info",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: debugInfo,
|
||||
Name: "Get Debug Information",
|
||||
Description: "Returns debugging information, including the version and platform info, errors, logs and the current goroutine stack.",
|
||||
Parameters: []Parameter{{
|
||||
Method: http.MethodGet,
|
||||
Field: "style",
|
||||
Value: "github",
|
||||
Description: "Specify the formatting style. The default is simple markdown formatting.",
|
||||
}},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ping responds with pong.
|
||||
func ping(ar *Request) (msg string, err error) {
|
||||
return "Pong.", nil
|
||||
}
|
||||
|
||||
// ready checks if Portmaster has completed starting.
|
||||
func ready(ar *Request) (msg string, err error) {
|
||||
if module.instance.Ready() {
|
||||
return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly)
|
||||
}
|
||||
return "Portmaster is ready.", nil
|
||||
}
|
||||
|
||||
// getStack returns the current goroutine stack.
|
||||
func getStack(_ *Request) (data []byte, err error) {
|
||||
buf := &bytes.Buffer{}
|
||||
err = pprof.Lookup("goroutine").WriteTo(buf, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// printStack prints the current goroutine stack to stderr.
|
||||
func printStack(_ *Request) (msg string, err error) {
|
||||
_, err = fmt.Fprint(os.Stderr, "===== PRINTING STACK =====\n")
|
||||
if err == nil {
|
||||
err = pprof.Lookup("goroutine").WriteTo(os.Stderr, 1)
|
||||
}
|
||||
if err == nil {
|
||||
_, err = fmt.Fprint(os.Stderr, "===== END OF STACK =====\n")
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "stack printed to stdout", nil
|
||||
}
|
||||
|
||||
// handleCPUProfile returns the CPU profile.
|
||||
func handleCPUProfile(ar *Request) (data []byte, err error) {
|
||||
// Parse duration.
|
||||
duration := 10 * time.Second
|
||||
if durationOption := ar.Request.URL.Query().Get("duration"); durationOption != "" {
|
||||
parsedDuration, err := time.ParseDuration(durationOption)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration: %w", err)
|
||||
}
|
||||
duration = parsedDuration
|
||||
}
|
||||
|
||||
// Indicate download and filename.
|
||||
ar.ResponseHeader.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(`attachment; filename="portmaster-cpu-profile_v%s.pprof"`, info.Version()),
|
||||
)
|
||||
|
||||
// Start CPU profiling.
|
||||
buf := new(bytes.Buffer)
|
||||
if err := pprof.StartCPUProfile(buf); err != nil {
|
||||
return nil, fmt.Errorf("failed to start cpu profile: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the specified duration.
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
case <-ar.Context().Done():
|
||||
pprof.StopCPUProfile()
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Stop CPU profiling and return data.
|
||||
pprof.StopCPUProfile()
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// handleHeapProfile returns the Heap profile.
|
||||
func handleHeapProfile(ar *Request) (data []byte, err error) {
|
||||
// Indicate download and filename.
|
||||
ar.ResponseHeader.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(`attachment; filename="portmaster-memory-heap-profile_v%s.pprof"`, info.Version()),
|
||||
)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := pprof.Lookup("heap").WriteTo(buf, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to write heap profile: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// handleAllocsProfile returns the Allocs profile.
|
||||
func handleAllocsProfile(ar *Request) (data []byte, err error) {
|
||||
// Indicate download and filename.
|
||||
ar.ResponseHeader.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(`attachment; filename="portmaster-memory-allocs-profile_v%s.pprof"`, info.Version()),
|
||||
)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := pprof.Lookup("allocs").WriteTo(buf, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to write allocs profile: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// debugInfo returns the debugging information for support requests.
|
||||
func debugInfo(ar *Request) (data []byte, err error) {
|
||||
// Create debug information helper.
|
||||
di := new(debug.Info)
|
||||
di.Style = ar.Request.URL.Query().Get("style")
|
||||
|
||||
// Add debug information.
|
||||
di.AddVersionInfo()
|
||||
di.AddPlatformInfo(ar.Context())
|
||||
di.AddLastUnexpectedLogs()
|
||||
di.AddGoroutineStack()
|
||||
|
||||
// Return data.
|
||||
return di.Bytes(), nil
|
||||
}
|
||||
140
base/api/endpoints_meta.go
Normal file
140
base/api/endpoints_meta.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func registerMetaEndpoints() error {
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "endpoints",
|
||||
Read: PermitAnyone,
|
||||
MimeType: MimeTypeJSON,
|
||||
DataFunc: listEndpoints,
|
||||
Name: "Export API Endpoints",
|
||||
Description: "Returns a list of all registered endpoints and their metadata.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "auth/permissions",
|
||||
Read: Dynamic,
|
||||
StructFunc: permissions,
|
||||
Name: "View Current Permissions",
|
||||
Description: "Returns the current permissions assigned to the request.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "auth/bearer",
|
||||
Read: Dynamic,
|
||||
HandlerFunc: authBearer,
|
||||
Name: "Request HTTP Bearer Auth",
|
||||
Description: "Returns an HTTP Bearer Auth request, if not authenticated.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "auth/basic",
|
||||
Read: Dynamic,
|
||||
HandlerFunc: authBasic,
|
||||
Name: "Request HTTP Basic Auth",
|
||||
Description: "Returns an HTTP Basic Auth request, if not authenticated.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "auth/reset",
|
||||
Read: PermitAnyone,
|
||||
HandlerFunc: authReset,
|
||||
Name: "Reset Authenticated Session",
|
||||
Description: "Resets authentication status internally and in the browser.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func listEndpoints(ar *Request) (data []byte, err error) {
|
||||
data, err = json.Marshal(ExportEndpoints())
|
||||
return
|
||||
}
|
||||
|
||||
func permissions(ar *Request) (i interface{}, err error) {
|
||||
if ar.AuthToken == nil {
|
||||
return nil, errors.New("authentication token missing")
|
||||
}
|
||||
|
||||
return struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
ReadRole string
|
||||
WriteRole string
|
||||
}{
|
||||
Read: ar.AuthToken.Read,
|
||||
Write: ar.AuthToken.Write,
|
||||
ReadRole: ar.AuthToken.Read.Role(),
|
||||
WriteRole: ar.AuthToken.Write.Role(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authBearer(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if authenticated by checking read permission.
|
||||
ar := GetAPIRequest(r)
|
||||
if ar.AuthToken.Read != PermitAnyone {
|
||||
TextResponse(w, r, "Authenticated.")
|
||||
return
|
||||
}
|
||||
|
||||
// Respond with desired authentication header.
|
||||
w.Header().Set(
|
||||
"WWW-Authenticate",
|
||||
`Bearer realm="Portmaster API" domain="/"`,
|
||||
)
|
||||
http.Error(w, "Authorization required.", http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func authBasic(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if authenticated by checking read permission.
|
||||
ar := GetAPIRequest(r)
|
||||
if ar.AuthToken.Read != PermitAnyone {
|
||||
TextResponse(w, r, "Authenticated.")
|
||||
return
|
||||
}
|
||||
|
||||
// Respond with desired authentication header.
|
||||
w.Header().Set(
|
||||
"WWW-Authenticate",
|
||||
`Basic realm="Portmaster API" domain="/"`,
|
||||
)
|
||||
http.Error(w, "Authorization required.", http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func authReset(w http.ResponseWriter, r *http.Request) {
|
||||
// Get session cookie from request and delete session if exists.
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err == nil {
|
||||
deleteSession(c.Value)
|
||||
}
|
||||
|
||||
// Delete session and cookie.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
MaxAge: -1, // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
|
||||
})
|
||||
|
||||
// Request client to also reset all data.
|
||||
w.Header().Set("Clear-Site-Data", "*")
|
||||
|
||||
// Set HTTP Auth Realm without requesting authorization.
|
||||
w.Header().Set("WWW-Authenticate", `None realm="Portmaster API"`)
|
||||
|
||||
// Reply with 401 Unauthorized in order to clear HTTP Basic Auth data.
|
||||
http.Error(w, "Session deleted.", http.StatusUnauthorized)
|
||||
}
|
||||
161
base/api/endpoints_test.go
Normal file
161
base/api/endpoints_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
const (
|
||||
successMsg = "endpoint api success"
|
||||
failedMsg = "endpoint api failed"
|
||||
)
|
||||
|
||||
type actionTestRecord struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
Msg string
|
||||
}
|
||||
|
||||
func TestEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testHandler := &mainHandler{
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// ActionFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/action",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/action-err",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: func(_ *Request) (msg string, err error) {
|
||||
return "", errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action-err", nil, failedMsg)
|
||||
|
||||
// DataFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/data",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: func(_ *Request) (data []byte, err error) {
|
||||
return []byte(successMsg), nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/data-err",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: func(_ *Request) (data []byte, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data-err", nil, failedMsg)
|
||||
|
||||
// StructFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/struct",
|
||||
Read: PermitAnyone,
|
||||
StructFunc: func(_ *Request) (i interface{}, err error) {
|
||||
return &actionTestRecord{
|
||||
Msg: successMsg,
|
||||
}, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/struct-err",
|
||||
Read: PermitAnyone,
|
||||
StructFunc: func(_ *Request) (i interface{}, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct-err", nil, failedMsg)
|
||||
|
||||
// RecordFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/record",
|
||||
Read: PermitAnyone,
|
||||
RecordFunc: func(_ *Request) (r record.Record, err error) {
|
||||
r = &actionTestRecord{
|
||||
Msg: successMsg,
|
||||
}
|
||||
r.CreateMeta()
|
||||
return r, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/record-err",
|
||||
Read: PermitAnyone,
|
||||
RecordFunc: func(_ *Request) (r record.Record, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record-err", nil, failedMsg)
|
||||
}
|
||||
|
||||
func TestActionRegistration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Read: NotFound,
|
||||
}))
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Read: PermitSelf + 1,
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Write: NotFound,
|
||||
}))
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Write: PermitSelf + 1,
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
ActionFunc: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
DataFunc: func(_ *Request) (data []byte, err error) {
|
||||
return []byte(successMsg), nil
|
||||
},
|
||||
}))
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
ActionFunc: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
}))
|
||||
}
|
||||
68
base/api/enriched-response.go
Normal file
68
base/api/enriched-response.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
// LoggingResponseWriter is a wrapper for http.ResponseWriter for better request logging.
|
||||
type LoggingResponseWriter struct {
|
||||
ResponseWriter http.ResponseWriter
|
||||
Request *http.Request
|
||||
Status int
|
||||
}
|
||||
|
||||
// NewLoggingResponseWriter wraps a http.ResponseWriter.
|
||||
func NewLoggingResponseWriter(w http.ResponseWriter, r *http.Request) *LoggingResponseWriter {
|
||||
return &LoggingResponseWriter{
|
||||
ResponseWriter: w,
|
||||
Request: r,
|
||||
}
|
||||
}
|
||||
|
||||
// Header wraps the original Header method.
|
||||
func (lrw *LoggingResponseWriter) Header() http.Header {
|
||||
return lrw.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
// Write wraps the original Write method.
|
||||
func (lrw *LoggingResponseWriter) Write(b []byte) (int, error) {
|
||||
return lrw.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// WriteHeader wraps the original WriteHeader method to extract information.
|
||||
func (lrw *LoggingResponseWriter) WriteHeader(code int) {
|
||||
lrw.Status = code
|
||||
lrw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Hijack wraps the original Hijack method, if available.
|
||||
func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := lrw.ResponseWriter.(http.Hijacker)
|
||||
if ok {
|
||||
c, b, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
log.Tracer(lrw.Request.Context()).Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI)
|
||||
return c, b, nil
|
||||
}
|
||||
return nil, nil, errors.New("response does not implement http.Hijacker")
|
||||
}
|
||||
|
||||
// RequestLogger is a logging middleware.
|
||||
func RequestLogger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracer(r.Context()).Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
|
||||
lrw := NewLoggingResponseWriter(w, r)
|
||||
next.ServeHTTP(lrw, r)
|
||||
if lrw.Status != 0 {
|
||||
// request may have been hijacked
|
||||
log.Tracer(r.Context()).Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
|
||||
}
|
||||
})
|
||||
}
|
||||
38
base/api/init_test.go
Normal file
38
base/api/init_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/config"
|
||||
)
|
||||
|
||||
type testInstance struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
var _ instance = &testInstance{}
|
||||
|
||||
func (stub *testInstance) Config() *config.Config {
|
||||
return stub.config
|
||||
}
|
||||
|
||||
func (stub *testInstance) SetCmdLineOperation(f func() error) {}
|
||||
|
||||
func (stub *testInstance) Ready() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
SetDefaultAPIListenAddress("0.0.0.0:8080")
|
||||
instance := &testInstance{}
|
||||
var err error
|
||||
module, err = New(instance)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = SetAuthenticator(testAuthenticator)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
m.Run()
|
||||
}
|
||||
82
base/api/main.go
Normal file
82
base/api/main.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
var exportEndpoints bool
|
||||
|
||||
// API Errors.
|
||||
var (
|
||||
ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set")
|
||||
ErrAuthenticationImmutable = errors.New("the authentication function can only be set before the api has started")
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&exportEndpoints, "export-api-endpoints", false, "export api endpoint registry and exit")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
// Register endpoints.
|
||||
if err := registerConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registerDebugEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registerConfigEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registerMetaEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exportEndpoints {
|
||||
module.instance.SetCmdLineOperation(exportEndpointsCmd)
|
||||
return mgr.ErrExecuteCmdLineOp
|
||||
}
|
||||
|
||||
if getDefaultListenAddress() == "" {
|
||||
return errors.New("no default listen address for api available")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
startServer()
|
||||
|
||||
updateAPIKeys()
|
||||
module.instance.Config().EventConfigChange.AddCallback("update API keys",
|
||||
func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) {
|
||||
updateAPIKeys()
|
||||
return false, nil
|
||||
})
|
||||
|
||||
// start api auth token cleaner
|
||||
if authFnSet.IsSet() {
|
||||
_ = module.mgr.Repeat("clean api sessions", 5*time.Minute, cleanSessions)
|
||||
}
|
||||
|
||||
return registerEndpointBridgeDB()
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
return stopServer()
|
||||
}
|
||||
|
||||
func exportEndpointsCmd() error {
|
||||
data, err := json.MarshalIndent(ExportEndpoints(), "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = os.Stdout.Write(data)
|
||||
return err
|
||||
}
|
||||
65
base/api/module.go
Normal file
65
base/api/module.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/safing/portmaster/base/config"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// API is the HTTP/Websockets API module.
|
||||
type API struct {
|
||||
mgr *mgr.Manager
|
||||
instance instance
|
||||
|
||||
online atomic.Bool
|
||||
}
|
||||
|
||||
func (api *API) Manager() *mgr.Manager {
|
||||
return api.mgr
|
||||
}
|
||||
|
||||
// Start starts the module.
|
||||
func (api *API) Start() error {
|
||||
if err := start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
api.online.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the module.
|
||||
func (api *API) Stop() error {
|
||||
defer api.online.Store(false)
|
||||
return stop()
|
||||
}
|
||||
|
||||
var (
|
||||
shimLoaded atomic.Bool
|
||||
module *API
|
||||
)
|
||||
|
||||
// New returns a new UI module.
|
||||
func New(instance instance) (*API, error) {
|
||||
if !shimLoaded.CompareAndSwap(false, true) {
|
||||
return nil, errors.New("only one instance allowed")
|
||||
}
|
||||
m := mgr.New("API")
|
||||
module = &API{
|
||||
mgr: m,
|
||||
instance: instance,
|
||||
}
|
||||
|
||||
if err := prep(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return module, nil
|
||||
}
|
||||
|
||||
type instance interface {
|
||||
Config() *config.Config
|
||||
SetCmdLineOperation(f func() error)
|
||||
Ready() bool
|
||||
}
|
||||
60
base/api/request.go
Normal file
60
base/api/request.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
// Request is a support struct to pool more request related information.
|
||||
type Request struct {
|
||||
// Request is the http request.
|
||||
*http.Request
|
||||
|
||||
// InputData contains the request body for write operations.
|
||||
InputData []byte
|
||||
|
||||
// Route of this request.
|
||||
Route *mux.Route
|
||||
|
||||
// URLVars contains the URL variables extracted by the gorilla mux.
|
||||
URLVars map[string]string
|
||||
|
||||
// AuthToken is the request-side authentication token assigned.
|
||||
AuthToken *AuthToken
|
||||
|
||||
// ResponseHeader holds the response header.
|
||||
ResponseHeader http.Header
|
||||
|
||||
// HandlerCache can be used by handlers to cache data between handlers within a request.
|
||||
HandlerCache interface{}
|
||||
}
|
||||
|
||||
// apiRequestContextKey is a key used for the context key/value storage.
|
||||
type apiRequestContextKey struct{}
|
||||
|
||||
// RequestContextKey is the key used to add the API request to the context.
|
||||
var RequestContextKey = apiRequestContextKey{}
|
||||
|
||||
// GetAPIRequest returns the API Request of the given http request.
|
||||
func GetAPIRequest(r *http.Request) *Request {
|
||||
ar, ok := r.Context().Value(RequestContextKey).(*Request)
|
||||
if ok {
|
||||
return ar
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TextResponse writes a text response.
|
||||
func TextResponse(w http.ResponseWriter, r *http.Request, text string) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := fmt.Fprintln(w, text)
|
||||
if err != nil {
|
||||
log.Tracer(r.Context()).Warningf("api: failed to write text response: %s", err)
|
||||
}
|
||||
}
|
||||
329
base/api/router.go
Normal file
329
base/api/router.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// EnableServer defines if the HTTP server should be started.
|
||||
var EnableServer = true
|
||||
|
||||
var (
|
||||
// mainMux is the main mux router.
|
||||
mainMux = mux.NewRouter()
|
||||
|
||||
// server is the main server.
|
||||
server = &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
handlerLock sync.RWMutex
|
||||
|
||||
allowedDevCORSOrigins = []string{
|
||||
"127.0.0.1",
|
||||
"localhost",
|
||||
}
|
||||
)
|
||||
|
||||
// RegisterHandler registers a handler with the API endpoint.
|
||||
func RegisterHandler(path string, handler http.Handler) *mux.Route {
|
||||
handlerLock.Lock()
|
||||
defer handlerLock.Unlock()
|
||||
return mainMux.Handle(path, handler)
|
||||
}
|
||||
|
||||
// RegisterHandleFunc registers a handle function with the API endpoint.
|
||||
func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route {
|
||||
handlerLock.Lock()
|
||||
defer handlerLock.Unlock()
|
||||
return mainMux.HandleFunc(path, handleFunc)
|
||||
}
|
||||
|
||||
func startServer() {
|
||||
// Check if server is enabled.
|
||||
if !EnableServer {
|
||||
return
|
||||
}
|
||||
|
||||
// Configure server.
|
||||
server.Addr = listenAddressConfig()
|
||||
server.Handler = &mainHandler{
|
||||
// TODO: mainMux should not be modified anymore.
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// Start server manager.
|
||||
module.mgr.Go("http server manager", serverManager)
|
||||
}
|
||||
|
||||
func stopServer() error {
|
||||
// Check if server is enabled.
|
||||
if !EnableServer {
|
||||
return nil
|
||||
}
|
||||
|
||||
if server.Addr != "" {
|
||||
return server.Shutdown(context.Background())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts serving the API endpoint.
|
||||
func serverManager(ctx *mgr.WorkerCtx) error {
|
||||
// start serving
|
||||
log.Infof("api: starting to listen on %s", server.Addr)
|
||||
backoffDuration := 10 * time.Second
|
||||
for {
|
||||
err := module.mgr.Do("http server", func(ctx *mgr.WorkerCtx) error {
|
||||
err := server.ListenAndServe()
|
||||
// return on shutdown error
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// log error and restart
|
||||
log.Errorf("api: http endpoint failed: %s - restarting in %s", err, backoffDuration)
|
||||
time.Sleep(backoffDuration)
|
||||
}
|
||||
}
|
||||
|
||||
type mainHandler struct {
|
||||
mux *mux.Router
|
||||
}
|
||||
|
||||
func (mh *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
_ = module.mgr.Do("http request", func(_ *mgr.WorkerCtx) error {
|
||||
return mh.handle(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
|
||||
// Setup context trace logging.
|
||||
ctx, tracer := log.AddTracer(r.Context())
|
||||
// Add request context.
|
||||
apiRequest := &Request{
|
||||
Request: r,
|
||||
}
|
||||
ctx = context.WithValue(ctx, RequestContextKey, apiRequest)
|
||||
// Add context back to request.
|
||||
r = r.WithContext(ctx)
|
||||
lrw := NewLoggingResponseWriter(w, r)
|
||||
|
||||
tracer.Tracef("api request: %s ___ %s %s", r.RemoteAddr, lrw.Request.Method, r.RequestURI)
|
||||
defer func() {
|
||||
// Log request status.
|
||||
if lrw.Status != 0 {
|
||||
// If lrw.Status is 0, the request may have been hijacked.
|
||||
tracer.Debugf("api request: %s %d %s %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.Method, lrw.Request.RequestURI)
|
||||
}
|
||||
tracer.Submit()
|
||||
}()
|
||||
|
||||
// Add security headers.
|
||||
w.Header().Set("Referrer-Policy", "same-origin")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "deny")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("X-DNS-Prefetch-Control", "off")
|
||||
|
||||
// Add CSP Header in production mode.
|
||||
if !devMode() {
|
||||
w.Header().Set(
|
||||
"Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"connect-src https://*.safing.io 'self'; "+
|
||||
"style-src 'self' 'unsafe-inline'; "+
|
||||
"img-src 'self' data: blob:",
|
||||
)
|
||||
}
|
||||
|
||||
// Check Cross-Origin Requests.
|
||||
origin := r.Header.Get("Origin")
|
||||
isPreflighCheck := false
|
||||
if origin != "" {
|
||||
|
||||
// Parse origin URL.
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
tracer.Warningf("api: denied request from %s: failed to parse origin header: %s", r.RemoteAddr, err)
|
||||
http.Error(lrw, "Invalid Origin.", http.StatusForbidden)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the Origin matches the Host.
|
||||
switch {
|
||||
case originURL.Host == r.Host:
|
||||
// Origin (with port) matches Host.
|
||||
case originURL.Hostname() == r.Host:
|
||||
// Origin (without port) matches Host.
|
||||
case originURL.Scheme == "chrome-extension":
|
||||
// Allow access for the browser extension
|
||||
// TODO(ppacher):
|
||||
// This currently allows access from any browser extension.
|
||||
// Can we reduce that to only our browser extension?
|
||||
// Also, what do we need to support Firefox?
|
||||
case devMode() &&
|
||||
utils.StringInSlice(allowedDevCORSOrigins, originURL.Hostname()):
|
||||
// We are in dev mode and the request is coming from the allowed
|
||||
// development origins.
|
||||
default:
|
||||
// Origin and Host do NOT match!
|
||||
tracer.Warningf("api: denied request from %s: Origin (`%s`) and Host (`%s`) do not match", r.RemoteAddr, origin, r.Host)
|
||||
http.Error(lrw, "Cross-Origin Request Denied.", http.StatusForbidden)
|
||||
return nil
|
||||
|
||||
// If the Host header has a port, and the Origin does not, requests will
|
||||
// also end up here, as we cannot properly check for equality.
|
||||
}
|
||||
|
||||
// Add Cross-Site Headers now as we need them in any case now.
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
w.Header().Set("Access-Control-Max-Age", "60")
|
||||
w.Header().Add("Vary", "Origin")
|
||||
|
||||
// if there's a Access-Control-Request-Method header this is a Preflight check.
|
||||
// In that case, we will just check if the preflighMethod is allowed and then return
|
||||
// success here
|
||||
if preflighMethod := r.Header.Get("Access-Control-Request-Method"); r.Method == http.MethodOptions && preflighMethod != "" {
|
||||
isPreflighCheck = true
|
||||
}
|
||||
}
|
||||
|
||||
// Clean URL.
|
||||
cleanedRequestPath := cleanRequestPath(r.URL.Path)
|
||||
|
||||
// If the cleaned URL differs from the original one, redirect to there.
|
||||
if r.URL.Path != cleanedRequestPath {
|
||||
redirURL := *r.URL
|
||||
redirURL.Path = cleanedRequestPath
|
||||
http.Redirect(lrw, r, redirURL.String(), http.StatusMovedPermanently)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get handler for request.
|
||||
// Gorilla does not support handling this on our own very well.
|
||||
// See github.com/gorilla/mux.ServeHTTP for reference.
|
||||
var match mux.RouteMatch
|
||||
var handler http.Handler
|
||||
if mh.mux.Match(r, &match) {
|
||||
handler = match.Handler
|
||||
apiRequest.Route = match.Route
|
||||
apiRequest.URLVars = match.Vars
|
||||
}
|
||||
switch {
|
||||
case match.MatchErr == nil:
|
||||
// All good.
|
||||
case errors.Is(match.MatchErr, mux.ErrMethodMismatch):
|
||||
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
default:
|
||||
tracer.Debug("api: no handler registered for this path")
|
||||
http.Error(lrw, "Not found.", http.StatusNotFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Be sure that URLVars always is a map.
|
||||
if apiRequest.URLVars == nil {
|
||||
apiRequest.URLVars = make(map[string]string)
|
||||
}
|
||||
|
||||
// Check method.
|
||||
_, readMethod, ok := getEffectiveMethod(r)
|
||||
if !ok {
|
||||
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// At this point we know the method is allowed and there's a handler for the request.
|
||||
// If this is just a CORS-Preflight, we'll accept the request with StatusOK now.
|
||||
// There's no point in trying to authenticate the request because the Browser will
|
||||
// not send authentication along a preflight check.
|
||||
if isPreflighCheck && handler != nil {
|
||||
lrw.WriteHeader(http.StatusOK)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check authentication.
|
||||
apiRequest.AuthToken = authenticateRequest(lrw, r, handler, readMethod)
|
||||
if apiRequest.AuthToken == nil {
|
||||
// Authenticator already replied.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we have a handler.
|
||||
if handler == nil {
|
||||
http.Error(lrw, "Not found.", http.StatusNotFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Format panics in handler.
|
||||
defer func() {
|
||||
if panicValue := recover(); panicValue != nil {
|
||||
// Log failure.
|
||||
log.Errorf("api: handler panic: %s", panicValue)
|
||||
// Respond with a server error.
|
||||
if devMode() {
|
||||
http.Error(
|
||||
lrw,
|
||||
fmt.Sprintf(
|
||||
"Internal Server Error: %s\n\n%s",
|
||||
panicValue,
|
||||
debug.Stack(),
|
||||
),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
} else {
|
||||
http.Error(lrw, "Internal Server Error.", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle with registered handler.
|
||||
handler.ServeHTTP(lrw, r)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanRequestPath cleans and returns a request URL.
|
||||
func cleanRequestPath(requestPath string) string {
|
||||
// If the request URL is empty, return a request for "root".
|
||||
if requestPath == "" || requestPath == "/" {
|
||||
return "/"
|
||||
}
|
||||
// If the request URL does not start with a slash, prepend it.
|
||||
if !strings.HasPrefix(requestPath, "/") {
|
||||
requestPath = "/" + requestPath
|
||||
}
|
||||
|
||||
// Clean path to remove any relative parts.
|
||||
cleanedRequestPath := path.Clean(requestPath)
|
||||
// Because path.Clean removes a trailing slash, we need to add it back here
|
||||
// if the original URL had one.
|
||||
if strings.HasSuffix(requestPath, "/") {
|
||||
cleanedRequestPath += "/"
|
||||
}
|
||||
|
||||
return cleanedRequestPath
|
||||
}
|
||||
49
base/api/testclient/root/index.html
Normal file
49
base/api/testclient/root/index.html
Normal file
@@ -0,0 +1,49 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title></title>
|
||||
<!-- <script src="https://cdn.jsdelivr.net/sockjs/1/sockjs.min.js"></script> -->
|
||||
</head>
|
||||
<body>
|
||||
<script type="text/javascript">
|
||||
|
||||
var ws = new WebSocket('ws://127.0.0.1:18/api/database/v1')
|
||||
|
||||
ws.onopen = function () {
|
||||
console.log('open');
|
||||
};
|
||||
|
||||
ws.onerror = function (error) {
|
||||
console.log('error');
|
||||
console.log(error);
|
||||
};
|
||||
|
||||
ws.onmessage = function (e) {
|
||||
reader = new FileReader()
|
||||
reader.onload = function(e) {
|
||||
console.log(e.target.result)
|
||||
}
|
||||
reader.readAsText(e.data)
|
||||
};
|
||||
|
||||
function send(text) {
|
||||
ws.send(text)
|
||||
}
|
||||
|
||||
// var sock = new SockJS("http://localhost:8080/api/v1");
|
||||
//
|
||||
// sock.onopen = function() {
|
||||
// console.log('open');
|
||||
// };
|
||||
//
|
||||
// sock.onmessage = function(e) {
|
||||
// console.log('message received: ', e.data);
|
||||
// };
|
||||
//
|
||||
// sock.onclose = function(e) {
|
||||
// console.log('close', e);
|
||||
// };
|
||||
</script>
|
||||
yeeee
|
||||
</body>
|
||||
</html>
|
||||
11
base/api/testclient/serve.go
Normal file
11
base/api/testclient/serve.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package testclient
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/safing/portmaster/base/api"
|
||||
)
|
||||
|
||||
func init() {
|
||||
api.RegisterHandler("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/"))))
|
||||
}
|
||||
Reference in New Issue
Block a user