diff --git a/intel/geoip/database.go b/intel/geoip/database.go index 57c07489..5b0059a6 100644 --- a/intel/geoip/database.go +++ b/intel/geoip/database.go @@ -1,10 +1,10 @@ package geoip import ( + "context" "fmt" "sync" - - "github.com/tevino/abool" + "time" maxminddb "github.com/oschwald/maxminddb-golang" @@ -13,106 +13,198 @@ import ( "github.com/safing/portmaster/updates" ) -var ( - geoDBv4File *updater.File - geoDBv6File *updater.File - dbFileLock sync.Mutex +var worker *updateWorker - geoDBv4Reader *maxminddb.Reader - geoDBv6Reader *maxminddb.Reader - dbLock sync.RWMutex +func init() { + worker = &updateWorker{ + trigger: make(chan struct{}), + } +} - dbInUse = abool.NewBool(false) // only activate if used for first time - dbDoReload = abool.NewBool(true) // if database should be reloaded +const ( + v4MMDBResource = "intel/geoip/geoipv4.mmdb.gz" + v6MMDBResource = "intel/geoip/geoipv6.mmdb.gz" ) -// ReloadDatabases reloads the geoip database, if they are in use. -func ReloadDatabases() error { - // don't do anything if the database isn't actually used - if !dbInUse.IsSet() { +type geoIPDB struct { + *maxminddb.Reader + file *updater.File +} + +// updateBroadcaster stores a geoIPDB and provides synchronized +// access to the MMDB reader. It also supports broadcasting to +// multiple waiters when a new database becomes available. +type updateBroadcaster struct { + rw sync.RWMutex + db *geoIPDB + + waiter chan struct{} +} + +// NeedsUpdate returns true if the current broadcaster needs a +// database update. +func (ub *updateBroadcaster) NeedsUpdate() bool { + ub.rw.RLock() + defer ub.rw.RUnlock() + + return ub.db == nil || ub.db.file.UpgradeAvailable() +} + +// ReplaceDatabase replaces (or initialy sets) the mmdb database. +// It also notifies all waiters about the availability of the new +// database. +func (ub *updateBroadcaster) ReplaceDatabase(db *geoIPDB) { + ub.rw.Lock() + defer ub.rw.Unlock() + + if ub.db != nil { + ub.db.Close() + } + ub.db = db + ub.notifyWaiters() +} + +// notifyWaiters notifies and removes all waiters. Must be called +// with ub.rw locked. +func (ub *updateBroadcaster) notifyWaiters() { + waiter := ub.waiter + ub.waiter = nil + close(waiter) +} + +// getWaiter appends and returns a new waiter channel that gets closed +// when a new database version is available. Must be called with +// ub.rw locked. +func (ub *updateBroadcaster) getWaiter() chan struct{} { + if ub.waiter != nil { + return ub.waiter + } + + ub.waiter = make(chan struct{}) + return ub.waiter +} + +type updateWorker struct { + trigger chan struct{} + once sync.Once + + v4 updateBroadcaster + v6 updateBroadcaster +} + +// GetReader returns a MMDB reader for either the IPv4 or the IPv6 database. +// If wait is true GetReader will wait at most 1 second for the database to +// become available. If no database is available or GetReader times-out while +// waiting nil is returned. +func (upd *updateWorker) GetReader(v6 bool, wait bool) *maxminddb.Reader { + // check which updateBroadcaster we need to use + var ub *updateBroadcaster = &upd.v4 + if v6 { + ub = &upd.v6 + } + + // lock the updateBroadcaster and - if we are allowed to wait - + // create a new waiter channel, trigger an update and wait for at + // least 1 second for the update to complete. + ub.rw.Lock() + if ub.db == nil { + if wait { + waiter := ub.getWaiter() + ub.rw.Unlock() + + upd.triggerUpdate() + + select { + case <-waiter: + // call this method again but this time we don't allow + // it to wait since there must be a open database anyway ... + return upd.GetReader(v6, false) + case <-time.After(time.Second): + // we tried hard but failed so give up here + return nil + } + } + ub.rw.Unlock() return nil } + rd := ub.db.Reader + ub.rw.Unlock() - dbFileLock.Lock() - defer dbFileLock.Unlock() - - dbLock.Lock() - defer dbLock.Unlock() - - dbDoReload.Set() - return doReload() + return rd } -func prepDatabaseForUse() error { - dbInUse.Set() - return doReload() +// triggerUpdate triggers a database update check. +func (upd *updateWorker) triggerUpdate() { + upd.start() + + select { + case upd.trigger <- struct{}{}: + default: + } } -func doReload() error { - // reload if needed - if dbDoReload.SetToIf(true, false) { - closeDBs() - if err := openDBs(); err != nil { - // try again the next time - dbDoReload.SetTo(true) - return err +func (upd *updateWorker) start() { + upd.once.Do(func() { + module.StartServiceWorker("geoip-updater", time.Second*10, upd.run) + }) +} + +func (upd *updateWorker) run(ctx context.Context) error { + for { + if upd.v4.NeedsUpdate() { + if v4, err := getGeoIPDB(v4MMDBResource); err == nil { + upd.v4.ReplaceDatabase(v4) + } else { + log.Warningf("geoip: failed to get v4 database: %s", err) + } + } + + if upd.v6.NeedsUpdate() { + if v6, err := getGeoIPDB(v6MMDBResource); err == nil { + upd.v6.ReplaceDatabase(v6) + } else { + log.Warningf("geoip: failed to get v6 database: %s", err) + } + } + + select { + case <-ctx.Done(): + return nil + case <-upd.trigger: } } - - return nil } -func openDBs() error { - var err error +func getGeoIPDB(resource string) (*geoIPDB, error) { + log.Debugf("geoip: opening database %s", resource) - geoDBv4File, err = updates.GetFile("intel/geoip/geoipv4.mmdb.gz") + file, unpackedPath, err := openAndUnpack(resource) if err != nil { - return fmt.Errorf("could not get GeoIP v4 database file: %s", err) - } - unpackedV4, err := geoDBv4File.Unpack(".gz", updater.UnpackGZIP) - if err != nil { - return err - } - geoDBv4Reader, err = maxminddb.Open(unpackedV4) - if err != nil { - return err + return nil, err } - geoDBv6File, err = updates.GetFile("intel/geoip/geoipv6.mmdb.gz") + reader, err := maxminddb.Open(unpackedPath) if err != nil { - return fmt.Errorf("could not get GeoIP v6 database file: %s", err) - } - unpackedV6, err := geoDBv6File.Unpack(".gz", updater.UnpackGZIP) - if err != nil { - return err - } - geoDBv6Reader, err = maxminddb.Open(unpackedV6) - if err != nil { - return err + return nil, fmt.Errorf("failed to open: %w", err) } + log.Debugf("geoip: successfully opened database %s", resource) - return nil + return &geoIPDB{ + Reader: reader, + file: file, + }, nil } -func handleError(err error) { - log.Errorf("network/geoip: lookup failed, reloading databases: %s", err) - dbDoReload.Set() -} - -func closeDBs() { - if geoDBv4Reader != nil { - err := geoDBv4Reader.Close() - if err != nil { - log.Warningf("network/geoip: failed to close database: %s", err) - } +func openAndUnpack(resource string) (*updater.File, string, error) { + f, err := updates.GetFile(resource) + if err != nil { + return nil, "", fmt.Errorf("getting file: %w", err) } - geoDBv4Reader = nil - if geoDBv6Reader != nil { - err := geoDBv6Reader.Close() - if err != nil { - log.Warningf("network/geoip: failed to close database: %s", err) - } + unpacked, err := f.Unpack(".gz", updater.UnpackGZIP) + if err != nil { + return nil, "", fmt.Errorf("unpacking file: %w", err) } - geoDBv6Reader = nil + + return f, unpacked, nil } diff --git a/intel/geoip/lookup.go b/intel/geoip/lookup.go index 16a8f59f..61c1487a 100644 --- a/intel/geoip/lookup.go +++ b/intel/geoip/lookup.go @@ -1,52 +1,26 @@ package geoip import ( + "fmt" "net" "github.com/oschwald/maxminddb-golang" ) func getReader(ip net.IP) *maxminddb.Reader { - if v4 := ip.To4(); v4 != nil { - return geoDBv4Reader - } - return geoDBv6Reader + isV6 := ip.To4() == nil + return worker.GetReader(isV6, true) } // GetLocation returns Location data of an IP address -func GetLocation(ip net.IP) (record *Location, err error) { - dbLock.RLock() - defer dbLock.RUnlock() - - err = prepDatabaseForUse() - if err != nil { - return nil, err - } - +func GetLocation(ip net.IP) (*Location, error) { db := getReader(ip) - - record = &Location{} - - // fetch - err = db.Lookup(ip, record) - - // retry - if err != nil { - // reprep - handleError(err) - err = prepDatabaseForUse() - if err != nil { - return nil, err - } - db = getReader(ip) - - // refetch - err = db.Lookup(ip, record) + if db == nil { + return nil, fmt.Errorf("geoip database not available") } - - if err != nil { + var record Location + if err := db.Lookup(ip, &record); err != nil { return nil, err } - - return record, nil + return &record, nil } diff --git a/intel/geoip/module.go b/intel/geoip/module.go index 4a528052..fb2bc00b 100644 --- a/intel/geoip/module.go +++ b/intel/geoip/module.go @@ -20,23 +20,9 @@ func prep() error { updates.ModuleName, updates.ResourceUpdateEvent, "Check for GeoIP database updates", - upgradeDatabases, + func(c context.Context, i interface{}) error { + worker.triggerUpdate() + return nil + }, ) } - -func upgradeDatabases(_ context.Context, _ interface{}) error { - dbFileLock.Lock() - reload := false - if geoDBv4File != nil && geoDBv4File.UpgradeAvailable() { - reload = true - } - if geoDBv6File != nil && geoDBv6File.UpgradeAvailable() { - reload = true - } - dbFileLock.Unlock() - - if reload { - return ReloadDatabases() - } - return nil -}