diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 19f13d63..8341e084 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -118,6 +118,13 @@ def advance(iterable, num): return iterator +def raises(obj): + """Returns a function that raises 'obj' as exception""" + def wrap(): + raise obj + return wrap + + def combine_dict(a, b): """Recursively combine the contents of b into a""" for key, value in b.items(): @@ -249,6 +256,7 @@ class FilterPredicate(): "safe_int": safe_int, "urlsplit": urllib.parse.urlsplit, "datetime": datetime.datetime, + "abort": raises(exception.StopExtraction()), "re": re, } @@ -258,6 +266,8 @@ class FilterPredicate(): def __call__(self, url, kwds): try: return eval(self.codeobj, self.globalsdict, kwds) + except exception.GalleryDLException: + raise except Exception as exc: raise exception.FilterError(exc) diff --git a/test/test_util.py b/test/test_util.py index 219d30d0..7c684d89 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -255,6 +255,14 @@ class TestOther(unittest.TestCase): self.assertCountEqual( util.advance(util.advance(items, 1), 2), range(3, 5)) + def test_raises(self): + self.assertRaises(Exception, util.raises(Exception())) + + func = util.raises(ValueError(1)) + self.assertRaises(Exception, func) + self.assertRaises(Exception, func) + self.assertRaises(Exception, func) + def test_combine_dict(self): self.assertEqual( util.combine_dict({}, {}),