diff --git a/gallery_dl/util.py b/gallery_dl/util.py index fc3f909b..26de97a0 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -15,6 +15,7 @@ import time import base64 import random import string +import _string import hashlib import urllib.parse from . import text, exception @@ -218,6 +219,90 @@ class ChainPredicate(): return True +class Formatter(): + """Custom, trimmed-down version of string.Formatter + + This string formatter implementation is a mostly performance-optimized + variant of the original string.Formatter class. Unnecessary features have + been removed (positional arguments, unused argument check) and new + formatting options have been added. + + Extra Conversions: + - "l": calls str.lower on the target value + - "u": calls str.upper + - "c": calls str.capitalize + - "C": calls string.capwords + - Example: {f!l} -> "example"; {f!u} -> "EXAMPLE" + + Extra Format Specifiers: + - "?//": + Adds and to the actual value if it evaluates to True. + Otherwise the whole replacement field beomes an empty string. + Example: {f:?-+/+-/} -> "-+Example+-" (if "f" contains "Example") + -> "" (if "f" is None, 0, "") + """ + conversions = { + "l": str.lower, + "u": str.upper, + "c": str.capitalize, + "C": string.capwords, + "s": str, + "r": repr, + "a": ascii, + } + + def __init__(self, format_string): + self.formatter_rules = tuple(_string.formatter_parser(format_string)) + + def format_map(self, kwargs): + """Apply 'kwargs' to the initial format_string and return its result""" + result = [] + append = result.append + + for literal_text, field_name, format_spec, conversion in \ + self.formatter_rules: + + if literal_text: + append(literal_text) + + if field_name: + obj = self.get_field(field_name, kwargs) + if conversion: + obj = self.conversions[conversion](obj) + if format_spec: + format_spec = format_spec.format_map(kwargs) + obj = self.format_field(obj, format_spec) + else: + obj = str(obj) + append(obj) + + return "".join(result) + + @staticmethod + def format_field(value, format_spec): + """Format 'value' according to 'format_spec'""" + if format_spec[0] == "?": + if not value: + return "" + before, after, format_spec = format_spec.split("/", 2) + return before[1:] + format(value, format_spec) + after + return format(value, format_spec) + + @staticmethod + def get_field(field_name, kwargs): + """Return value called 'field_name' from 'kwargs'""" + first, rest = _string.formatter_field_name_split(field_name) + + obj = kwargs[first] + for is_attr, i in rest: + if is_attr: + obj = getattr(obj, i) + else: + obj = obj[i] + + return obj + + class PathFormat(): def __init__(self, extractor): @@ -225,6 +310,7 @@ class PathFormat(): "filename", extractor.filename_fmt) self.directory_fmt = extractor.config( "directory", extractor.directory_fmt) + self.filename_formatter = Formatter(self.filename_fmt) self.has_extension = False self.keywords = {} self.directory = self.realdirectory = "" @@ -257,7 +343,8 @@ class PathFormat(): """Build directory path and create it if necessary""" try: segments = [ - text.clean_path(segment.format_map(keywords).strip()) + text.clean_path( + Formatter(segment).format_map(keywords).strip()) for segment in self.directory_fmt ] except Exception as exc: @@ -287,7 +374,7 @@ class PathFormat(): """Use filename-keywords and directory to build a full path""" try: filename = text.clean_path( - self.filename_fmt.format_map(self.keywords)) + self.filename_formatter.format_map(self.keywords)) except Exception as exc: raise exception.FormatError(exc, "filename") diff --git a/test/test_util.py b/test/test_util.py index 4f1cb343..6e4e22d7 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -134,6 +134,53 @@ class TestISO639_1(unittest.TestCase): self.assertEqual(func(*args), result) +class TestFormatter(unittest.TestCase): + + kwdict = { + "a": "hElLo wOrLd", + "b": "äöü", + "name": "Name", + "title1": "Title", + "title2": "", + "title3": None, + "title4": 0, + } + + def test_conversions(self): + self._run_test("{a!l}", "hello world") + self._run_test("{a!u}", "HELLO WORLD") + self._run_test("{a!c}", "Hello world") + self._run_test("{a!C}", "Hello World") + self._run_test("{a!s}", self.kwdict["a"]) + self._run_test("{a!r}", "'" + self.kwdict["a"] + "'") + self._run_test("{a!a}", "'" + self.kwdict["a"] + "'") + self._run_test("{b!a}", "'\\xe4\\xf6\\xfc'") + with self.assertRaises(KeyError): + self._run_test("{a!q}", "hello world") + + def test_optional(self): + self._run_test("{name}{title1}", "NameTitle") + self._run_test("{name}{title1:?//}", "NameTitle") + self._run_test("{name}{title1:? **/''/}", "Name **Title''") + + self._run_test("{name}{title2}", "Name") + self._run_test("{name}{title2:?//}", "Name") + self._run_test("{name}{title2:? **/''/}", "Name") + + self._run_test("{name}{title3}", "NameNone") + self._run_test("{name}{title3:?//}", "Name") + self._run_test("{name}{title3:? **/''/}", "Name") + + self._run_test("{name}{title4}", "Name0") + self._run_test("{name}{title4:?//}", "Name") + self._run_test("{name}{title4:? **/''/}", "Name") + + def _run_test(self, format_string, result): + formatter = util.Formatter(format_string) + output = formatter.format_map(self.kwdict) + self.assertEqual(output, result, format_string) + + class TestOther(unittest.TestCase): def test_bdecode(self):