diff --git a/gallery_dl/extractor/tiktok.py b/gallery_dl/extractor/tiktok.py index 8746b006..b03a7e27 100644 --- a/gallery_dl/extractor/tiktok.py +++ b/gallery_dl/extractor/tiktok.py @@ -36,7 +36,7 @@ class TiktokExtractor(Extractor): self.cover = self.config("covers", False) self.range = self.config("tiktok-range") or "" - self.range_predicate = util.RangePredicate(self.range) + self.range_predicate = util.predicate_range(self.range) def items(self): for tiktok_url in self.posts(): diff --git a/gallery_dl/job.py b/gallery_dl/job.py index 6fbd6593..acec6997 100644 --- a/gallery_dl/job.py +++ b/gallery_dl/job.py @@ -274,32 +274,28 @@ class Job(): self.pred_post = self._prepare_predicates("post", False) self.pred_queue = self._prepare_predicates("chapter", False) - def _prepare_predicates(self, target, skip=True): + def _prepare_predicates(self, target, skip): predicates = [] + extr = self.extractor - if self.extractor.config(f"{target}-unique"): - predicates.append(util.UniquePredicate()) + if extr.config(target + "-unique"): + predicates.append(util.predicate_unique()) - if pfilter := self.extractor.config(f"{target}-filter"): + if pfilter := extr.config(target + "-filter"): try: - pred = util.FilterPredicate(pfilter, target) + predicates.append(util.predicate_filter(pfilter, target)) except (SyntaxError, ValueError, TypeError) as exc: - self.extractor.log.warning(exc) - else: - predicates.append(pred) + extr.log.warning(exc) - if prange := self.extractor.config(f"{target}-range"): + if prange := extr.config(target + "-range"): try: - pred = util.RangePredicate(prange) + skip = extr.skip if skip and not pfilter else None + predicates.append(util.predicate_range(prange, skip)) except ValueError as exc: - self.extractor.log.warning( + extr.log.warning( "invalid %s range: %s", target, exc) - else: - if skip and pred.lower > 1 and not pfilter: - pred.index += self.extractor.skip(pred.lower - 1) - predicates.append(pred) - return util.build_predicate(predicates) + return util.predicate_build(predicates) def get_logger(self, name): return self._wrap_logger(logging.getLogger(name)) diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 1ccd3715..e4557086 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -957,115 +957,110 @@ def build_proxy_map(proxies, log=None): return proxies -def build_predicate(predicates): +def predicate_build(predicates): if not predicates: return true - elif len(predicates) == 1: + + if len(predicates) == 1: return predicates[0] - return functools.partial(chain_predicates, predicates) + + def chain(url, kwdict): + for pred in predicates: + if not pred(url, kwdict): + return False + return True + return chain -def chain_predicates(predicates, url, kwdict): - for pred in predicates: - if not pred(url, kwdict): - return False - return True - - -class RangePredicate(): - """Predicate; True if the current index is in the given range(s)""" - - def __init__(self, rangespec): - self.ranges = ranges = self._parse(rangespec) - self.index = 0 - - if ranges: - # technically wrong, but good enough for now - # and evaluating min/max for a large range is slow - self.lower = min(r.start for r in ranges) - self.upper = max(r.stop for r in ranges) - 1 - else: - self.lower = 0 - self.upper = 0 - - def __call__(self, _url, _kwdict): - self.index = index = self.index + 1 - - if index > self.upper: - raise exception.StopExtraction() - - for range in self.ranges: - if index in range: - return True - return False - - def _parse(self, rangespec): - """Parse an integer range string and return the resulting ranges - - Examples: - _parse("-2,4,6-8,10-") -> [(1,3), (4,5), (6,9), (10,INTMAX)] - _parse(" - 3 , 4- 4, 2-6") -> [(1,4), (4,5), (2,7)] - _parse("1:2,4:8:2") -> [(1,1), (4,7,2)] - """ - ranges = [] - - if isinstance(rangespec, str): - rangespec = rangespec.split(",") - elif isinstance(rangespec, int): - rangespec = (str(rangespec),) - - for group in rangespec: - if not group: - continue - - elif ":" in group: - start, _, stop = group.partition(":") - stop, _, step = stop.partition(":") - ranges.append(range( - int(start) if start.strip() else 1, - int(stop) if stop.strip() else sys.maxsize, - int(step) if step.strip() else 1, - )) - - elif "-" in group: - start, _, stop = group.partition("-") - ranges.append(range( - int(start) if start.strip() else 1, - int(stop) + 1 if stop.strip() else sys.maxsize, - )) - - else: - start = int(group) - ranges.append(range(start, start+1)) - - return ranges - - -class UniquePredicate(): +def predicate_unique(): """Predicate; True if given URL has not been encountered before""" - def __init__(self): - self.urls = set() - - def __call__(self, url, _): + def _pred(url, _): if url.startswith("text:"): return True - if url not in self.urls: - self.urls.add(url) + if url not in urls: + urls.add(url) return True return False + urls = set() + return _pred -class FilterPredicate(): +def predicate_filter(expr, target="image"): """Predicate; True if evaluating the given expression returns True""" - - def __init__(self, expr, target="image"): - name = f"<{target} filter>" - self.expr = compile_filter(expr, name) - - def __call__(self, _, kwdict): + def _pred(_, kwdict): try: - return self.expr(kwdict) + return expr(kwdict) except exception.GalleryDLException: raise except Exception as exc: raise exception.FilterError(exc) + expr = compile_filter(expr, f"<{target} filter>") + return _pred + + +def predicate_range(ranges, skip=None): + """Predicate; True if the current index is in the given range(s)""" + if ranges := predicate_range_parse(ranges): + # technically wrong for 'step > 2', but good enough for now + # and evaluating min/max for a large range is slow + upper = max(r.stop for r in ranges) - 1 + lower = min(r.start for r in ranges) + index = 0 if skip is None or lower <= 1 else skip(lower) + del lower + else: + index = upper = 0 + + def _pred(_url, _kwdict): + nonlocal index + + if index >= upper: + raise exception.StopExtraction() + index += 1 + + for range in ranges: + if index in range: + return True + return False + return _pred + + +def predicate_range_parse(rangespec): + """Parse an integer range string and return the resulting ranges + + Examples: + _parse("-2,4,6-8,10-") -> [(1,3), (4,5), (6,9), (10,INTMAX)] + _parse(" - 3 , 4- 4, 2-6") -> [(1,4), (4,5), (2,7)] + _parse("1:2,4:8:2") -> [(1,1), (4,7,2)] + """ + ranges = [] + + if isinstance(rangespec, str): + rangespec = rangespec.split(",") + elif isinstance(rangespec, int): + rangespec = (str(rangespec),) + + for group in rangespec: + if not group: + continue + + elif ":" in group: + start, _, stop = group.partition(":") + stop, _, step = stop.partition(":") + ranges.append(range( + int(start) if start.strip() else 1, + int(stop) if stop.strip() else sys.maxsize, + int(step) if step.strip() else 1, + )) + + elif "-" in group: + start, _, stop = group.partition("-") + ranges.append(range( + int(start) if start.strip() else 1, + int(stop) + 1 if stop.strip() else sys.maxsize, + )) + + else: + start = int(group) + ranges.append(range(start, start+1)) + + return ranges diff --git a/test/test_util.py b/test/test_util.py index 67848748..e0a83090 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -28,17 +28,14 @@ from gallery_dl import util, text, exception # noqa E402 class TestRange(unittest.TestCase): - def setUp(self): - self.predicate = util.RangePredicate("") - def test_parse_empty(self): - f = self.predicate._parse + f = util.predicate_range_parse self.assertEqual(f(""), []) self.assertEqual(f([]), []) def test_parse_digit(self): - f = self.predicate._parse + f = util.predicate_range_parse self.assertEqual(f(2), [range(2, 3)]) self.assertEqual(f("2"), [range(2, 3)]) @@ -57,7 +54,7 @@ class TestRange(unittest.TestCase): ) def test_parse_range(self): - f = self.predicate._parse + f = util.predicate_range_parse self.assertEqual(f("1-2"), [range(1, 3)]) self.assertEqual(f("2-"), [range(2, sys.maxsize)]) @@ -79,7 +76,7 @@ class TestRange(unittest.TestCase): ) def test_parse_slice(self): - f = self.predicate._parse + f = util.predicate_range_parse self.assertEqual(f("2:4") , [range(2, 4)]) self.assertEqual(f("3::") , [range(3, sys.maxsize)]) @@ -106,16 +103,16 @@ class TestRange(unittest.TestCase): class TestPredicate(unittest.TestCase): - def test_range_predicate(self): + def test_predicate_range(self): dummy = None - pred = util.RangePredicate(" - 3 , 4- 4, 2-6") + pred = util.predicate_range(" - 3 , 4- 4, 2-6") for i in range(6): self.assertTrue(pred(dummy, dummy)) with self.assertRaises(exception.StopExtraction): pred(dummy, dummy) - pred = util.RangePredicate("1, 3, 5") + pred = util.predicate_range("1, 3, 5") self.assertTrue(pred(dummy, dummy)) self.assertFalse(pred(dummy, dummy)) self.assertTrue(pred(dummy, dummy)) @@ -124,13 +121,13 @@ class TestPredicate(unittest.TestCase): with self.assertRaises(exception.StopExtraction): pred(dummy, dummy) - pred = util.RangePredicate("") + pred = util.predicate_range("") with self.assertRaises(exception.StopExtraction): pred(dummy, dummy) - def test_unique_predicate(self): + def test_predicate_unique(self): dummy = None - pred = util.UniquePredicate() + pred = util.predicate_unique() # no duplicates self.assertTrue(pred("1", dummy)) @@ -145,22 +142,22 @@ class TestPredicate(unittest.TestCase): self.assertTrue(pred("text:123", dummy)) self.assertTrue(pred("text:123", dummy)) - def test_filter_predicate(self): + def test_predicate_filter(self): url = "" - pred = util.FilterPredicate("a < 3") + pred = util.predicate_filter("a < 3") self.assertTrue(pred(url, {"a": 2})) self.assertFalse(pred(url, {"a": 3})) with self.assertRaises(SyntaxError): - util.FilterPredicate("(") + util.predicate_filter("(") self.assertFalse( - util.FilterPredicate("a > 1")(url, {"a": None})) + util.predicate_filter("a > 1")(url, {"a": None})) self.assertFalse( - util.FilterPredicate("b > 1")(url, {"a": 2})) + util.predicate_filter("b > 1")(url, {"a": 2})) - pred = util.FilterPredicate(["a < 3", "b < 4", "c < 5"]) + pred = util.predicate_filter(["a < 3", "b < 4", "c < 5"]) self.assertTrue(pred(url, {"a": 2, "b": 3, "c": 4})) self.assertFalse(pred(url, {"a": 3, "b": 3, "c": 4})) self.assertFalse(pred(url, {"a": 2, "b": 4, "c": 4})) @@ -168,20 +165,23 @@ class TestPredicate(unittest.TestCase): self.assertFalse(pred(url, {"a": 2})) - pred = util.FilterPredicate("re.search(r'.+', url)") + pred = util.predicate_filter("re.search(r'.+', url)") self.assertTrue(pred(url, {"url": "https://example.org/"})) self.assertFalse(pred(url, {"url": ""})) - def test_build_predicate(self): - pred = util.build_predicate([]) + def test_predicate_build(self): + pred = util.predicate_build([]) self.assertIsInstance(pred, type(lambda: True)) - pred = util.build_predicate([util.UniquePredicate()]) - self.assertIsInstance(pred, util.UniquePredicate) + pred = util.predicate_build([util.predicate_unique()]) + self.assertTrue(callable(pred)) + self.assertIn("predicate_unique.", repr(pred)) - pred = util.build_predicate([util.UniquePredicate(), - util.UniquePredicate()]) - self.assertIs(pred.func, util.chain_predicates) + pred = util.predicate_build([util.predicate_unique(), + util.predicate_unique()]) + self.assertTrue(callable(pred)) + self.assertIn("predicate_build.", repr(pred)) + self.assertIn(".chain", repr(pred)) class TestISO639_1(unittest.TestCase):