Restructure modules (#1572)

* Move portbase into monorepo

* Add new simple module mgr

* [WIP] Switch to new simple module mgr

* Add StateMgr and more worker variants

* [WIP] Switch more modules

* [WIP] Switch more modules

* [WIP] swtich more modules

* [WIP] switch all SPN modules

* [WIP] switch all service modules

* [WIP] Convert all workers to the new module system

* [WIP] add new task system to module manager

* [WIP] Add second take for scheduling workers

* [WIP] Add FIXME for bugs in new scheduler

* [WIP] Add minor improvements to scheduler

* [WIP] Add new worker scheduler

* [WIP] Fix more bug related to new module system

* [WIP] Fix start handing of the new module system

* [WIP] Improve startup process

* [WIP] Fix minor issues

* [WIP] Fix missing subsystem in settings

* [WIP] Initialize managers in constructor

* [WIP] Move module event initialization to constrictors

* [WIP] Fix setting for enabling and disabling the SPN module

* [WIP] Move API registeration into module construction

* [WIP] Update states mgr for all modules

* [WIP] Add CmdLine operation support

* Add state helper methods to module group and instance

* Add notification and module status handling to status package

* Fix starting issues

* Remove pilot widget and update security lock to new status data

* Remove debug logs

* Improve http server shutdown

* Add workaround for cleanly shutting down firewall+netquery

* Improve logging

* Add syncing states with notifications for new module system

* Improve starting, stopping, shutdown; resolve FIXMEs/TODOs

* [WIP] Fix most unit tests

* Review new module system and fix minor issues

* Push shutdown and restart events again via API

* Set sleep mode via interface

* Update example/template module

* [WIP] Fix spn/cabin unit test

* Remove deprecated UI elements

* Make log output more similar for the logging transition phase

* Switch spn hub and observer cmds to new module system

* Fix log sources

* Make worker mgr less error prone

* Fix tests and minor issues

* Fix observation hub

* Improve shutdown and restart handling

* Split up big connection.go source file

* Move varint and dsd packages to structures repo

* Improve expansion test

* Fix linter warnings

* Fix interception module on windows

* Fix linter errors

---------

Co-authored-by: Vladimir Stoilov <vladimir@safing.io>
This commit is contained in:
Daniel Hååvi
2024-08-09 17:15:48 +02:00
committed by GitHub
parent 10a77498f4
commit 80664d1a27
647 changed files with 37690 additions and 3366 deletions

105
base/utils/atomic.go Normal file
View File

@@ -0,0 +1,105 @@
package utils
import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"github.com/safing/portmaster/base/utils/renameio"
)
// AtomicFileOptions holds additional options for manipulating
// the behavior of CreateAtomic and friends.
type AtomicFileOptions struct {
// Mode is the file mode for the new file. If
// 0, the file mode will be set to 0600.
Mode os.FileMode
// TempDir is the path to the temp-directory
// that should be used. If empty, it defaults
// to the system temp.
TempDir string
}
// CreateAtomic creates or overwrites a file at dest atomically using
// data from r. Atomic means that even in case of a power outage,
// dest will never be a zero-length file. It will always either contain
// the previous data (or not exist) or the new data but never anything
// in between.
func CreateAtomic(dest string, r io.Reader, opts *AtomicFileOptions) error {
if opts == nil {
opts = &AtomicFileOptions{}
}
tmpFile, err := renameio.TempFile(opts.TempDir, dest)
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
defer tmpFile.Cleanup() //nolint:errcheck
if opts.Mode != 0 {
if err := tmpFile.Chmod(opts.Mode); err != nil {
return fmt.Errorf("failed to update mode bits of temp file: %w", err)
}
}
if _, err := io.Copy(tmpFile, r); err != nil {
return fmt.Errorf("failed to copy source file: %w", err)
}
if err := tmpFile.CloseAtomicallyReplace(); err != nil {
return fmt.Errorf("failed to rename temp file to %q", dest)
}
return nil
}
// CopyFileAtomic is like CreateAtomic but copies content from
// src to dest. If opts.Mode is 0 CopyFileAtomic tries to set
// the file mode of src to dest.
func CopyFileAtomic(dest string, src string, opts *AtomicFileOptions) error {
if opts == nil {
opts = &AtomicFileOptions{}
}
if opts.Mode == 0 {
stat, err := os.Stat(src)
if err != nil {
return err
}
opts.Mode = stat.Mode()
}
f, err := os.Open(src)
if err != nil {
return err
}
defer func() {
_ = f.Close()
}()
return CreateAtomic(dest, f, opts)
}
// ReplaceFileAtomic replaces the file at dest with the content from src.
// If dest exists it's file mode copied and used for the replacement. If
// not, dest will get the same file mode as src. See CopyFileAtomic and
// CreateAtomic for more information.
func ReplaceFileAtomic(dest string, src string, opts *AtomicFileOptions) error {
if opts == nil {
opts = &AtomicFileOptions{}
}
if opts.Mode == 0 {
stat, err := os.Stat(dest)
if err == nil {
opts.Mode = stat.Mode()
} else if !errors.Is(err, fs.ErrNotExist) {
return err
}
}
return CopyFileAtomic(dest, src, opts)
}

View File

@@ -0,0 +1,84 @@
package utils
import (
"sync"
"github.com/tevino/abool"
)
// BroadcastFlag is a simple system to broadcast a flag value.
type BroadcastFlag struct {
flag *abool.AtomicBool
signal chan struct{}
lock sync.Mutex
}
// Flag receives changes from its broadcasting flag.
// A Flag must only be used in one goroutine and is not concurrency safe,
// but fast.
type Flag struct {
flag *abool.AtomicBool
signal chan struct{}
broadcaster *BroadcastFlag
}
// NewBroadcastFlag returns a new BroadcastFlag.
// In the initial state, the flag is not set and the signal does not trigger.
func NewBroadcastFlag() *BroadcastFlag {
return &BroadcastFlag{
flag: abool.New(),
signal: make(chan struct{}),
lock: sync.Mutex{},
}
}
// NewFlag returns a new Flag that listens to this broadcasting flag.
// In the initial state, the flag is set and the signal triggers.
// You can call Refresh immediately to get the current state from the
// broadcasting flag.
func (bf *BroadcastFlag) NewFlag() *Flag {
newFlag := &Flag{
flag: abool.NewBool(true),
signal: make(chan struct{}),
broadcaster: bf,
}
close(newFlag.signal)
return newFlag
}
// NotifyAndReset notifies all flags of this broadcasting flag and resets the
// internal broadcast flag state.
func (bf *BroadcastFlag) NotifyAndReset() {
bf.lock.Lock()
defer bf.lock.Unlock()
// Notify all flags of the change.
bf.flag.Set()
close(bf.signal)
// Reset
bf.flag = abool.New()
bf.signal = make(chan struct{})
}
// Signal returns a channel that waits for the flag to be set. This does not
// reset the Flag itself, you'll need to call Refresh for that.
func (f *Flag) Signal() <-chan struct{} {
return f.signal
}
// IsSet returns whether the flag was set since the last Refresh.
// This does not reset the Flag itself, you'll need to call Refresh for that.
func (f *Flag) IsSet() bool {
return f.flag.IsSet()
}
// Refresh fetches the current state from the broadcasting flag.
func (f *Flag) Refresh() {
f.broadcaster.lock.Lock()
defer f.broadcaster.lock.Unlock()
// Copy current flag and signal from the broadcasting flag.
f.flag = f.broadcaster.flag
f.signal = f.broadcaster.signal
}

View File

@@ -0,0 +1,87 @@
package utils
import (
"sync"
"sync/atomic"
"time"
)
// CallLimiter bundles concurrent calls and optionally limits how fast a function is called.
type CallLimiter struct {
pause time.Duration
inLock sync.Mutex
lastExec time.Time
waiters atomic.Int32
outLock sync.Mutex
}
// NewCallLimiter returns a new call limiter.
// Set minPause to zero to disable the minimum pause between calls.
func NewCallLimiter(minPause time.Duration) *CallLimiter {
return &CallLimiter{
pause: minPause,
}
}
// Do executes the given function.
// All concurrent calls to Do are bundled and return when f() finishes.
// Waits until the minimum pause is over before executing f() again.
func (l *CallLimiter) Do(f func()) {
// Wait for the previous waiters to exit.
l.inLock.Lock()
// Defer final unlock to safeguard from panics.
defer func() {
// Execution is finished - leave.
// If we are the last waiter, let the next batch in.
if l.waiters.Add(-1) == 0 {
l.inLock.Unlock()
}
}()
// Check if we are the first waiter.
if l.waiters.Add(1) == 1 {
// Take the lead on this execution run.
l.lead(f)
} else {
// We are not the first waiter, let others in.
l.inLock.Unlock()
}
// Wait for execution to complete.
l.outLock.Lock()
l.outLock.Unlock() //nolint:staticcheck
// Last statement is in defer above.
}
func (l *CallLimiter) lead(f func()) {
// Make all others wait while we execute the function.
l.outLock.Lock()
// Unlock in lock until execution is finished.
l.inLock.Unlock()
// Transition from out lock to in lock when done.
defer func() {
// Update last execution time.
l.lastExec = time.Now().UTC()
// Stop newcomers from waiting on previous execution.
l.inLock.Lock()
// Allow waiters to leave.
l.outLock.Unlock()
}()
// Wait for the minimum duration between executions.
if l.pause > 0 {
sinceLastExec := time.Since(l.lastExec)
if sinceLastExec < l.pause {
time.Sleep(l.pause - sinceLastExec)
}
}
// Execute.
f()
}

View File

@@ -0,0 +1,91 @@
package utils
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/tevino/abool"
)
func TestCallLimiter(t *testing.T) {
t.Parallel()
pause := 10 * time.Millisecond
oa := NewCallLimiter(pause)
executed := abool.New()
var testWg sync.WaitGroup
// One execution should gobble up the whole batch.
// We are doing this without sleep in function, so dummy exec first to trigger first pause.
oa.Do(func() {})
// Start
for range 10 {
testWg.Add(100)
for range 100 {
go func() {
oa.Do(func() {
if !executed.SetToIf(false, true) {
t.Errorf("concurrent execution!")
}
})
testWg.Done()
}()
}
testWg.Wait()
// Check if function was executed at least once.
if executed.IsNotSet() {
t.Errorf("no execution!")
}
executed.UnSet() // reset check
}
// Wait for pause to reset.
time.Sleep(pause)
// Continuous use with re-execution.
// Choose values so that about 10 executions are expected
var execs uint32
testWg.Add(200)
for range 200 {
go func() {
oa.Do(func() {
atomic.AddUint32(&execs, 1)
time.Sleep(10 * time.Millisecond)
})
testWg.Done()
}()
// Start one goroutine every 1ms.
time.Sleep(1 * time.Millisecond)
}
testWg.Wait()
if execs <= 5 {
t.Errorf("unexpected low exec count: %d", execs)
}
if execs >= 15 {
t.Errorf("unexpected high exec count: %d", execs)
}
// Wait for pause to reset.
time.Sleep(pause)
// Check if the limiter correctly handles panics.
testWg.Add(100)
for range 100 {
go func() {
defer func() {
_ = recover()
testWg.Done()
}()
oa.Do(func() {
time.Sleep(1 * time.Millisecond)
panic("test")
})
}()
time.Sleep(100 * time.Microsecond)
}
testWg.Wait()
}

131
base/utils/debug/debug.go Normal file
View File

@@ -0,0 +1,131 @@
package debug
import (
"bytes"
"fmt"
"runtime/pprof"
"time"
"github.com/safing/portmaster/base/info"
"github.com/safing/portmaster/base/log"
)
// Info gathers debugging information and stores everything in a buffer in
// order to write it to somewhere later. It directly inherits a bytes.Buffer,
// so you can also use all these functions too.
type Info struct {
bytes.Buffer
Style string
}
// InfoFlag defines possible options for adding sections to a Info.
type InfoFlag int
const (
// NoFlags does nothing.
NoFlags InfoFlag = 0
// UseCodeSection wraps the section content in a markdown code section.
UseCodeSection InfoFlag = 1
// AddContentLineBreaks adds a line breaks after each line of content,
// except for the last.
AddContentLineBreaks InfoFlag = 2
)
func useCodeSection(flags InfoFlag) bool {
return flags&UseCodeSection > 0
}
func addContentLineBreaks(flags InfoFlag) bool {
return flags&AddContentLineBreaks > 0
}
// AddSection adds a debug section to the Info. The result is directly
// written into the buffer.
func (di *Info) AddSection(name string, flags InfoFlag, content ...string) {
// Check if we need a spacer.
if di.Len() > 0 {
_, _ = di.WriteString("\n\n")
}
// Write section to buffer.
// Write section header.
if di.Style == "github" {
_, _ = di.WriteString(fmt.Sprintf("<details>\n<summary>%s</summary>\n\n", name))
} else {
_, _ = di.WriteString(fmt.Sprintf("**%s**:\n\n", name))
}
// Write section content.
if useCodeSection(flags) {
// Write code header: Needs one empty line between previous data.
_, _ = di.WriteString("```\n")
}
for i, part := range content {
_, _ = di.WriteString(part)
if addContentLineBreaks(flags) && i < len(content)-1 {
_, _ = di.WriteString("\n")
}
}
if useCodeSection(flags) {
// Write code footer: Needs one empty line between next data.
_, _ = di.WriteString("\n```\n")
}
// Write section header.
if di.Style == "github" {
_, _ = di.WriteString("\n</details>")
}
}
// AddVersionInfo adds version information from the info pkg.
func (di *Info) AddVersionInfo() {
di.AddSection(
"Version "+info.Version(),
UseCodeSection,
info.FullVersion(),
)
}
// AddGoroutineStack adds the current goroutine stack.
func (di *Info) AddGoroutineStack() {
buf := new(bytes.Buffer)
err := pprof.Lookup("goroutine").WriteTo(buf, 1)
if err != nil {
di.AddSection(
"Goroutine Stack",
NoFlags,
fmt.Sprintf("Failed to get: %s", err),
)
return
}
// Add section.
di.AddSection(
"Goroutine Stack",
UseCodeSection,
buf.String(),
)
}
// AddLastUnexpectedLogs adds the last 10 unexpected log lines, if any.
func (di *Info) AddLastUnexpectedLogs() {
lines := log.GetLastUnexpectedLogs()
// Check if there is anything at all.
if len(lines) == 0 {
di.AddSection("No Unexpected Logs", NoFlags)
return
}
di.AddSection(
"Unexpected Logs",
UseCodeSection|AddContentLineBreaks,
append(
lines,
fmt.Sprintf("%s CURRENT TIME", time.Now().Format("060102 15:04:05.000")),
)...,
)
}

View File

@@ -0,0 +1,31 @@
package debug
import (
"context"
"fmt"
"github.com/safing/portmaster-android/go/app_interface"
)
// AddPlatformInfo adds OS and platform information.
func (di *Info) AddPlatformInfo(_ context.Context) {
// Get information from the system.
info, err := app_interface.GetPlatformInfo()
if err != nil {
di.AddSection(
"Platform Information",
NoFlags,
fmt.Sprintf("Failed to get: %s", err),
)
return
}
// Add section.
di.AddSection(
fmt.Sprintf("Platform: Android"),
UseCodeSection|AddContentLineBreaks,
fmt.Sprintf("SDK: %d", info.SDK),
fmt.Sprintf("Device: %s %s (%s)", info.Manufacturer, info.Brand, info.Board),
fmt.Sprintf("App: %s: %s %s", info.ApplicationID, info.VersionName, info.BuildType))
}

View File

@@ -0,0 +1,43 @@
//go:build !android
package debug
import (
"context"
"fmt"
"github.com/shirou/gopsutil/host"
)
// AddPlatformInfo adds OS and platform information.
func (di *Info) AddPlatformInfo(ctx context.Context) {
// Get information from the system.
info, err := host.InfoWithContext(ctx)
if err != nil {
di.AddSection(
"Platform Information",
NoFlags,
fmt.Sprintf("Failed to get: %s", err),
)
return
}
// Check if we want to add virtulization information.
var virtInfo string
if info.VirtualizationRole == "guest" {
if info.VirtualizationSystem != "" {
virtInfo = fmt.Sprintf("VM: %s", info.VirtualizationSystem)
} else {
virtInfo = "VM: unidentified"
}
}
// Add section.
di.AddSection(
fmt.Sprintf("Platform: %s %s", info.Platform, info.PlatformVersion),
UseCodeSection|AddContentLineBreaks,
fmt.Sprintf("System: %s %s (%s) %s", info.Platform, info.OS, info.PlatformFamily, info.PlatformVersion),
fmt.Sprintf("Kernel: %s %s", info.KernelVersion, info.KernelArch),
virtInfo,
)
}

51
base/utils/fs.go Normal file
View File

@@ -0,0 +1,51 @@
package utils
import (
"errors"
"fmt"
"io/fs"
"os"
"runtime"
)
const isWindows = runtime.GOOS == "windows"
// EnsureDirectory ensures that the given directory exists and that is has the given permissions set.
// If path is a file, it is deleted and a directory created.
func EnsureDirectory(path string, perm os.FileMode) error {
// open path
f, err := os.Stat(path)
if err == nil {
// file exists
if f.IsDir() {
// directory exists, check permissions
if isWindows {
// TODO: set correct permission on windows
// acl.Chmod(path, perm)
} else if f.Mode().Perm() != perm {
return os.Chmod(path, perm)
}
return nil
}
err = os.Remove(path)
if err != nil {
return fmt.Errorf("could not remove file %s to place dir: %w", path, err)
}
}
// file does not exist (or has been deleted)
if err == nil || errors.Is(err, fs.ErrNotExist) {
err = os.Mkdir(path, perm)
if err != nil {
return fmt.Errorf("could not create dir %s: %w", path, err)
}
return os.Chmod(path, perm)
}
// other error opening path
return fmt.Errorf("failed to access %s: %w", path, err)
}
// PathExists returns whether the given path (file or dir) exists.
func PathExists(path string) bool {
_, err := os.Stat(path)
return err == nil || errors.Is(err, fs.ErrExist)
}

78
base/utils/mimetypes.go Normal file
View File

@@ -0,0 +1,78 @@
package utils
import "strings"
// Do not depend on the OS for mimetypes.
// A Windows update screwed us over here and broke all the automatic mime
// typing via Go in April 2021.
// MimeTypeByExtension returns a mimetype for the given file name extension,
// which must including the leading dot.
// If the extension is not known, the call returns with ok=false and,
// additionally, a default "application/octet-stream" mime type is returned.
func MimeTypeByExtension(ext string) (mimeType string, ok bool) {
mimeType, ok = mimeTypes[strings.ToLower(ext)]
if ok {
return
}
return defaultMimeType, false
}
var (
defaultMimeType = "application/octet-stream"
mimeTypes = map[string]string{
".7z": "application/x-7z-compressed",
".atom": "application/atom+xml",
".css": "text/css; charset=utf-8",
".csv": "text/csv; charset=utf-8",
".deb": "application/x-debian-package",
".epub": "application/epub+zip",
".es": "application/ecmascript",
".flv": "video/x-flv",
".gif": "image/gif",
".gz": "application/gzip",
".htm": "text/html; charset=utf-8",
".html": "text/html; charset=utf-8",
".jpeg": "image/jpeg",
".jpg": "image/jpeg",
".js": "text/javascript; charset=utf-8",
".json": "application/json; charset=utf-8",
".m3u": "audio/mpegurl",
".m4a": "audio/mpeg",
".md": "text/markdown; charset=utf-8",
".mjs": "text/javascript; charset=utf-8",
".mov": "video/quicktime",
".mp3": "audio/mpeg",
".mp4": "video/mp4",
".mpeg": "video/mpeg",
".mpg": "video/mpeg",
".ogg": "audio/ogg",
".ogv": "video/ogg",
".otf": "font/otf",
".pdf": "application/pdf",
".png": "image/png",
".qt": "video/quicktime",
".rar": "application/rar",
".rtf": "application/rtf",
".svg": "image/svg+xml",
".tar": "application/x-tar",
".tiff": "image/tiff",
".ts": "video/MP2T",
".ttc": "font/collection",
".ttf": "font/ttf",
".txt": "text/plain; charset=utf-8",
".wasm": "application/wasm",
".wav": "audio/x-wav",
".webm": "video/webm",
".webp": "image/webp",
".woff": "font/woff",
".woff2": "font/woff2",
".xml": "text/xml; charset=utf-8",
".xz": "application/x-xz",
".yaml": "application/yaml; charset=utf-8",
".yml": "application/yaml; charset=utf-8",
".zip": "application/zip",
}
)

86
base/utils/onceagain.go Normal file
View File

@@ -0,0 +1,86 @@
package utils
// This file is forked from https://github.com/golang/go/blob/bc593eac2dc63d979a575eccb16c7369a5ff81e0/src/sync/once.go.
import (
"sync"
"sync/atomic"
)
// OnceAgain is an object that will perform only one action "in flight". It's
// basically the same as sync.Once, but is automatically reused when the
// function was executed and everyone who waited has left.
// Important: This is somewhat racy when used heavily as it only resets _after_
// everyone who waited has left. So, while some goroutines are waiting to be
// activated again to leave the waiting state, other goroutines will call Do()
// without executing the function again.
type OnceAgain struct {
// done indicates whether the action has been performed.
// It is first in the struct because it is used in the hot path.
// The hot path is inlined at every call site.
// Placing done first allows more compact instructions on some architectures (amd64/x86),
// and fewer instructions (to calculate offset) on other architectures.
done uint32
// Number of waiters waiting for the function to finish. The last waiter resets done.
waiters int32
m sync.Mutex
}
// Do calls the function f if and only if Do is being called for the
// first time for this instance of Once. In other words, given
//
// var once Once
//
// if once.Do(f) is called multiple times, only the first call will invoke f,
// even if f has a different value in each invocation. A new instance of
// Once is required for each function to execute.
//
// Do is intended for initialization that must be run exactly once. Since f
// is niladic, it may be necessary to use a function literal to capture the
// arguments to a function to be invoked by Do:
//
// config.once.Do(func() { config.init(filename) })
//
// Because no call to Do returns until the one call to f returns, if f causes
// Do to be called, it will deadlock.
//
// If f panics, Do considers it to have returned; future calls of Do return
// without calling f.
func (o *OnceAgain) Do(f func()) {
// Note: Here is an incorrect implementation of Do:
//
// if atomic.CompareAndSwapUint32(&o.done, 0, 1) {
// f()
// }
//
// Do guarantees that when it returns, f has finished.
// This implementation would not implement that guarantee:
// given two simultaneous calls, the winner of the cas would
// call f, and the second would return immediately, without
// waiting for the first's call to f to complete.
// This is why the slow path falls back to a mutex, and why
// the atomic.StoreUint32 must be delayed until after f returns.
if atomic.LoadUint32(&o.done) == 0 {
// Outlined slow-path to allow inlining of the fast-path.
o.doSlow(f)
}
}
func (o *OnceAgain) doSlow(f func()) {
atomic.AddInt32(&o.waiters, 1)
defer func() {
if atomic.AddInt32(&o.waiters, -1) == 0 {
atomic.StoreUint32(&o.done, 0) // reset
}
}()
o.m.Lock()
defer o.m.Unlock()
if o.done == 0 {
defer atomic.StoreUint32(&o.done, 1)
f()
}
}

View File

@@ -0,0 +1,60 @@
package utils
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/tevino/abool"
)
func TestOnceAgain(t *testing.T) {
t.Parallel()
oa := OnceAgain{}
executed := abool.New()
var testWg sync.WaitGroup
// One execution should gobble up the whole batch.
for range 10 {
testWg.Add(100)
for range 100 {
go func() {
oa.Do(func() {
if !executed.SetToIf(false, true) {
t.Errorf("concurrent execution!")
}
time.Sleep(10 * time.Millisecond)
})
testWg.Done()
}()
}
testWg.Wait()
executed.UnSet() // reset check
}
// Continuous use with re-execution.
// Choose values so that about 10 executions are expected
var execs uint32
testWg.Add(100)
for range 100 {
go func() {
oa.Do(func() {
atomic.AddUint32(&execs, 1)
time.Sleep(10 * time.Millisecond)
})
testWg.Done()
}()
time.Sleep(1 * time.Millisecond)
}
testWg.Wait()
if execs <= 8 {
t.Errorf("unexpected low exec count: %d", execs)
}
if execs >= 12 {
t.Errorf("unexpected high exec count: %d", execs)
}
}

View File

@@ -0,0 +1,51 @@
package osdetail
import (
"sync"
"golang.org/x/sys/windows"
)
var (
colorSupport bool
colorSupportChecked bool
checkingColorSupport sync.Mutex
)
// EnableColorSupport tries to enable color support for cmd on windows and returns whether it is enabled.
func EnableColorSupport() bool {
checkingColorSupport.Lock()
defer checkingColorSupport.Unlock()
if !colorSupportChecked {
colorSupport = enableColorSupport()
colorSupportChecked = true
}
return colorSupport
}
func enableColorSupport() bool {
if IsAtLeastWindowsNTVersionWithDefault("10", false) {
// check if windows.Stdout is file
if windows.GetFileInformationByHandle(windows.Stdout, &windows.ByHandleFileInformation{}) == nil {
return false
}
var mode uint32
err := windows.GetConsoleMode(windows.Stdout, &mode)
if err == nil {
if mode&windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING == 0 {
mode |= windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
err = windows.SetConsoleMode(windows.Stdout, mode)
if err != nil {
return false
}
}
return true
}
}
return false
}

View File

@@ -0,0 +1,51 @@
package osdetail
import (
"bytes"
"errors"
"os/exec"
"strings"
)
// RunCmd runs the given command and run error checks on the output.
func RunCmd(command ...string) (output []byte, err error) {
// Create command to execute.
var cmd *exec.Cmd
switch len(command) {
case 0:
return nil, errors.New("no command supplied")
case 1:
cmd = exec.Command(command[0])
default:
cmd = exec.Command(command[0], command[1:]...)
}
// Create and assign output buffers.
var stdoutBuf bytes.Buffer
var stderrBuf bytes.Buffer
cmd.Stdout = &stdoutBuf
cmd.Stderr = &stderrBuf
// Run command and collect output.
err = cmd.Run()
stdout, stderr := stdoutBuf.Bytes(), stderrBuf.Bytes()
if err != nil {
return nil, err
}
// Command might not return an error, but just write to stdout instead.
if len(stderr) > 0 {
return nil, errors.New(strings.SplitN(string(stderr), "\n", 2)[0])
}
// Debugging output:
// fmt.Printf("command stdout: %s\n", stdout)
// fmt.Printf("command stderr: %s\n", stderr)
// Finalize stdout.
cleanedOutput := bytes.TrimSpace(stdout)
if len(cleanedOutput) == 0 {
return nil, ErrEmptyOutput
}
return cleanedOutput, nil
}

View File

@@ -0,0 +1,17 @@
package osdetail
import (
"os/exec"
)
// EnableDNSCache enables the Windows Service "DNS Client" by setting the registry value "HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\Dnscache" to 2 (Automatic).
// A reboot is required for this setting to take effect.
func EnableDNSCache() error {
return exec.Command("reg", "add", "HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\services\\Dnscache", "/v", "Start", "/t", "REG_DWORD", "/d", "2", "/f").Run()
}
// DisableDNSCache disables the Windows Service "DNS Client" by setting the registry value "HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\Dnscache" to 4 (Disabled).
// A reboot is required for this setting to take effect.
func DisableDNSCache() error {
return exec.Command("reg", "add", "HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\services\\Dnscache", "/v", "Start", "/t", "REG_DWORD", "/d", "4", "/f").Run()
}

View File

@@ -0,0 +1,12 @@
package osdetail
import "errors"
var (
// ErrNotSupported is returned when an operation is not supported on the current platform.
ErrNotSupported = errors.New("not supported")
// ErrNotFound is returned when the desired data is not found.
ErrNotFound = errors.New("not found")
// ErrEmptyOutput is a special error that is returned when an operation has no error, but also returns to data.
ErrEmptyOutput = errors.New("command succeeded with empty output")
)

View File

@@ -0,0 +1,112 @@
package osdetail
import (
"errors"
"fmt"
"os/exec"
"strings"
"time"
)
// Service Status
const (
StatusUnknown uint8 = iota
StatusRunningStoppable
StatusRunningNotStoppable
StatusStartPending
StatusStopPending
StatusStopped
)
// Exported errors
var (
ErrServiceNotStoppable = errors.New("the service is not stoppable")
)
// GetServiceStatus returns the current status of a Windows Service (limited implementation).
func GetServiceStatus(name string) (status uint8, err error) {
output, err := exec.Command("sc", "query", name).Output()
if err != nil {
return StatusUnknown, fmt.Errorf("failed to query service: %s", err)
}
outputString := string(output)
switch {
case strings.Contains(outputString, "RUNNING"):
if strings.Contains(outputString, "NOT_STOPPABLE") {
return StatusRunningNotStoppable, nil
}
return StatusRunningStoppable, nil
case strings.Contains(outputString, "STOP_PENDING"):
return StatusStopPending, nil
case strings.Contains(outputString, "STOPPED"):
return StatusStopped, nil
case strings.Contains(outputString, "START_PENDING"):
return StatusStopPending, nil
}
return StatusUnknown, errors.New("unknown service status")
}
// StopService stops a Windows Service.
func StopService(name string) (err error) {
pendingCnt := 0
for {
// get status
status, err := GetServiceStatus(name)
if err != nil {
return err
}
switch status {
case StatusRunningStoppable:
err := exec.Command("sc", "stop", name).Run()
if err != nil {
return fmt.Errorf("failed to stop service: %s", err)
}
case StatusRunningNotStoppable:
return ErrServiceNotStoppable
case StatusStartPending, StatusStopPending:
pendingCnt++
if pendingCnt > 50 {
return errors.New("service stuck in pending status (5s)")
}
case StatusStopped:
return nil
}
time.Sleep(100 * time.Millisecond)
}
}
// SartService starts a Windows Service.
func SartService(name string) (err error) {
pendingCnt := 0
for {
// get status
status, err := GetServiceStatus(name)
if err != nil {
return err
}
switch status {
case StatusRunningStoppable, StatusRunningNotStoppable:
return nil
case StatusStartPending, StatusStopPending:
pendingCnt++
if pendingCnt > 50 {
return errors.New("service stuck in pending status (5s)")
}
case StatusStopped:
err := exec.Command("sc", "start", name).Run()
if err != nil {
return fmt.Errorf("failed to stop service: %s", err)
}
}
time.Sleep(100 * time.Millisecond)
}
}

View File

@@ -0,0 +1,49 @@
package osdetail
import (
"bytes"
"errors"
)
// RunPowershellCmd runs a powershell command and returns its output.
func RunPowershellCmd(script string) (output []byte, err error) {
// Create command to execute.
return RunCmd(
"powershell.exe",
"-ExecutionPolicy", "Bypass",
"-NoProfile",
"-NonInteractive",
"[System.Console]::OutputEncoding = [System.Text.Encoding]::UTF8\n"+script,
)
}
const outputSeparator = "pwzzhtuvpwdgozhzbnjj"
// RunTerminalCmd runs a Windows cmd command and returns its output.
// It sets the output of the cmd to UTF-8 in order to avoid encoding errors.
func RunTerminalCmd(command ...string) (output []byte, err error) {
output, err = RunCmd(append([]string{
"cmd.exe",
"/c",
"chcp", // Set output encoding...
"65001", // ...to UTF-8.
"&",
"echo",
outputSeparator,
"&",
},
command...,
)...)
if err != nil {
return nil, err
}
// Find correct start of output and shift start.
index := bytes.IndexAny(output, outputSeparator+"\r\n")
if index < 0 {
return nil, errors.New("failed to post-process output: could not find output separator")
}
output = output[index+len(outputSeparator)+2:]
return output, nil
}

View File

@@ -0,0 +1,120 @@
package osdetail
import (
"bufio"
"bytes"
"errors"
"fmt"
"os/exec"
"strconv"
"strings"
"sync"
)
var (
serviceNames map[int32][]string
serviceNamesLock sync.Mutex
)
// Errors
var (
ErrServiceNotFound = errors.New("no service with the given PID was found")
)
// GetServiceNames returns all service names assosicated with a svchost.exe process on Windows.
func GetServiceNames(pid int32) ([]string, error) {
serviceNamesLock.Lock()
defer serviceNamesLock.Unlock()
if serviceNames != nil {
names, ok := serviceNames[pid]
if ok {
return names, nil
}
}
serviceNames, err := GetAllServiceNames()
if err != nil {
return nil, err
}
names, ok := serviceNames[pid]
if ok {
return names, nil
}
return nil, ErrServiceNotFound
}
// GetAllServiceNames returns a list of service names assosicated with svchost.exe processes on Windows.
func GetAllServiceNames() (map[int32][]string, error) {
output, err := exec.Command("tasklist", "/svc", "/fi", "imagename eq svchost.exe").Output()
if err != nil {
return nil, fmt.Errorf("failed to get svchost tasklist: %s", err)
}
// file scanner
scanner := bufio.NewScanner(bytes.NewReader(output))
scanner.Split(bufio.ScanLines)
// skip output header
for scanner.Scan() {
if strings.HasPrefix(scanner.Text(), "=") {
break
}
}
var (
pid int32
services []string
collection = make(map[int32][]string)
)
for scanner.Scan() {
// get fields of line
fields := strings.Fields(scanner.Text())
// check fields length
if len(fields) == 0 {
continue
}
// new entry
if fields[0] == "svchost.exe" {
// save old entry
if pid != 0 {
collection[pid] = services
}
// reset PID
pid = 0
services = make([]string, 0, len(fields))
// check fields length
if len(fields) < 3 {
continue
}
// get pid
i, err := strconv.ParseInt(fields[1], 10, 32)
if err != nil {
continue
}
pid = int32(i)
// skip used fields
fields = fields[2:]
}
// add service names
for _, field := range fields {
services = append(services, strings.Trim(strings.TrimSpace(field), ","))
}
}
if pid != 0 {
// save last entry
collection[pid] = services
}
return collection, nil
}

View File

@@ -0,0 +1,99 @@
package osdetail
import (
"fmt"
"strings"
"sync"
"github.com/hashicorp/go-version"
"github.com/shirou/gopsutil/host"
)
var (
// versionRe = regexp.MustCompile(`[0-9\.]+`)
windowsNTVersion string
windowsNTVersionForCmp *version.Version
fetching sync.Mutex
fetched bool
)
// WindowsNTVersion returns the current Windows version.
func WindowsNTVersion() (string, error) {
var err error
fetching.Lock()
defer fetching.Unlock()
if !fetched {
_, _, windowsNTVersion, err = host.PlatformInformation()
windowsNTVersion = strings.SplitN(windowsNTVersion, " ", 2)[0]
if err != nil {
return "", fmt.Errorf("failed to obtain Windows-Version: %s", err)
}
windowsNTVersionForCmp, err = version.NewVersion(windowsNTVersion)
if err != nil {
return "", fmt.Errorf("failed to parse Windows-Version %s: %s", windowsNTVersion, err)
}
fetched = true
}
return windowsNTVersion, err
}
// IsAtLeastWindowsNTVersion returns whether the current WindowsNT version is at least the given version or newer.
func IsAtLeastWindowsNTVersion(v string) (bool, error) {
_, err := WindowsNTVersion()
if err != nil {
return false, err
}
versionForCmp, err := version.NewVersion(v)
if err != nil {
return false, err
}
return windowsNTVersionForCmp.GreaterThanOrEqual(versionForCmp), nil
}
// IsAtLeastWindowsNTVersionWithDefault is like IsAtLeastWindowsNTVersion(), but keeps the Error and returns the default Value in Errorcase
func IsAtLeastWindowsNTVersionWithDefault(v string, defaultValue bool) bool {
val, err := IsAtLeastWindowsNTVersion(v)
if err != nil {
return defaultValue
}
return val
}
// IsAtLeastWindowsVersion returns whether the current Windows version is at least the given version or newer.
func IsAtLeastWindowsVersion(v string) (bool, error) {
var NTVersion string
switch v {
case "7":
NTVersion = "6.1"
case "8":
NTVersion = "6.2"
case "8.1":
NTVersion = "6.3"
case "10":
NTVersion = "10"
default:
return false, fmt.Errorf("failed to compare Windows-Version: Windows %s is unknown", v)
}
return IsAtLeastWindowsNTVersion(NTVersion)
}
// IsAtLeastWindowsVersionWithDefault is like IsAtLeastWindowsVersion(), but keeps the Error and returns the default Value in Errorcase
func IsAtLeastWindowsVersionWithDefault(v string, defaultValue bool) bool {
val, err := IsAtLeastWindowsVersion(v)
if err != nil {
return defaultValue
}
return val
}

View File

@@ -0,0 +1,29 @@
package osdetail
import "testing"
func TestWindowsNTVersion(t *testing.T) {
if str, err := WindowsNTVersion(); str == "" || err != nil {
t.Fatalf("failed to obtain windows version: %s", err)
}
}
func TestIsAtLeastWindowsNTVersion(t *testing.T) {
ret, err := IsAtLeastWindowsNTVersion("6")
if err != nil {
t.Fatalf("failed to compare windows versions: %s", err)
}
if !ret {
t.Fatalf("WindowsNTVersion is less than 6 (Vista)")
}
}
func TestIsAtLeastWindowsVersion(t *testing.T) {
ret, err := IsAtLeastWindowsVersion("7")
if err != nil {
t.Fatalf("failed to compare windows versions: %s", err)
}
if !ret {
t.Fatalf("WindowsVersion is less than 7")
}
}

202
base/utils/renameio/LICENSE Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,55 @@
This is a fork of the github.com/google/renameio Go package at commit 353f8196982447d8b12c64f69530e657331e3dbc.
The inital commit of this package will carry the original package contents.
The Original License is the Apache License in Version 2.0 and the copyright of the forked package is held by Google Inc.
Any changes are recorded in the git history, which is part of this project.
---
The `renameio` Go package provides a way to atomically create or replace a file or
symbolic link.
## Atomicity vs durability
`renameio` concerns itself *only* with atomicity, i.e. making sure applications
never see unexpected file content (a half-written file, or a 0-byte file).
As a practical example, consider https://manpages.debian.org/: if there is a
power outage while the site is updating, we are okay with losing the manpages
which were being rendered at the time of the power outage. They will be added in
a later run of the software. We are not okay with having a manpage replaced by a
0-byte file under any circumstances, though.
## Advantages of this package
There are other packages for atomically replacing files, and sometimes ad-hoc
implementations can be found in programs.
A naive approach to the problem is to create a temporary file followed by a call
to `os.Rename()`. However, there are a number of subtleties which make the
correct sequence of operations hard to identify:
* The temporary file should be removed when an error occurs, but a remove must
not be attempted if the rename succeeded, as a new file might have been
created with the same name. This renders a throwaway `defer
os.Remove(t.Name())` insufficient; state must be kept.
* The temporary file must be created on the same file system (same mount point)
for the rename to work, but the TMPDIR environment variable should still be
respected, e.g. to direct temporary files into a separate directory outside of
the webservers document root but on the same file system.
* On POSIX operating systems, the
[`fsync`](https://manpages.debian.org/stretch/manpages-dev/fsync.2) system
call must be used to ensure that the `os.Rename()` call will not result in a
0-length file.
This package attempts to get all of these details right, provides an intuitive,
yet flexible API and caters to use-cases where high performance is required.
## Disclaimer
This is not an official Google product (experimental or otherwise), it
is just code that happens to be owned by Google.
This project is not affiliated with the Go project.

View File

@@ -0,0 +1,7 @@
// Package renameio provides a way to atomically create or replace a file or
// symbolic link.
//
// Caveat: this package requires the file system rename(2) implementation to be
// atomic. Notably, this is not the case when using NFS with multiple clients:
// https://stackoverflow.com/a/41396801
package renameio

View File

@@ -0,0 +1,57 @@
package renameio_test
import (
"fmt"
"log"
"github.com/safing/portmaster/base/utils/renameio"
)
func ExampleTempFile_justone() { //nolint:testableexamples
persist := func(temperature float64) error {
t, err := renameio.TempFile("", "/srv/www/metrics.txt")
if err != nil {
return err
}
defer func() {
_ = t.Cleanup()
}()
if _, err := fmt.Fprintf(t, "temperature_degc %f\n", temperature); err != nil {
return err
}
return t.CloseAtomicallyReplace()
}
// Thanks to the write package, a webserver exposing /srv/www never
// serves an incomplete or missing file.
if err := persist(31.2); err != nil {
log.Fatal(err)
}
}
func ExampleTempFile_many() { //nolint:testableexamples
// Prepare for writing files to /srv/www, effectively caching calls to
// TempDir which TempFile would otherwise need to make.
dir := renameio.TempDir("/srv/www")
persist := func(temperature float64) error {
t, err := renameio.TempFile(dir, "/srv/www/metrics.txt")
if err != nil {
return err
}
defer func() {
_ = t.Cleanup()
}()
if _, err := fmt.Fprintf(t, "temperature_degc %f\n", temperature); err != nil {
return err
}
return t.CloseAtomicallyReplace()
}
// Imagine this was an endless loop, reading temperature sensor values.
// Thanks to the write package, a webserver exposing /srv/www never
// serves an incomplete or missing file.
for {
if err := persist(31.2); err != nil {
log.Fatal(err)
}
}
}

View File

@@ -0,0 +1,41 @@
//go:build darwin || dragonfly || freebsd || linux || nacl || netbsd || openbsd || solaris || windows
package renameio
import (
"bytes"
"os"
"path/filepath"
"testing"
)
func TestSymlink(t *testing.T) {
t.Parallel()
d, err := os.MkdirTemp("", "test-renameio-testsymlink")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = os.RemoveAll(d)
})
want := []byte("Hello World")
if err := os.WriteFile(filepath.Join(d, "hello.txt"), want, 0o0600); err != nil {
t.Fatal(err)
}
for range 2 {
if err := Symlink("hello.txt", filepath.Join(d, "hi.txt")); err != nil {
t.Fatal(err)
}
got, err := os.ReadFile(filepath.Join(d, "hi.txt"))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, want) {
t.Fatalf("unexpected content: got %q, want %q", string(got), string(want))
}
}
}

View File

@@ -0,0 +1,170 @@
package renameio
import (
"errors"
"io/fs"
"os"
"path/filepath"
)
// TempDir checks whether os.TempDir() can be used as a temporary directory for
// later atomically replacing files within dest. If no (os.TempDir() resides on
// a different mount point), dest is returned.
//
// Note that the returned value ceases to be valid once either os.TempDir()
// changes (e.g. on Linux, once the TMPDIR environment variable changes) or the
// file system is unmounted.
func TempDir(dest string) string {
return tempDir("", filepath.Join(dest, "renameio-TempDir"))
}
func tempDir(dir, dest string) string {
if dir != "" {
return dir // caller-specified directory always wins
}
// Chose the destination directory as temporary directory so that we
// definitely can rename the file, for which both temporary and destination
// file need to point to the same mount point.
fallback := filepath.Dir(dest)
// The user might have overridden the os.TempDir() return value by setting
// the TMPDIR environment variable.
tmpdir := os.TempDir()
testsrc, err := os.CreateTemp(tmpdir, "."+filepath.Base(dest))
if err != nil {
return fallback
}
cleanup := true
defer func() {
if cleanup {
_ = os.Remove(testsrc.Name())
}
}()
_ = testsrc.Close()
testdest, err := os.CreateTemp(filepath.Dir(dest), "."+filepath.Base(dest))
if err != nil {
return fallback
}
defer func() {
_ = os.Remove(testdest.Name())
}()
_ = testdest.Close()
if err := os.Rename(testsrc.Name(), testdest.Name()); err != nil {
return fallback
}
cleanup = false // testsrc no longer exists
return tmpdir
}
// PendingFile is a pending temporary file, waiting to replace the destination
// path in a call to CloseAtomicallyReplace.
type PendingFile struct {
*os.File
path string
done bool
closed bool
}
// Cleanup is a no-op if CloseAtomicallyReplace succeeded, and otherwise closes
// and removes the temporary file.
func (t *PendingFile) Cleanup() error {
if t.done {
return nil
}
// An error occurred. Close and remove the tempfile. Errors are returned for
// reporting, there is nothing the caller can recover here.
var closeErr error
if !t.closed {
closeErr = t.Close()
}
if err := os.Remove(t.Name()); err != nil {
return err
}
return closeErr
}
// CloseAtomicallyReplace closes the temporary file and atomically replaces
// the destination file with it, i.e., a concurrent open(2) call will either
// open the file previously located at the destination path (if any), or the
// just written file, but the file will always be present.
func (t *PendingFile) CloseAtomicallyReplace() error {
// Even on an ordered file system (e.g. ext4 with data=ordered) or file
// systems with write barriers, we cannot skip the fsync(2) call as per
// Theodore Ts'o (ext2/3/4 lead developer):
//
// > data=ordered only guarantees the avoidance of stale data (e.g., the previous
// > contents of a data block showing up after a crash, where the previous data
// > could be someone's love letters, medical records, etc.). Without the fsync(2)
// > a zero-length file is a valid and possible outcome after the rename.
if err := t.Sync(); err != nil {
return err
}
t.closed = true
if err := t.Close(); err != nil {
return err
}
if err := os.Rename(t.Name(), t.path); err != nil {
return err
}
t.done = true
return nil
}
// TempFile wraps os.CreateTemp for the use case of atomically creating or
// replacing the destination file at path.
//
// If dir is the empty string, TempDir(filepath.Base(path)) is used. If you are
// going to write a large number of files to the same file system, store the
// result of TempDir(filepath.Base(path)) and pass it instead of the empty
// string.
//
// The file's permissions will be 0600 by default. You can change these by
// explicitly calling Chmod on the returned PendingFile.
func TempFile(dir, path string) (*PendingFile, error) {
f, err := os.CreateTemp(tempDir(dir, path), "."+filepath.Base(path))
if err != nil {
return nil, err
}
return &PendingFile{File: f, path: path}, nil
}
// Symlink wraps os.Symlink, replacing an existing symlink with the same name
// atomically (os.Symlink fails when newname already exists, at least on Linux).
func Symlink(oldname, newname string) error {
// Fast path: if newname does not exist yet, we can skip the whole dance
// below.
if err := os.Symlink(oldname, newname); err == nil || !errors.Is(err, fs.ErrExist) {
return err
}
// We need to use os.MkdirTemp, as we cannot overwrite a os.CreateTemp,
// and removing+symlinking creates a TOCTOU race.
d, err := os.MkdirTemp(filepath.Dir(newname), "."+filepath.Base(newname))
if err != nil {
return err
}
cleanup := true
defer func() {
if cleanup {
_ = os.RemoveAll(d)
}
}()
symlink := filepath.Join(d, "tmp.symlink")
if err := os.Symlink(oldname, symlink); err != nil {
return err
}
if err := os.Rename(symlink, newname); err != nil {
return err
}
cleanup = false
return os.RemoveAll(d)
}

View File

@@ -0,0 +1,113 @@
//go:build linux
package renameio
import (
"os"
"path/filepath"
"syscall"
"testing"
)
func TestTempDir(t *testing.T) {
t.Parallel()
if tmpdir, ok := os.LookupEnv("TMPDIR"); ok {
t.Cleanup(func() {
_ = os.Setenv("TMPDIR", tmpdir) // restore
})
} else {
t.Cleanup(func() {
_ = os.Unsetenv("TMPDIR") // restore
})
}
mount1, err := os.MkdirTemp("", "test-renameio-testtempdir1")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = os.RemoveAll(mount1)
})
mount2, err := os.MkdirTemp("", "test-renameio-testtempdir2")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = os.RemoveAll(mount2)
})
if err := syscall.Mount("tmpfs", mount1, "tmpfs", 0, ""); err != nil {
t.Skipf("cannot mount tmpfs on %s: %v", mount1, err)
}
t.Cleanup(func() {
_ = syscall.Unmount(mount1, 0)
})
if err := syscall.Mount("tmpfs", mount2, "tmpfs", 0, ""); err != nil {
t.Skipf("cannot mount tmpfs on %s: %v", mount2, err)
}
t.Cleanup(func() {
_ = syscall.Unmount(mount2, 0)
})
tests := []struct {
name string
dir string
path string
TMPDIR string
want string
}{
{
name: "implicit TMPDIR",
path: filepath.Join(os.TempDir(), "foo.txt"),
want: os.TempDir(),
},
{
name: "explicit TMPDIR",
path: filepath.Join(mount1, "foo.txt"),
TMPDIR: mount1,
want: mount1,
},
{
name: "explicit unsuitable TMPDIR",
path: filepath.Join(mount1, "foo.txt"),
TMPDIR: mount2,
want: mount1,
},
{
name: "nonexistant TMPDIR",
path: filepath.Join(mount1, "foo.txt"),
TMPDIR: "/nonexistant",
want: mount1,
},
{
name: "caller-specified",
dir: "/overridden",
path: filepath.Join(mount1, "foo.txt"),
TMPDIR: "/nonexistant",
want: "/overridden",
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
if testCase.TMPDIR == "" {
_ = os.Unsetenv("TMPDIR")
} else {
_ = os.Setenv("TMPDIR", testCase.TMPDIR)
}
if got := tempDir(testCase.dir, testCase.path); got != testCase.want {
t.Fatalf("tempDir(%q, %q): got %q, want %q", testCase.dir, testCase.path, got, testCase.want)
}
})
}
}

View File

@@ -0,0 +1,26 @@
package renameio
import "os"
// WriteFile mirrors os.WriteFile, replacing an existing file with the same
// name atomically.
func WriteFile(filename string, data []byte, perm os.FileMode) error {
t, err := TempFile("", filename)
if err != nil {
return err
}
defer func() {
_ = t.Cleanup()
}()
// Set permissions before writing data, in case the data is sensitive.
if err := t.Chmod(perm); err != nil {
return err
}
if _, err := t.Write(data); err != nil {
return err
}
return t.CloseAtomicallyReplace()
}

View File

@@ -0,0 +1,46 @@
//go:build darwin || dragonfly || freebsd || linux || nacl || netbsd || openbsd || solaris || windows
package renameio
import (
"bytes"
"os"
"path/filepath"
"testing"
)
func TestWriteFile(t *testing.T) {
t.Parallel()
d, err := os.MkdirTemp("", "test-renameio-testwritefile")
if err != nil {
t.Fatal(err)
}
defer func() {
_ = os.RemoveAll(d)
}()
filename := filepath.Join(d, "hello.sh")
wantData := []byte("#!/bin/sh\necho \"Hello World\"\n")
wantPerm := os.FileMode(0o0600)
if err := WriteFile(filename, wantData, wantPerm); err != nil {
t.Fatal(err)
}
gotData, err := os.ReadFile(filename)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(gotData, wantData) {
t.Errorf("got data %v, want data %v", gotData, wantData)
}
fi, err := os.Stat(filename)
if err != nil {
t.Fatal(err)
}
if gotPerm := fi.Mode() & os.ModePerm; gotPerm != wantPerm {
t.Errorf("got permissions 0%o, want permissions 0%o", gotPerm, wantPerm)
}
}

23
base/utils/safe.go Normal file
View File

@@ -0,0 +1,23 @@
package utils
import (
"encoding/hex"
"strings"
)
// SafeFirst16Bytes return the first 16 bytes of the given data in safe form.
func SafeFirst16Bytes(data []byte) string {
if len(data) == 0 {
return "<empty>"
}
return strings.TrimPrefix(
strings.SplitN(hex.Dump(data), "\n", 2)[0],
"00000000 ",
)
}
// SafeFirst16Chars return the first 16 characters of the given data in safe form.
func SafeFirst16Chars(s string) string {
return SafeFirst16Bytes([]byte(s))
}

29
base/utils/safe_test.go Normal file
View File

@@ -0,0 +1,29 @@
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSafeFirst16(t *testing.T) {
t.Parallel()
assert.Equal(t,
"47 6f 20 69 73 20 61 6e 20 6f 70 65 6e 20 73 6f |Go is an open so|",
SafeFirst16Bytes([]byte("Go is an open source programming language.")),
)
assert.Equal(t,
"47 6f 20 69 73 20 61 6e 20 6f 70 65 6e 20 73 6f |Go is an open so|",
SafeFirst16Chars("Go is an open source programming language."),
)
assert.Equal(t,
"<empty>",
SafeFirst16Bytes(nil),
)
assert.Equal(t,
"<empty>",
SafeFirst16Chars(""),
)
}

52
base/utils/slices.go Normal file
View File

@@ -0,0 +1,52 @@
package utils
// IndexOfString returns the index of given string and -1 if its not part of the slice.
func IndexOfString(a []string, s string) int {
for i, entry := range a {
if entry == s {
return i
}
}
return -1
}
// StringInSlice returns whether the given string is in the string slice.
func StringInSlice(a []string, s string) bool {
return IndexOfString(a, s) >= 0
}
// RemoveFromStringSlice removes the given string from the slice and returns a new slice.
func RemoveFromStringSlice(a []string, s string) []string {
i := IndexOfString(a, s)
if i > 0 {
a = append(a[:i], a[i+1:]...)
}
return a
}
// DuplicateStrings returns a new copy of the given string slice.
func DuplicateStrings(a []string) []string {
b := make([]string, len(a))
copy(b, a)
return b
}
// StringSliceEqual returns whether the given string slices are equal.
func StringSliceEqual(a []string, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
// DuplicateBytes returns a new copy of the given byte slice.
func DuplicateBytes(a []byte) []byte {
b := make([]byte, len(a))
copy(b, a)
return b
}

91
base/utils/slices_test.go Normal file
View File

@@ -0,0 +1,91 @@
package utils
import (
"bytes"
"testing"
)
var (
stringTestSlice = []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}
stringTestSlice2 = []string{"a", "x", "x", "x", "x", "x", "x", "x", "x", "j"}
stringTestSlice3 = []string{"a", "x"}
byteTestSlice = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
)
func TestStringInSlice(t *testing.T) {
t.Parallel()
if !StringInSlice(stringTestSlice, "a") {
t.Fatal("string reported not in slice (1), but it is")
}
if !StringInSlice(stringTestSlice, "d") {
t.Fatal("string reported not in slice (2), but it is")
}
if !StringInSlice(stringTestSlice, "j") {
t.Fatal("string reported not in slice (3), but it is")
}
if StringInSlice(stringTestSlice, "0") {
t.Fatal("string reported in slice (1), but is not")
}
if StringInSlice(stringTestSlice, "x") {
t.Fatal("string reported in slice (2), but is not")
}
if StringInSlice(stringTestSlice, "k") {
t.Fatal("string reported in slice (3), but is not")
}
}
func TestRemoveFromStringSlice(t *testing.T) {
t.Parallel()
test1 := DuplicateStrings(stringTestSlice)
test1 = RemoveFromStringSlice(test1, "b")
if StringInSlice(test1, "b") {
t.Fatal("string reported in slice, but was removed")
}
if len(test1) != len(stringTestSlice)-1 {
t.Fatalf("new string slice length not as expected: is %d, should be %d\nnew slice is %v", len(test1), len(stringTestSlice)-1, test1)
}
RemoveFromStringSlice(test1, "b")
}
func TestDuplicateStrings(t *testing.T) {
t.Parallel()
a := DuplicateStrings(stringTestSlice)
if !StringSliceEqual(a, stringTestSlice) {
t.Fatal("copied string slice is not equal")
}
a[0] = "x"
if StringSliceEqual(a, stringTestSlice) {
t.Fatal("copied string slice is not a real copy")
}
}
func TestStringSliceEqual(t *testing.T) {
t.Parallel()
if !StringSliceEqual(stringTestSlice, stringTestSlice) {
t.Fatal("strings are equal, but are reported as not")
}
if StringSliceEqual(stringTestSlice, stringTestSlice2) {
t.Fatal("strings are not equal (1), but are reported as equal")
}
if StringSliceEqual(stringTestSlice, stringTestSlice3) {
t.Fatal("strings are not equal (1), but are reported as equal")
}
}
func TestDuplicateBytes(t *testing.T) {
t.Parallel()
a := DuplicateBytes(byteTestSlice)
if !bytes.Equal(a, byteTestSlice) {
t.Fatal("copied bytes slice is not equal")
}
a[0] = 0xff
if bytes.Equal(a, byteTestSlice) {
t.Fatal("copied bytes slice is not a real copy")
}
}

118
base/utils/stablepool.go Normal file
View File

@@ -0,0 +1,118 @@
package utils
import "sync"
// A StablePool is a drop-in replacement for sync.Pool that is slower, but
// predictable.
// A StablePool is a set of temporary objects that may be individually saved and
// retrieved.
//
// In contrast to sync.Pool, items are not removed automatically. Every item
// will be returned at some point. Items are returned in a FIFO manner in order
// to evenly distribute usage of a set of items.
//
// A StablePool is safe for use by multiple goroutines simultaneously and must
// not be copied after first use.
type StablePool struct {
lock sync.Mutex
pool []interface{}
cnt int
getIndex int
putIndex int
// New optionally specifies a function to generate
// a value when Get would otherwise return nil.
// It may not be changed concurrently with calls to Get.
New func() interface{}
}
// Put adds x to the pool.
func (p *StablePool) Put(x interface{}) {
if x == nil {
return
}
p.lock.Lock()
defer p.lock.Unlock()
// check if pool is full (or unitialized)
if p.cnt == len(p.pool) {
p.pool = append(p.pool, x)
p.cnt++
p.putIndex = p.cnt
return
}
// correct putIndex
p.putIndex %= len(p.pool)
// iterate the whole pool once to find a free spot
stopAt := p.putIndex - 1
for i := p.putIndex; i != stopAt; i = (i + 1) % len(p.pool) {
if p.pool[i] == nil {
p.pool[i] = x
p.cnt++
p.putIndex = i + 1
return
}
}
}
// Get returns the next item from the Pool, removes it from the Pool, and
// returns it to the caller.
// In contrast to sync.Pool, Get never ignores the pool.
// Callers should not assume any relation between values passed to Put and
// the values returned by Get.
//
// If Get would otherwise return nil and p.New is non-nil, Get returns
// the result of calling p.New.
func (p *StablePool) Get() interface{} {
p.lock.Lock()
defer p.lock.Unlock()
// check if pool is empty
if p.cnt == 0 {
if p.New != nil {
return p.New()
}
return nil
}
// correct getIndex
p.getIndex %= len(p.pool)
// iterate the whole pool to find an item
stopAt := p.getIndex - 1
for i := p.getIndex; i != stopAt; i = (i + 1) % len(p.pool) {
if p.pool[i] != nil {
x := p.pool[i]
p.pool[i] = nil
p.cnt--
p.getIndex = i + 1
return x
}
}
// if we ever get here, return a new item
if p.New != nil {
return p.New()
}
return nil
}
// Size returns the amount of items the pool currently holds.
func (p *StablePool) Size() int {
p.lock.Lock()
defer p.lock.Unlock()
return p.cnt
}
// Max returns the amount of items the pool held at maximum.
func (p *StablePool) Max() int {
p.lock.Lock()
defer p.lock.Unlock()
return len(p.pool)
}

View File

@@ -0,0 +1,118 @@
package utils
import (
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestStablePoolRealWorld(t *testing.T) {
t.Parallel()
// "real world" simulation
cnt := 0
testPool := &StablePool{
New: func() interface{} {
cnt++
return cnt
},
}
var testWg sync.WaitGroup
var testWorkerWg sync.WaitGroup
// for i := 0; i < 100; i++ {
// cnt++
// testPool.Put(cnt)
// }
for range 100 {
// block round
testWg.Add(1)
// add workers
testWorkerWg.Add(100)
for j := range 100 {
go func() {
// wait for round to start
testWg.Wait()
// get value
x := testPool.Get()
// fmt.Println(x)
// "work"
time.Sleep(5 * time.Microsecond)
// re-insert 99%
if j%100 > 0 {
testPool.Put(x)
}
// mark as finished
testWorkerWg.Done()
}()
}
// start round
testWg.Done()
// wait for round to finish
testWorkerWg.Wait()
}
t.Logf("real world simulation: cnt=%d p.cnt=%d p.max=%d\n", cnt, testPool.Size(), testPool.Max())
assert.GreaterOrEqual(t, 200, cnt, "should not use more than 200 values")
assert.GreaterOrEqual(t, 100, testPool.Max(), "pool should have at most this max size")
// optimal usage test
optPool := &StablePool{}
for range 1000 {
for j := range 100 {
optPool.Put(j)
}
for k := range 100 {
assert.Equal(t, k, optPool.Get(), "should match")
}
}
assert.Equal(t, 100, optPool.Max(), "pool should have exactly this max size")
}
func TestStablePoolFuzzing(t *testing.T) {
t.Parallel()
// fuzzing test
fuzzPool := &StablePool{}
var fuzzWg sync.WaitGroup
var fuzzWorkerWg sync.WaitGroup
// start goroutines and wait
fuzzWg.Add(1)
for i := range 1000 {
fuzzWorkerWg.Add(2)
go func() {
fuzzWg.Wait()
fuzzPool.Put(i)
fuzzWorkerWg.Done()
}()
go func() {
fuzzWg.Wait()
fmt.Print(fuzzPool.Get())
fuzzWorkerWg.Done()
}()
}
// kick off
fuzzWg.Done()
// wait for all to finish
fuzzWorkerWg.Wait()
}
func TestStablePoolBreaking(t *testing.T) {
t.Parallel()
// try to break it
breakPool := &StablePool{}
for range 10 {
for j := range 100 {
breakPool.Put(nil)
breakPool.Put(j)
breakPool.Put(nil)
}
for k := range 100 {
assert.Equal(t, k, breakPool.Get(), "should match")
}
}
}

139
base/utils/structure.go Normal file
View File

@@ -0,0 +1,139 @@
package utils
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)
// DirStructure represents a directory structure with permissions that should be enforced.
type DirStructure struct {
sync.Mutex
Path string
Dir string
Perm os.FileMode
Parent *DirStructure
Children map[string]*DirStructure
}
// NewDirStructure returns a new DirStructure.
func NewDirStructure(path string, perm os.FileMode) *DirStructure {
return &DirStructure{
Path: path,
Perm: perm,
Children: make(map[string]*DirStructure),
}
}
// ChildDir adds a new child DirStructure and returns it. Should the child already exist, the existing child is returned and the permissions are updated.
func (ds *DirStructure) ChildDir(dirName string, perm os.FileMode) (child *DirStructure) {
ds.Lock()
defer ds.Unlock()
// if exists, update
child, ok := ds.Children[dirName]
if ok {
child.Perm = perm
return child
}
// create new
newDir := &DirStructure{
Path: filepath.Join(ds.Path, dirName),
Dir: dirName,
Perm: perm,
Parent: ds,
Children: make(map[string]*DirStructure),
}
ds.Children[dirName] = newDir
return newDir
}
// Ensure ensures that the specified directory structure (from the first parent on) exists.
func (ds *DirStructure) Ensure() error {
return ds.EnsureAbsPath(ds.Path)
}
// EnsureRelPath ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists.
func (ds *DirStructure) EnsureRelPath(dirPath string) error {
return ds.EnsureAbsPath(filepath.Join(ds.Path, dirPath))
}
// EnsureRelDir ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists.
func (ds *DirStructure) EnsureRelDir(dirNames ...string) error {
return ds.EnsureAbsPath(filepath.Join(append([]string{ds.Path}, dirNames...)...))
}
// EnsureAbsPath ensures that the specified directory structure (from the first parent on) and the given absolute path exists.
// If the given path is outside the DirStructure, an error will be returned.
func (ds *DirStructure) EnsureAbsPath(dirPath string) error {
// always start at the top
if ds.Parent != nil {
return ds.Parent.EnsureAbsPath(dirPath)
}
// check if root
if dirPath == ds.Path {
return ds.ensure(nil)
}
// check scope
slashedPath := ds.Path
// add slash to end
if !strings.HasSuffix(slashedPath, string(filepath.Separator)) {
slashedPath += string(filepath.Separator)
}
// check if given path is in scope
if !strings.HasPrefix(dirPath, slashedPath) {
return fmt.Errorf(`path "%s" is outside of DirStructure scope`, dirPath)
}
// get relative path
relPath, err := filepath.Rel(ds.Path, dirPath)
if err != nil {
return fmt.Errorf("failed to get relative path: %w", err)
}
// split to path elements
pathDirs := strings.Split(filepath.ToSlash(relPath), "/")
// start checking
return ds.ensure(pathDirs)
}
func (ds *DirStructure) ensure(pathDirs []string) error {
ds.Lock()
defer ds.Unlock()
// check current dir
err := EnsureDirectory(ds.Path, ds.Perm)
if err != nil {
return err
}
if len(pathDirs) == 0 {
// we reached the end!
return nil
}
child, ok := ds.Children[pathDirs[0]]
if !ok {
// we have reached the end of the defined dir structure
// ensure all remaining dirs
dirPath := ds.Path
for _, dir := range pathDirs {
dirPath = filepath.Join(dirPath, dir)
err := EnsureDirectory(dirPath, ds.Perm)
if err != nil {
return err
}
}
return nil
}
// we got a child, continue
return child.ensure(pathDirs[1:])
}

View File

@@ -0,0 +1,73 @@
//go:build !windows
package utils
import (
"fmt"
"os"
"path/filepath"
"strings"
)
func ExampleDirStructure() {
// output:
// / [755]
// /repo [777]
// /repo/b [707]
// /repo/b/c [750]
// /repo/b/d [707]
// /repo/b/d/e [707]
// /repo/b/d/f [707]
// /repo/b/d/f/g [707]
// /repo/b/d/f/g/h [707]
// /secret [700]
basePath, err := os.MkdirTemp("", "")
if err != nil {
fmt.Println(err)
return
}
ds := NewDirStructure(basePath, 0o0755)
secret := ds.ChildDir("secret", 0o0700)
repo := ds.ChildDir("repo", 0o0777)
_ = repo.ChildDir("a", 0o0700)
b := repo.ChildDir("b", 0o0707)
c := b.ChildDir("c", 0o0750)
err = ds.Ensure()
if err != nil {
fmt.Println(err)
}
err = c.Ensure()
if err != nil {
fmt.Println(err)
}
err = secret.Ensure()
if err != nil {
fmt.Println(err)
}
err = b.EnsureRelDir("d", "e")
if err != nil {
fmt.Println(err)
}
err = b.EnsureRelPath("d/f/g/h")
if err != nil {
fmt.Println(err)
}
_ = filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
if err == nil {
dir := strings.TrimPrefix(path, basePath)
if dir == "" {
dir = "/"
}
fmt.Printf("%s [%o]\n", dir, info.Mode().Perm())
}
return nil
})
}

45
base/utils/uuid.go Normal file
View File

@@ -0,0 +1,45 @@
package utils
import (
"encoding/binary"
"time"
"github.com/gofrs/uuid"
)
var (
constantUUID = uuid.Must(uuid.FromString("e8dba9f7-21e2-4c82-96cb-6586922c6422"))
instanceUUID = RandomUUID("instance")
)
// RandomUUID returns a new random UUID with optionally provided ns.
func RandomUUID(ns string) uuid.UUID {
randUUID, err := uuid.NewV4()
switch {
case err != nil:
// fallback
// should practically never happen
return uuid.NewV5(uuidFromTime(), ns)
case ns != "":
// mix ns into the UUID
return uuid.NewV5(randUUID, ns)
default:
return randUUID
}
}
// DerivedUUID returns a new UUID that is derived from the input only, and therefore is always reproducible.
func DerivedUUID(input string) uuid.UUID {
return uuid.NewV5(constantUUID, input)
}
// DerivedInstanceUUID returns a new UUID that is derived from the input, but is unique per instance (execution) and therefore is only reproducible with the same process.
func DerivedInstanceUUID(input string) uuid.UUID {
return uuid.NewV5(instanceUUID, input)
}
func uuidFromTime() uuid.UUID {
var timeUUID uuid.UUID
binary.LittleEndian.PutUint64(timeUUID[:], uint64(time.Now().UnixNano()))
return timeUUID
}

71
base/utils/uuid_test.go Normal file
View File

@@ -0,0 +1,71 @@
package utils
import (
"testing"
"time"
"github.com/gofrs/uuid"
)
func TestUUID(t *testing.T) {
t.Parallel()
// check randomness
a := RandomUUID("")
a2 := RandomUUID("")
if a.String() == a2.String() {
t.Error("should not match")
}
// check with input
b := RandomUUID("b")
b2 := RandomUUID("b")
if b.String() == b2.String() {
t.Error("should not match")
}
// check with long input
c := RandomUUID("TG8UkxS+4rVrDxHtDAaNab1CBpygzmX1g5mJA37jbQ5q2uE4rVrDxHtDAaNab1CBpygzmX1g5mJA37jbQ5q2uE")
c2 := RandomUUID("TG8UkxS+4rVrDxHtDAaNab1CBpygzmX1g5mJA37jbQ5q2uE4rVrDxHtDAaNab1CBpygzmX1g5mJA37jbQ5q2uE")
if c.String() == c2.String() {
t.Error("should not match")
}
// check for nanosecond precision
d := uuidFromTime()
time.Sleep(2 * time.Nanosecond)
d2 := uuidFromTime()
if d.String() == d2.String() {
t.Error("should not match")
}
// check mixing
timeUUID := uuidFromTime()
e := uuid.NewV5(timeUUID, "e")
e2 := uuid.NewV5(timeUUID, "e2")
if e.String() == e2.String() {
t.Error("should not match")
}
// check deriving
f := DerivedUUID("f")
f2 := DerivedUUID("f")
f3 := DerivedUUID("f3")
if f.String() != f2.String() {
t.Error("should match")
}
if f.String() == f3.String() {
t.Error("should not match")
}
// check instance deriving
g := DerivedInstanceUUID("g")
g2 := DerivedInstanceUUID("g")
g3 := DerivedInstanceUUID("g3")
if g.String() != g2.String() {
t.Error("should match")
}
if g.String() == g3.String() {
t.Error("should not match")
}
}