[util] add "defaultdict" filters-environment

allows accessing undefined values without raising an exception,
but preserves other errors like TypeError, AttributeError, etc
This commit is contained in:
Mike Fährmann
2024-11-12 21:21:58 +01:00
parent cfe24a9e31
commit 0b99d9e6b9
4 changed files with 118 additions and 62 deletions

View File

@@ -6855,13 +6855,22 @@ Description
filters-environment filters-environment
------------------- -------------------
Type Type
``bool`` * ``bool``
* ``string``
Default Default
``true`` ``true``
Description Description
Evaluate filter expressions raising an exception as ``false`` Evaluate filter expressions in a special environment
instead of aborting the current extractor run preventing them from raising fatal exceptions.
by wrapping them in a `try`/`except` block.
``true`` or ``"tryexcept"``:
Wrap expressions in a `try`/`except` block;
Evaluate expressions raising an exception as ``false``
``false`` or ``"raw"``:
Do not wrap expressions in a special environment
``"defaultdict"``:
Prevent exceptions when accessing undefined variables
by using a `defaultdict <https://docs.python.org/3/library/collections.html#collections.defaultdict>`__
format-separator format-separator

View File

@@ -107,8 +107,15 @@ def main():
# filter environment # filter environment
filterenv = config.get((), "filters-environment", True) filterenv = config.get((), "filters-environment", True)
if not filterenv: if filterenv is True:
pass
elif not filterenv:
util.compile_expression = util.compile_expression_raw util.compile_expression = util.compile_expression_raw
elif isinstance(filterenv, str):
if filterenv == "raw":
util.compile_expression = util.compile_expression_raw
elif filterenv.startswith("default"):
util.compile_expression = util.compile_expression_defaultdict
# format string separator # format string separator
separator = config.get((), "format-separator") separator = config.get((), "format-separator")

View File

@@ -21,6 +21,7 @@ import datetime
import functools import functools
import itertools import itertools
import subprocess import subprocess
import collections
import urllib.parse import urllib.parse
from http.cookiejar import Cookie from http.cookiejar import Cookie
from email.utils import mktime_tz, parsedate_tz from email.utils import mktime_tz, parsedate_tz
@@ -702,6 +703,20 @@ def compile_expression_raw(expr, name="<expr>", globals=None):
return functools.partial(eval, code_object, globals or GLOBALS) return functools.partial(eval, code_object, globals or GLOBALS)
def compile_expression_defaultdict(expr, name="<expr>", globals=None):
global GLOBALS_DEFAULT
GLOBALS_DEFAULT = collections.defaultdict(lambda: NONE, GLOBALS)
global compile_expression_defaultdict
compile_expression_defaultdict = compile_expression_defaultdict_impl
return compile_expression_defaultdict_impl(expr, name, globals)
def compile_expression_defaultdict_impl(expr, name="<expr>", globals=None):
code_object = compile(expr, name, "eval")
return functools.partial(eval, code_object, globals or GLOBALS_DEFAULT)
def compile_expression_tryexcept(expr, name="<expr>", globals=None): def compile_expression_tryexcept(expr, name="<expr>", globals=None):
code_object = compile(expr, name, "eval") code_object = compile(expr, name, "eval")
@@ -711,7 +726,7 @@ def compile_expression_tryexcept(expr, name="<expr>", globals=None):
except exception.GalleryDLException: except exception.GalleryDLException:
raise raise
except Exception: except Exception:
return False return NONE
return _eval return _eval

View File

@@ -300,6 +300,87 @@ class TestCookiesTxt(unittest.TestCase):
) )
class TestCompileExpression(unittest.TestCase):
def test_compile_expression(self):
expr = util.compile_expression("1 + 2 * 3")
self.assertEqual(expr(), 7)
self.assertEqual(expr({"a": 1, "b": 2, "c": 3}), 7)
self.assertEqual(expr({"a": 9, "b": 9, "c": 9}), 7)
expr = util.compile_expression("a + b * c")
self.assertEqual(expr({"a": 1, "b": 2, "c": 3}), 7)
self.assertEqual(expr({"a": 9, "b": 9, "c": 9}), 90)
with self.assertRaises(SyntaxError):
util.compile_expression("")
with self.assertRaises(SyntaxError):
util.compile_expression("x++")
expr = util.compile_expression("1 and abort()")
with self.assertRaises(exception.StopExtraction):
expr()
def test_compile_expression_raw(self):
expr = util.compile_expression_raw("a + b * c")
with self.assertRaises(NameError):
expr()
with self.assertRaises(NameError):
expr({"a": 2})
expr = util.compile_expression_defaultdict("int.param")
with self.assertRaises(AttributeError):
expr({"a": 2})
def test_compile_expression_tryexcept(self):
expr = util.compile_expression_tryexcept("a + b * c")
self.assertIs(expr(), util.NONE)
self.assertIs(expr({"a": 2}), util.NONE)
expr = util.compile_expression_tryexcept("int.param")
self.assertIs(expr({"a": 2}), util.NONE)
def test_compile_expression_defaultdict(self):
expr = util.compile_expression_defaultdict("a + b * c")
self.assertIs(expr(), util.NONE)
self.assertIs(expr({"a": 2}), util.NONE)
expr = util.compile_expression_defaultdict("int.param")
with self.assertRaises(AttributeError):
expr({"a": 2})
def test_custom_globals(self):
value = {"v": "foobar"}
result = "8843d7f92416211de9ebb963ff4ce28125932878"
expr = util.compile_expression("hash_sha1(v)")
self.assertEqual(expr(value), result)
expr = util.compile_expression("hs(v)", globals={"hs": util.sha1})
self.assertEqual(expr(value), result)
with tempfile.TemporaryDirectory() as path:
file = path + "/module_sha1.py"
with open(file, "w") as fp:
fp.write("""
import hashlib
def hash(value):
return hashlib.sha1(value.encode()).hexdigest()
""")
module = util.import_file(file)
expr = util.compile_expression("hash(v)", globals=module.__dict__)
self.assertEqual(expr(value), result)
GLOBALS_ORIG = util.GLOBALS
try:
util.GLOBALS = module.__dict__
expr = util.compile_expression("hash(v)")
finally:
util.GLOBALS = GLOBALS_ORIG
self.assertEqual(expr(value), result)
class TestOther(unittest.TestCase): class TestOther(unittest.TestCase):
def test_bencode(self): def test_bencode(self):
@@ -434,31 +515,6 @@ class TestOther(unittest.TestCase):
self.assertEqual(util.sha1(None), self.assertEqual(util.sha1(None),
"da39a3ee5e6b4b0d3255bfef95601890afd80709") "da39a3ee5e6b4b0d3255bfef95601890afd80709")
def test_compile_expression(self):
expr = util.compile_expression("1 + 2 * 3")
self.assertEqual(expr(), 7)
self.assertEqual(expr({"a": 1, "b": 2, "c": 3}), 7)
self.assertEqual(expr({"a": 9, "b": 9, "c": 9}), 7)
expr = util.compile_expression("a + b * c")
self.assertEqual(expr({"a": 1, "b": 2, "c": 3}), 7)
self.assertEqual(expr({"a": 9, "b": 9, "c": 9}), 90)
expr = util.compile_expression_raw("a + b * c")
with self.assertRaises(NameError):
expr()
with self.assertRaises(NameError):
expr({"a": 2})
with self.assertRaises(SyntaxError):
util.compile_expression("")
with self.assertRaises(SyntaxError):
util.compile_expression("x++")
expr = util.compile_expression("1 and abort()")
with self.assertRaises(exception.StopExtraction):
expr()
def test_import_file(self): def test_import_file(self):
module = util.import_file("datetime") module = util.import_file("datetime")
self.assertIs(module, datetime) self.assertIs(module, datetime)
@@ -478,37 +534,6 @@ value = 123
self.assertEqual(module.value, 123) self.assertEqual(module.value, 123)
self.assertIs(module.datetime, datetime) self.assertIs(module.datetime, datetime)
def test_custom_globals(self):
value = {"v": "foobar"}
result = "8843d7f92416211de9ebb963ff4ce28125932878"
expr = util.compile_expression("hash_sha1(v)")
self.assertEqual(expr(value), result)
expr = util.compile_expression("hs(v)", globals={"hs": util.sha1})
self.assertEqual(expr(value), result)
with tempfile.TemporaryDirectory() as path:
file = path + "/module_sha1.py"
with open(file, "w") as fp:
fp.write("""
import hashlib
def hash(value):
return hashlib.sha1(value.encode()).hexdigest()
""")
module = util.import_file(file)
expr = util.compile_expression("hash(v)", globals=module.__dict__)
self.assertEqual(expr(value), result)
GLOBALS_ORIG = util.GLOBALS
try:
util.GLOBALS = module.__dict__
expr = util.compile_expression("hash(v)")
finally:
util.GLOBALS = GLOBALS_ORIG
self.assertEqual(expr(value), result)
def test_build_duration_func(self, f=util.build_duration_func): def test_build_duration_func(self, f=util.build_duration_func):
def test_single(df, v): def test_single(df, v):