From b0c333b799ccbf3818c9f2752894cf25b1112dba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20F=C3=A4hrmann?= Date: Wed, 20 Apr 2016 08:40:41 +0200 Subject: [PATCH] rewrite cache module --- gallery_dl/__init__.py | 3 +- gallery_dl/cache.py | 216 +++++++++++++++++++++++++++++----------- test/test_extractors.py | 1 - 3 files changed, 157 insertions(+), 63 deletions(-) diff --git a/gallery_dl/__init__.py b/gallery_dl/__init__.py index bd455204..c966e9ea 100644 --- a/gallery_dl/__init__.py +++ b/gallery_dl/__init__.py @@ -16,7 +16,7 @@ __email__ = "mike_faehrmann@web.de" import os import argparse -from . import config, extractor, jobs, cache +from . import config, extractor, jobs def build_cmdline_parser(): parser = argparse.ArgumentParser( @@ -80,7 +80,6 @@ def main(): else: if not args.urls: parser.error("the following arguments are required: URL") - cache.init_database() if args.list_urls: jobtype = jobs.UrlJob elif args.list_keywords: diff --git a/gallery_dl/cache.py b/gallery_dl/cache.py index 128d142c..972bf874 100644 --- a/gallery_dl/cache.py +++ b/gallery_dl/cache.py @@ -6,82 +6,178 @@ # it under the terms of the GNU General Public License version 2 as # published by the Free Software Foundation. +"""Decorator to keep function results in a combined in-memory and database cache""" + import sqlite3 import pickle import time import tempfile import os +import functools from . import config + class CacheInvalidError(Exception): + """A cache entry is either expired or does not exist""" pass -def init_database(): - global _db - path_default = os.path.join(tempfile.gettempdir(), ".gallery-dl.cache") - path = config.get(("cache", "file"), path_default) - _db = sqlite3.connect(path, timeout=30, check_same_thread=False) - _db.execute("CREATE TABLE IF NOT EXISTS data (" - "key TEXT PRIMARY KEY," - "value TEXT," - "expires INTEGER" - ")") -def cache(maxage=3600, keyarg=None): - """decorator - cache function return values in memory and database""" - def wrap(func): - gkey = "{}.{}".format(func.__module__, func.__name__) +class CacheModule(): + """Base class for cache modules""" + def __init__(self): + self.child = None - def wrapped(*args): - timestamp = time.time() - if keyarg is not None: - key = "{}-{}".format(gkey, args[keyarg]) - else: - key = gkey + def __getitem__(self, key): + raise CacheInvalidError() + def __setitem__(self, key, item): + pass + + def __enter__(self): + pass + + def __exit__(self, *exc_info): + pass + + +class CacheChain(CacheModule): + + def __init__(self, modules=[]): + CacheModule.__init__(self) + self.modules = modules + + def __getitem__(self, key): + num = 0 + for module in self.modules: try: - result = lookup_cache(key, timestamp) + value = module[key] + break except CacheInvalidError: - try: - result = func(*args) - expires = int(timestamp+maxage) - _cache[key] = (result, expires) - _db.execute("INSERT OR REPLACE INTO data VALUES (?,?,?)", - (key, pickle.dumps(result), expires)) - finally: - _db.commit() - return result - - def lookup_cache(key, timestamp): - try: - result, expires = _cache[key] - if timestamp < expires: - return result - except KeyError: - pass - result, expires = lookup_database(key, timestamp) - _cache[key] = (result, expires) - return result - - def lookup_database(key, timestamp): - try: - cursor = _db.cursor() - cursor.execute("BEGIN EXCLUSIVE") - cursor.execute("SELECT value, expires FROM data WHERE key=?", - (key,)) - result, expires = cursor.fetchone() - if timestamp < expires: - _db.commit() - return pickle.loads(result), expires - except TypeError: - pass + num += 1 + else: raise CacheInvalidError() + while num: + num -= 1 + self.modules[num][key[0]] = value + return value - return wrapped - return wrap + def __setitem__(self, key, item): + for module in self.modules: + module.__setitem__(key, item) -# -------------------------------------------------------------------- -# internals + def __exit__(self, exc_type, exc_value, exc_traceback): + for module in self.modules: + module.__exit__(exc_type, exc_value, exc_traceback) -_db = None -_cache = {} + +class MemoryCache(CacheModule): + """In-memory cache module""" + def __init__(self): + CacheModule.__init__(self) + self.cache = {} + + def __getitem__(self, key): + key, timestamp = key + try: + value, expires = self.cache[key] + if timestamp < expires: + return value, expires + except KeyError: + pass + raise CacheInvalidError() + + def __setitem__(self, key, item): + self.cache[key] = item + + +class DatabaseCache(CacheModule): + """Database cache module""" + def __init__(self): + CacheModule.__init__(self) + path_default = os.path.join(tempfile.gettempdir(), ".gallery-dl.cache") + path = config.get(("cache", "file"), path_default) + if path is None: + raise RuntimeError() + self.db = sqlite3.connect(path, timeout=30, check_same_thread=False) + self.db.execute("CREATE TABLE IF NOT EXISTS data (" + "key TEXT PRIMARY KEY," + "value TEXT," + "expires INTEGER" + ")") + + def __getitem__(self, key): + key, timestamp = key + try: + cursor = self.db.cursor() + cursor.execute("BEGIN EXCLUSIVE") + cursor.execute("SELECT value, expires FROM data WHERE key=?", (key,)) + value, expires = cursor.fetchone() + if timestamp < expires: + self.commit() + return pickle.loads(value), expires + except TypeError: + pass + raise CacheInvalidError() + + def __setitem__(self, key, item): + value, expires = item + self.db.execute("INSERT OR REPLACE INTO data VALUES (?,?,?)", + (key, pickle.dumps(value), expires)) + + def __exit__(self, *exc_info): + self.commit() + + def commit(self): + self.db.commit() + + +class CacheDecorator(): + + def __init__(self, func, module, maxage, keyarg): + self.func = func + self.key = "%s.%s" % (func.__module__, func.__name__) + self.cache = module + self.maxage = maxage + self.keyarg = keyarg + + def __call__(self, *args, **kwargs): + timestamp = time.time() + if self.keyarg is None: + key = self.key + else: + key = "%s-%s" % (self.key, args[self.keyarg]) + try: + result, _ = self.cache[key, timestamp] + except CacheInvalidError: + with self.cache: + result = self.func(*args, **kwargs) + expires = int(timestamp + self.maxage) + self.cache[key] = result, expires + return result + + def __get__(self, obj, objtype): + """Support instance methods.""" + return functools.partial(self.__call__, obj) + + +def build_cache_decorator(*modules): + if len(modules) > 1: + module = CacheChain(modules) + else: + module = modules[0] + def decorator(maxage=3600, keyarg=None): + def wrap(func): + return CacheDecorator(func, module, maxage, keyarg) + return wrap + return decorator + + +MEMCACHE = MemoryCache() +memcache = build_cache_decorator(MEMCACHE) + +try: + DBCACHE = DatabaseCache() + cache = build_cache_decorator(MEMCACHE, DBCACHE) +except RuntimeError(): + DBCACHE = None + cache = memcache diff --git a/test/test_extractors.py b/test/test_extractors.py index f75a05bb..bdb68d67 100644 --- a/test/test_extractors.py +++ b/test/test_extractors.py @@ -15,7 +15,6 @@ class TestExtractors(unittest.TestCase): def setUp(self): config.load() config.set(("cache", "file"), ":memory:") - cache.init_database() def run_test(self, extr, url, result): hjob = jobs.HashJob(url, "content" in result)