diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 78663a02..5e22eac7 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -21,6 +21,7 @@ import sqlite3 import binascii import datetime import operator +import functools import itertools import urllib.parse from http.cookiejar import Cookie @@ -346,8 +347,6 @@ CODES = { "zh": "Chinese", } -SPECIAL_EXTRACTORS = {"oauth", "recursive", "test"} - class UniversalNone(): """None-style object that supports more operations than None itself""" @@ -373,6 +372,20 @@ class UniversalNone(): NONE = UniversalNone() WINDOWS = (os.name == "nt") SENTINEL = object() +SPECIAL_EXTRACTORS = {"oauth", "recursive", "test"} +GLOBALS = { + "parse_int": text.parse_int, + "urlsplit" : urllib.parse.urlsplit, + "datetime" : datetime.datetime, + "abort" : raises(exception.StopExtraction), + "terminate": raises(exception.TerminateExtraction), + "re" : re, +} + + +def compile_expression(expr, name="", globals=GLOBALS): + code_object = compile(expr, name, "eval") + return functools.partial(eval, code_object, globals) def build_predicate(predicates): @@ -472,20 +485,13 @@ class UniquePredicate(): class FilterPredicate(): """Predicate; True if evaluating the given expression returns True""" - def __init__(self, filterexpr, target="image"): + def __init__(self, expr, target="image"): name = "<{} filter>".format(target) - self.codeobj = compile(filterexpr, name, "eval") - self.globals = { - "parse_int": text.parse_int, - "urlsplit" : urllib.parse.urlsplit, - "datetime" : datetime.datetime, - "abort" : raises(exception.StopExtraction), - "re" : re, - } + self.expr = compile_expression(expr, name) - def __call__(self, url, kwds): + def __call__(self, _, kwdict): try: - return eval(self.codeobj, self.globals, kwds) + return self.expr(kwdict) except exception.GalleryDLException: raise except Exception as exc: diff --git a/test/test_util.py b/test/test_util.py index e2f5084d..d90d5adc 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -493,6 +493,30 @@ class TestOther(unittest.TestCase): def test_noop(self): self.assertEqual(util.noop(), None) + def test_compile_expression(self): + expr = util.compile_expression("1 + 2 * 3") + self.assertEqual(expr(), 7) + self.assertEqual(expr({"a": 1, "b": 2, "c": 3}), 7) + self.assertEqual(expr({"a": 9, "b": 9, "c": 9}), 7) + + expr = util.compile_expression("a + b * c") + self.assertEqual(expr({"a": 1, "b": 2, "c": 3}), 7) + self.assertEqual(expr({"a": 9, "b": 9, "c": 9}), 90) + + with self.assertRaises(NameError): + expr() + with self.assertRaises(NameError): + expr({"a": 2}) + + with self.assertRaises(SyntaxError): + util.compile_expression("") + with self.assertRaises(SyntaxError): + util.compile_expression("x++") + + expr = util.compile_expression("1 and abort()") + with self.assertRaises(exception.StopExtraction): + expr() + def test_generate_token(self): tokens = set() for _ in range(100):