diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 54e9dae7..98edd74d 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -714,29 +714,33 @@ def chain_predicates(predicates, url, kwdict): class RangePredicate(): - """Predicate; True if the current index is in the given range""" + """Predicate; True if the current index is in the given range(s)""" + def __init__(self, rangespec): - self.ranges = self.optimize_range(self.parse_range(rangespec)) + self.ranges = ranges = self._parse(rangespec) self.index = 0 - if self.ranges: - self.lower, self.upper = self.ranges[0][0], self.ranges[-1][1] + if ranges: + # technically wrong, but good enough for now + # and evaluating min/max for a öarge range is slow + self.lower = min(r.start for r in ranges) + self.upper = max(r.stop for r in ranges) - 1 else: - self.lower, self.upper = 0, 0 + self.lower = self.upper = 0 - def __call__(self, url, _): - self.index += 1 + def __call__(self, _url, _kwdict): + self.index = index = self.index + 1 - if self.index > self.upper: + if index > self.upper: raise exception.StopExtraction() - for lower, upper in self.ranges: - if lower <= self.index <= upper: + for range in self.ranges: + if index in range: return True return False @staticmethod - def parse_range(rangespec): + def _parse(rangespec): """Parse an integer range string and return the resulting ranges Examples: @@ -744,22 +748,29 @@ class RangePredicate(): parse_range(" - 3 , 4- 4, 2-6") -> [(1,3), (4,4), (2,6)] """ ranges = [] + append = ranges.append - for group in rangespec.split(","): + if isinstance(rangespec, str): + rangespec = rangespec.split(",") + + for group in rangespec: if not group: continue + first, sep, last = group.partition("-") - if not sep: - beg = end = int(first) + if sep: + append(range( + int(first) if first.strip() else 1, + int(last) + 1 if last.strip() else sys.maxsize, + )) else: - beg = int(first) if first.strip() else 1 - end = int(last) if last.strip() else sys.maxsize - ranges.append((beg, end) if beg <= end else (end, beg)) + v = int(first) + append(range(v, v+1)) return ranges @staticmethod - def optimize_range(ranges): + def _optimize(ranges): """Simplify/Combine a parsed list of ranges Examples: diff --git a/test/test_util.py b/test/test_util.py index 4de8ce84..24e1c3ef 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -24,39 +24,32 @@ from gallery_dl import util, text, exception # noqa E402 class TestRange(unittest.TestCase): - def test_parse_range(self, f=util.RangePredicate.parse_range): + def test_parse_range(self, f=util.RangePredicate._parse): self.assertEqual( f(""), - []) + [], + ) self.assertEqual( f("1-2"), - [(1, 2)]) + [range(1, 3)], + ) self.assertEqual( f("-"), - [(1, sys.maxsize)]) + [range(1, sys.maxsize)], + ) self.assertEqual( f("-2,4,6-8,10-"), - [(1, 2), (4, 4), (6, 8), (10, sys.maxsize)]) + [range(1, 3), + range(4, 5), + range(6, 9), + range(10, sys.maxsize)], + ) self.assertEqual( f(" - 3 , 4- 4, 2-6"), - [(1, 3), (4, 4), (2, 6)]) - - def test_optimize_range(self, f=util.RangePredicate.optimize_range): - self.assertEqual( - f([]), - []) - self.assertEqual( - f([(2, 4)]), - [(2, 4)]) - self.assertEqual( - f([(2, 4), (6, 8), (10, 12)]), - [(2, 4), (6, 8), (10, 12)]) - self.assertEqual( - f([(2, 4), (4, 6), (5, 8)]), - [(2, 8)]) - self.assertEqual( - f([(1, 1), (2, 2), (3, 6), (8, 9)]), - [(1, 6), (8, 9)]) + [range(1, 4), + range(4, 5), + range(2, 7)], + ) class TestPredicate(unittest.TestCase): @@ -68,7 +61,7 @@ class TestPredicate(unittest.TestCase): for i in range(6): self.assertTrue(pred(dummy, dummy)) with self.assertRaises(exception.StopExtraction): - bool(pred(dummy, dummy)) + pred(dummy, dummy) pred = util.RangePredicate("1, 3, 5") self.assertTrue(pred(dummy, dummy)) @@ -77,11 +70,11 @@ class TestPredicate(unittest.TestCase): self.assertFalse(pred(dummy, dummy)) self.assertTrue(pred(dummy, dummy)) with self.assertRaises(exception.StopExtraction): - bool(pred(dummy, dummy)) + pred(dummy, dummy) pred = util.RangePredicate("") with self.assertRaises(exception.StopExtraction): - bool(pred(dummy, dummy)) + pred(dummy, dummy) def test_unique_predicate(self): dummy = None