update handling of extractor URL patterns

When loading extractor classes during 'extractor.find(…)', their
'pattern' attribute will be replaced with a compiled version of itself.
This commit is contained in:
Mike Fährmann
2019-02-08 20:08:16 +01:00
parent 6284731107
commit abbd45d0f4
6 changed files with 43 additions and 41 deletions

View File

@@ -93,33 +93,34 @@ modules = [
def find(url): def find(url):
"""Find suitable extractor for the given url""" """Find a suitable extractor for the given URL"""
for pattern, klass in _list_patterns(): for cls in _list_classes():
match = pattern.match(url) match = cls.pattern.match(url)
if match and klass not in _blacklist: if match and cls not in _blacklist:
return klass(match) return cls(match)
return None return None
def add(klass): def add(cls):
"""Add 'klass' to the list of available extractors""" """Add 'cls' to the list of available extractors"""
_cache.append((re.compile(klass.pattern), klass)) cls.pattern = re.compile(cls.pattern)
_cache.append(cls)
return cls
def add_module(module): def add_module(module):
"""Add all extractors in 'module' to the list of available extractors""" """Add all extractors in 'module' to the list of available extractors"""
tuples = [ classes = _get_classes(module)
(re.compile(klass.pattern), klass) for cls in classes:
for klass in _get_classes(module) cls.pattern = re.compile(cls.pattern)
] _cache.extend(classes)
_cache.extend(tuples) return classes
return tuples
def extractors(): def extractors():
"""Yield all available extractor classes""" """Yield all available extractor classes"""
return sorted( return sorted(
set(klass for _, klass in _list_patterns()), _list_classes(),
key=lambda x: x.__name__ key=lambda x: x.__name__
) )
@@ -128,9 +129,9 @@ class blacklist():
"""Context Manager to blacklist extractor modules""" """Context Manager to blacklist extractor modules"""
def __init__(self, categories, extractors=None): def __init__(self, categories, extractors=None):
self.extractors = extractors or [] self.extractors = extractors or []
for _, klass in _list_patterns(): for cls in _list_classes():
if klass.category in categories: if cls.category in categories:
self.extractors.append(klass) self.extractors.append(cls)
def __enter__(self): def __enter__(self):
_blacklist.update(self.extractors) _blacklist.update(self.extractors)
@@ -147,20 +148,19 @@ _blacklist = set()
_module_iter = iter(modules) _module_iter = iter(modules)
def _list_patterns(): def _list_classes():
"""Yield all available (pattern, class) tuples""" """Yield all available extractor classes"""
yield from _cache yield from _cache
for module_name in _module_iter: for module_name in _module_iter:
yield from add_module( module = importlib.import_module("."+module_name, __package__)
importlib.import_module("."+module_name, __package__) yield from add_module(module)
)
def _get_classes(module): def _get_classes(module):
"""Return a list of all extractor classes in a module""" """Return a list of all extractor classes in a module"""
return [ return [
klass for klass in module.__dict__.values() if ( cls for cls in module.__dict__.values() if (
hasattr(klass, "pattern") and klass.__module__ == module.__name__ hasattr(cls, "pattern") and cls.__module__ == module.__name__
) )
] ]

View File

@@ -228,7 +228,7 @@ class PostimgImageExtractor(ImagehostImageExtractor):
class TurboimagehostImageExtractor(ImagehostImageExtractor): class TurboimagehostImageExtractor(ImagehostImageExtractor):
"""Extractor for single images from turboimagehost.com""" """Extractor for single images from www.turboimagehost.com"""
category = "turboimagehost" category = "turboimagehost"
pattern = (r"(?:https?://)?((?:www\.)?turboimagehost\.com" pattern = (r"(?:https?://)?((?:www\.)?turboimagehost\.com"
r"/p/(\d+)/[^/?&#]+\.html)") r"/p/(\d+)/[^/?&#]+\.html)")

View File

@@ -13,7 +13,6 @@ from .. import text, util, extractor, exception
from ..cache import cache from ..cache import cache
import datetime import datetime
import time import time
import re
class RedditExtractor(Extractor): class RedditExtractor(Extractor):
@@ -27,7 +26,7 @@ class RedditExtractor(Extractor):
self._visited = set() self._visited = set()
def items(self): def items(self):
subre = re.compile(RedditSubmissionExtractor.pattern) subre = RedditSubmissionExtractor.pattern
submissions = self.submissions() submissions = self.submissions()
depth = 0 depth = 0

View File

@@ -211,13 +211,15 @@ def get_domain(classes):
if hasattr(cls, "root") and cls.root: if hasattr(cls, "root") and cls.root:
return cls.root + "/" return cls.root + "/"
if hasattr(cls, "test") and cls.test: if hasattr(cls, "https"):
url = cls.test[0][0] scheme = "https" if cls.https else "http"
return url[:url.find("/", 8)+1] domain = cls.__doc__.split()[-1]
return "{}://{}/".format(scheme, domain)
scheme = "http" if hasattr(cls, "https") and not cls.https else "https" test = next(cls._get_tests(), None)
host = cls.__doc__.split()[-1] if test:
return scheme + "://" + host + "/" url = test[0]
return url[:url.find("/", 8)+1]
except (IndexError, AttributeError): except (IndexError, AttributeError):
pass pass
return "" return ""

View File

@@ -6,6 +6,7 @@ import datetime
ROOTDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ROOTDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.realpath(ROOTDIR)) sys.path.insert(0, os.path.realpath(ROOTDIR))
from gallery_dl import extractor, job, config from gallery_dl import extractor, job, config
from test.test_results import setup_test_config from test.test_results import setup_test_config
@@ -19,7 +20,7 @@ tests = [
if hasattr(extr, "test") and extr.test if hasattr(extr, "test") and extr.test
if len(sys.argv) <= 1 or extr.category in sys.argv if len(sys.argv) <= 1 or extr.category in sys.argv
for idx, (url, result) in enumerate(extr.test) for idx, (url, result) in enumerate(extr._get_tests())
if result if result
] ]

View File

@@ -66,10 +66,10 @@ class TestExtractor(unittest.TestCase):
uri = "fake:foobar" uri = "fake:foobar"
self.assertIsNone(extractor.find(uri)) self.assertIsNone(extractor.find(uri))
tuples = extractor.add_module(sys.modules[__name__]) classes = extractor.add_module(sys.modules[__name__])
self.assertEqual(len(tuples), 1) self.assertEqual(len(classes), 1)
self.assertEqual(tuples[0][0].pattern, FakeExtractor.pattern) self.assertEqual(classes[0].pattern, FakeExtractor.pattern)
self.assertEqual(tuples[0][1], FakeExtractor) self.assertEqual(classes[0], FakeExtractor)
self.assertIsInstance(extractor.find(uri), FakeExtractor) self.assertIsInstance(extractor.find(uri), FakeExtractor)
def test_blacklist(self): def test_blacklist(self):
@@ -109,13 +109,13 @@ class TestExtractor(unittest.TestCase):
matches = [] matches = []
# ... and apply all regex patterns to each one # ... and apply all regex patterns to each one
for pattern, extr2 in extractor._cache: for extr2 in extractor._cache:
# skip DirectlinkExtractor pattern if it isn't tested # skip DirectlinkExtractor pattern if it isn't tested
if extr1 != DLExtractor and extr2 == DLExtractor: if extr1 != DLExtractor and extr2 == DLExtractor:
continue continue
match = pattern.match(url) match = extr2.pattern.match(url)
if match: if match:
matches.append(match) matches.append(match)