[util] generalize 'build_duration_func'

This commit is contained in:
Mike Fährmann
2025-06-08 17:29:15 +02:00
parent cc48cddf68
commit e84df260c0
4 changed files with 46 additions and 24 deletions

View File

@@ -209,11 +209,12 @@ def _hex_to_char(match):
def parse_bytes(value, default=0, suffixes="bkmgtp"):
"""Convert a bytes-amount ("500k", "2.5M", ...) to int"""
try:
last = value[-1].lower()
except Exception:
if not value:
return default
value = str(value).strip()
last = value[-1].lower()
if last in suffixes:
mul = 1024 ** suffixes.index(last)
value = value[:-1]

View File

@@ -898,25 +898,26 @@ def import_file(path):
return __import__(name.replace("-", "_"))
def build_duration_func(duration, min=0.0):
if not duration:
def build_selection_func(value, min=0.0, conv=float):
if not value:
if min:
return lambda: min
return None
if isinstance(duration, str):
lower, _, upper = duration.partition("-")
lower = float(lower)
if isinstance(value, str):
lower, _, upper = value.partition("-")
lower = conv(lower)
else:
try:
lower, upper = duration
lower, upper = value
except TypeError:
lower, upper = duration, None
lower, upper = value, None
lower = conv(lower)
if upper:
upper = float(upper)
upper = conv(upper)
return functools.partial(
random.uniform,
random.uniform if min.__class__ is float else random.randint,
lower if lower > min else min,
upper if upper > min else min,
)
@@ -926,6 +927,9 @@ def build_duration_func(duration, min=0.0):
return lambda: lower
build_duration_func = build_selection_func
def build_extractor_filter(categories, negate=True, special=None):
"""Build a function that takes an Extractor class as argument
and returns True if that class is allowed by 'categories'

View File

@@ -359,6 +359,8 @@ class TestText(unittest.TestCase):
)
def test_parse_bytes(self, f=text.parse_bytes):
self.assertEqual(f(0), 0)
self.assertEqual(f(50), 50)
self.assertEqual(f("0"), 0)
self.assertEqual(f("50"), 50)
self.assertEqual(f("50k"), 50 * 1024**1)
@@ -366,10 +368,13 @@ class TestText(unittest.TestCase):
self.assertEqual(f("50g"), 50 * 1024**3)
self.assertEqual(f("50t"), 50 * 1024**4)
self.assertEqual(f("50p"), 50 * 1024**5)
self.assertEqual(f(" 50p "), 50 * 1024**5)
# fractions
self.assertEqual(f(123.456), 123)
self.assertEqual(f("123.456"), 123)
self.assertEqual(f("123.567"), 124)
self.assertEqual(f(" 123.89 "), 124)
self.assertEqual(f("0.5M"), round(0.5 * 1024**2))
# invalid arguments

View File

@@ -554,17 +554,21 @@ value = 123
self.assertEqual(module.value, 123)
self.assertIs(module.datetime, datetime)
def test_build_duration_func(self, f=util.build_duration_func):
def test_build_selection_func(self, f=util.build_selection_func):
def test_single(df, v):
def test_single(df, v, type=None):
for _ in range(10):
self.assertEqual(df(), v)
if type is not None:
self.assertIsInstance(df(), type)
def test_range(df, lower, upper):
def test_range(df, lower, upper, type=None):
for __ in range(10):
v = df()
self.assertGreaterEqual(v, lower)
self.assertLessEqual(v, upper)
if type is not None:
self.assertIsInstance(v, type)
for v in (0, 0.0, "", None, (), []):
self.assertIsNone(f(v))
@@ -572,16 +576,24 @@ value = 123
for v in (0, 0.0, "", None, (), []):
test_single(f(v, 1.0), 1.0)
test_single(f(3), 3)
test_single(f(3.0), 3.0)
test_single(f("3"), 3)
test_single(f("3.0-"), 3)
test_single(f(" 3 -"), 3)
test_single(f(3) , 3 , float)
test_single(f(3.0) , 3.0, float)
test_single(f("3") , 3 , float)
test_single(f("3.0-") , 3 , float)
test_single(f(" 3 -"), 3 , float)
test_range(f((2, 4)), 2, 4)
test_range(f([2, 4]), 2, 4)
test_range(f("2-4"), 2, 4)
test_range(f(" 2.0 - 4 "), 2, 4)
test_range(f((2, 4)) , 2, 4, float)
test_range(f([2.0, 4.0]) , 2, 4, float)
test_range(f("2-4") , 2, 4, float)
test_range(f(" 2.0 - 4 "), 2, 4, float)
pb = text.parse_bytes
test_single(f("3", 0, pb) , 3, int)
test_single(f("3.0-", 0, pb) , 3, int)
test_single(f(" 3 -", 0, pb), 3, int)
test_range(f("2k-4k", 0, pb) , 2048, 4096, int)
test_range(f(" 2.0k - 4k ", 0, pb), 2048, 4096, int)
def test_extractor_filter(self):
# empty