Improve geoip database downloading and do not block connections

This commit is contained in:
Patrick Pacher
2021-08-18 15:28:50 +02:00
parent 26608c2e21
commit c02f2cb593
3 changed files with 189 additions and 130 deletions

View File

@@ -1,10 +1,10 @@
package geoip package geoip
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"time"
"github.com/tevino/abool"
maxminddb "github.com/oschwald/maxminddb-golang" maxminddb "github.com/oschwald/maxminddb-golang"
@@ -13,106 +13,205 @@ import (
"github.com/safing/portmaster/updates" "github.com/safing/portmaster/updates"
) )
var ( var worker *updateWorker
geoDBv4File *updater.File
geoDBv6File *updater.File
dbFileLock sync.Mutex
geoDBv4Reader *maxminddb.Reader func init() {
geoDBv6Reader *maxminddb.Reader worker = &updateWorker{
dbLock sync.RWMutex trigger: make(chan struct{}),
}
}
dbInUse = abool.NewBool(false) // only activate if used for first time const (
dbDoReload = abool.NewBool(true) // if database should be reloaded v4MMDBResource = "intel/geoip/geoipv4.mmdb.gz"
v6MMDBResource = "intel/geoip/geoipv6.mmdb.gz"
) )
// ReloadDatabases reloads the geoip database, if they are in use. type geoIPDB struct {
func ReloadDatabases() error { *maxminddb.Reader
// don't do anything if the database isn't actually used file *updater.File
if !dbInUse.IsSet() { }
// updateBroadcaster stores a geoIPDB and provides synchronized
// access to the MMDB reader. It also supports broadcasting to
// multiple waiters when a new database becomes available.
type updateBroadcaster struct {
rw sync.RWMutex
db *geoIPDB
waiters []chan struct{}
}
// NeedsUpdate returns true if the current broadcaster needs a
// database update.
func (ub *updateBroadcaster) NeedsUpdate() bool {
ub.rw.RLock()
defer ub.rw.RUnlock()
return ub.db == nil || ub.db.file.UpgradeAvailable()
}
// ReplaceDatabase replaces (or initialy sets) the mmdb database.
// It also notifies all waiters about the availability of the new
// database.
func (ub *updateBroadcaster) ReplaceDatabase(db *geoIPDB) {
ub.rw.Lock()
defer ub.rw.Unlock()
if ub.db != nil {
ub.db.Close()
}
ub.db = db
ub.notifyWaiters()
}
// notifyWaiters notifies and removes all waiters. Must be called
// with ub.rw locked.
func (ub *updateBroadcaster) notifyWaiters() {
waiters := ub.waiters
ub.waiters = nil
for _, c := range waiters {
close(c)
}
}
// getWaiter appends and returns a new waiter channel that gets closed
// when a new database version is available. Must be called with
// ub.rw locked.
func (ub *updateBroadcaster) getWaiter() chan struct{} {
ch := make(chan struct{})
ub.waiters = append(ub.waiters, ch)
return ch
}
type updateWorker struct {
trigger chan struct{}
once sync.Once
v4 updateBroadcaster
v6 updateBroadcaster
}
// GetReader returns a MMDB reader for either the IPv4 or the IPv6 database.
// If wait is true GetReader will wait at most 1 second for the database to
// become available. If no database is available or GetReader times-out while
// waiting nil is returned.
func (upd *updateWorker) GetReader(v6 bool, wait bool) *maxminddb.Reader {
// check which updateBroadcaster we need to use
var ub *updateBroadcaster = &upd.v4
if v6 {
ub = &upd.v6
}
// lock the updateBroadcaster and - if we are allowed to wait -
// create a new waiter channel, trigger an update and wait for at
// least 1 second for the update to complete.
ub.rw.Lock()
if ub.db == nil {
if wait {
waiter := ub.getWaiter()
ub.rw.Unlock()
upd.triggerUpdate()
select {
case <-waiter:
// call this method again but this time we don't allow
// it to wait since there must be a open database anyway ...
return upd.GetReader(v6, false)
case <-time.After(time.Second):
// we tried hard but failed so give up here
return nil
}
}
ub.rw.Unlock()
return nil return nil
} }
rd := ub.db.Reader
ub.rw.Unlock()
dbFileLock.Lock() return rd
defer dbFileLock.Unlock()
dbLock.Lock()
defer dbLock.Unlock()
dbDoReload.Set()
return doReload()
} }
func prepDatabaseForUse() error { // triggerUpdate triggers a database update check.
dbInUse.Set() func (upd *updateWorker) triggerUpdate() {
return doReload() upd.start()
select {
case upd.trigger <- struct{}{}:
default:
}
} }
func doReload() error { func (upd *updateWorker) start() {
// reload if needed upd.once.Do(func() {
if dbDoReload.SetToIf(true, false) { module.StartServiceWorker("geoip-updater", time.Second*10, upd.run)
closeDBs() })
if err := openDBs(); err != nil { }
// try again the next time
dbDoReload.SetTo(true) func (upd *updateWorker) run(ctx context.Context) error {
return err firstTime := true
for {
// update immediately if we're just got started (that happens in
// triggerUpdate() and we might get started a bit late for the
// channel notification to be sent). If not, wait for the next
// trigger our our ctx to be cancelled during shutdown.
if !firstTime {
select {
case <-ctx.Done():
return nil
case <-upd.trigger:
}
}
firstTime = false
if upd.v4.NeedsUpdate() {
if v4, err := getGeoIPDB(v4MMDBResource); err == nil {
upd.v4.ReplaceDatabase(v4)
} else {
log.Warningf("geoip: failed to get v4 database: %s", err)
}
}
if upd.v6.NeedsUpdate() {
if v6, err := getGeoIPDB(v6MMDBResource); err == nil {
upd.v6.ReplaceDatabase(v6)
} else {
log.Warningf("geoip: failed to get v6 database: %s", err)
}
} }
} }
return nil
} }
func openDBs() error { func getGeoIPDB(resource string) (*geoIPDB, error) {
var err error log.Debugf("geoip: opening database %s", resource)
geoDBv4File, err = updates.GetFile("intel/geoip/geoipv4.mmdb.gz") file, unpackedPath, err := openAndUnpack(resource)
if err != nil { if err != nil {
return fmt.Errorf("could not get GeoIP v4 database file: %s", err) return nil, err
}
unpackedV4, err := geoDBv4File.Unpack(".gz", updater.UnpackGZIP)
if err != nil {
return err
}
geoDBv4Reader, err = maxminddb.Open(unpackedV4)
if err != nil {
return err
} }
geoDBv6File, err = updates.GetFile("intel/geoip/geoipv6.mmdb.gz") reader, err := maxminddb.Open(unpackedPath)
if err != nil { if err != nil {
return fmt.Errorf("could not get GeoIP v6 database file: %s", err) return nil, fmt.Errorf("failed to open: %w", err)
}
unpackedV6, err := geoDBv6File.Unpack(".gz", updater.UnpackGZIP)
if err != nil {
return err
}
geoDBv6Reader, err = maxminddb.Open(unpackedV6)
if err != nil {
return err
} }
log.Debugf("geoip: successfully opened database %s", resource)
return nil return &geoIPDB{
Reader: reader,
file: file,
}, nil
} }
func handleError(err error) { func openAndUnpack(resource string) (*updater.File, string, error) {
log.Errorf("network/geoip: lookup failed, reloading databases: %s", err) f, err := updates.GetFile(resource)
dbDoReload.Set() if err != nil {
} return nil, "", fmt.Errorf("getting file: %w", err)
func closeDBs() {
if geoDBv4Reader != nil {
err := geoDBv4Reader.Close()
if err != nil {
log.Warningf("network/geoip: failed to close database: %s", err)
}
} }
geoDBv4Reader = nil
if geoDBv6Reader != nil { unpacked, err := f.Unpack(".gz", updater.UnpackGZIP)
err := geoDBv6Reader.Close() if err != nil {
if err != nil { return nil, "", fmt.Errorf("unpacking file: %w", err)
log.Warningf("network/geoip: failed to close database: %s", err)
}
} }
geoDBv6Reader = nil
return f, unpacked, nil
} }

View File

@@ -1,52 +1,26 @@
package geoip package geoip
import ( import (
"fmt"
"net" "net"
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
) )
func getReader(ip net.IP) *maxminddb.Reader { func getReader(ip net.IP) *maxminddb.Reader {
if v4 := ip.To4(); v4 != nil { isV6 := ip.To4() == nil
return geoDBv4Reader return worker.GetReader(isV6, true)
}
return geoDBv6Reader
} }
// GetLocation returns Location data of an IP address // GetLocation returns Location data of an IP address
func GetLocation(ip net.IP) (record *Location, err error) { func GetLocation(ip net.IP) (*Location, error) {
dbLock.RLock()
defer dbLock.RUnlock()
err = prepDatabaseForUse()
if err != nil {
return nil, err
}
db := getReader(ip) db := getReader(ip)
if db == nil {
record = &Location{} return nil, fmt.Errorf("geoip database not available")
// fetch
err = db.Lookup(ip, record)
// retry
if err != nil {
// reprep
handleError(err)
err = prepDatabaseForUse()
if err != nil {
return nil, err
}
db = getReader(ip)
// refetch
err = db.Lookup(ip, record)
} }
var record Location
if err != nil { if err := db.Lookup(ip, &record); err != nil {
return nil, err return nil, err
} }
return &record, nil
return record, nil
} }

View File

@@ -20,23 +20,9 @@ func prep() error {
updates.ModuleName, updates.ModuleName,
updates.ResourceUpdateEvent, updates.ResourceUpdateEvent,
"Check for GeoIP database updates", "Check for GeoIP database updates",
upgradeDatabases, func(c context.Context, i interface{}) error {
worker.triggerUpdate()
return nil
},
) )
} }
func upgradeDatabases(_ context.Context, _ interface{}) error {
dbFileLock.Lock()
reload := false
if geoDBv4File != nil && geoDBv4File.UpgradeAvailable() {
reload = true
}
if geoDBv6File != nil && geoDBv6File.UpgradeAvailable() {
reload = true
}
dbFileLock.Unlock()
if reload {
return ReloadDatabases()
}
return nil
}