diff --git a/gallery_dl/text.py b/gallery_dl/text.py index 2f18a616..7725ea09 100644 --- a/gallery_dl/text.py +++ b/gallery_dl/text.py @@ -209,11 +209,12 @@ def _hex_to_char(match): def parse_bytes(value, default=0, suffixes="bkmgtp"): """Convert a bytes-amount ("500k", "2.5M", ...) to int""" - try: - last = value[-1].lower() - except Exception: + if not value: return default + value = str(value).strip() + last = value[-1].lower() + if last in suffixes: mul = 1024 ** suffixes.index(last) value = value[:-1] diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 1f006a85..9fff88cb 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -898,25 +898,26 @@ def import_file(path): return __import__(name.replace("-", "_")) -def build_duration_func(duration, min=0.0): - if not duration: +def build_selection_func(value, min=0.0, conv=float): + if not value: if min: return lambda: min return None - if isinstance(duration, str): - lower, _, upper = duration.partition("-") - lower = float(lower) + if isinstance(value, str): + lower, _, upper = value.partition("-") + lower = conv(lower) else: try: - lower, upper = duration + lower, upper = value except TypeError: - lower, upper = duration, None + lower, upper = value, None + lower = conv(lower) if upper: - upper = float(upper) + upper = conv(upper) return functools.partial( - random.uniform, + random.uniform if min.__class__ is float else random.randint, lower if lower > min else min, upper if upper > min else min, ) @@ -926,6 +927,9 @@ def build_duration_func(duration, min=0.0): return lambda: lower +build_duration_func = build_selection_func + + def build_extractor_filter(categories, negate=True, special=None): """Build a function that takes an Extractor class as argument and returns True if that class is allowed by 'categories' diff --git a/test/test_text.py b/test/test_text.py index 96656c1e..f20e3d00 100644 --- a/test/test_text.py +++ b/test/test_text.py @@ -359,6 +359,8 @@ class TestText(unittest.TestCase): ) def test_parse_bytes(self, f=text.parse_bytes): + self.assertEqual(f(0), 0) + self.assertEqual(f(50), 50) self.assertEqual(f("0"), 0) self.assertEqual(f("50"), 50) self.assertEqual(f("50k"), 50 * 1024**1) @@ -366,10 +368,13 @@ class TestText(unittest.TestCase): self.assertEqual(f("50g"), 50 * 1024**3) self.assertEqual(f("50t"), 50 * 1024**4) self.assertEqual(f("50p"), 50 * 1024**5) + self.assertEqual(f(" 50p "), 50 * 1024**5) # fractions + self.assertEqual(f(123.456), 123) self.assertEqual(f("123.456"), 123) self.assertEqual(f("123.567"), 124) + self.assertEqual(f(" 123.89 "), 124) self.assertEqual(f("0.5M"), round(0.5 * 1024**2)) # invalid arguments diff --git a/test/test_util.py b/test/test_util.py index 7764114c..4561ce73 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -554,17 +554,21 @@ value = 123 self.assertEqual(module.value, 123) self.assertIs(module.datetime, datetime) - def test_build_duration_func(self, f=util.build_duration_func): + def test_build_selection_func(self, f=util.build_selection_func): - def test_single(df, v): + def test_single(df, v, type=None): for _ in range(10): self.assertEqual(df(), v) + if type is not None: + self.assertIsInstance(df(), type) - def test_range(df, lower, upper): + def test_range(df, lower, upper, type=None): for __ in range(10): v = df() self.assertGreaterEqual(v, lower) self.assertLessEqual(v, upper) + if type is not None: + self.assertIsInstance(v, type) for v in (0, 0.0, "", None, (), []): self.assertIsNone(f(v)) @@ -572,16 +576,24 @@ value = 123 for v in (0, 0.0, "", None, (), []): test_single(f(v, 1.0), 1.0) - test_single(f(3), 3) - test_single(f(3.0), 3.0) - test_single(f("3"), 3) - test_single(f("3.0-"), 3) - test_single(f(" 3 -"), 3) + test_single(f(3) , 3 , float) + test_single(f(3.0) , 3.0, float) + test_single(f("3") , 3 , float) + test_single(f("3.0-") , 3 , float) + test_single(f(" 3 -"), 3 , float) - test_range(f((2, 4)), 2, 4) - test_range(f([2, 4]), 2, 4) - test_range(f("2-4"), 2, 4) - test_range(f(" 2.0 - 4 "), 2, 4) + test_range(f((2, 4)) , 2, 4, float) + test_range(f([2.0, 4.0]) , 2, 4, float) + test_range(f("2-4") , 2, 4, float) + test_range(f(" 2.0 - 4 "), 2, 4, float) + + pb = text.parse_bytes + test_single(f("3", 0, pb) , 3, int) + test_single(f("3.0-", 0, pb) , 3, int) + test_single(f(" 3 -", 0, pb), 3, int) + + test_range(f("2k-4k", 0, pb) , 2048, 4096, int) + test_range(f(" 2.0k - 4k ", 0, pb), 2048, 4096, int) def test_extractor_filter(self): # empty