rework post processor callbacks

This commit is contained in:
Mike Fährmann
2020-11-18 17:11:55 +01:00
parent d6986be8b0
commit 9fffa9c343
10 changed files with 109 additions and 121 deletions

View File

@@ -10,6 +10,7 @@ import sys
import time import time
import errno import errno
import logging import logging
import collections
from . import extractor, downloader, postprocessor from . import extractor, downloader, postprocessor
from . import config, text, util, output, exception from . import config, text, util, output, exception
from .extractor.message import Message from .extractor.message import Message
@@ -193,8 +194,8 @@ class DownloadJob(Job):
self.blacklist = None self.blacklist = None
self.archive = None self.archive = None
self.sleep = None self.sleep = None
self.hooks = None
self.downloaders = {} self.downloaders = {}
self.postprocessors = None
self.out = output.select() self.out = output.select()
if parent: if parent:
@@ -207,16 +208,16 @@ class DownloadJob(Job):
def handle_url(self, url, kwdict): def handle_url(self, url, kwdict):
"""Download the resource specified in 'url'""" """Download the resource specified in 'url'"""
postprocessors = self.postprocessors hooks = self.hooks
pathfmt = self.pathfmt pathfmt = self.pathfmt
archive = self.archive archive = self.archive
# prepare download # prepare download
pathfmt.set_filename(kwdict) pathfmt.set_filename(kwdict)
if postprocessors: if "prepare" in hooks:
for pp in postprocessors: for callback in hooks["prepare"]:
pp.prepare(pathfmt) callback(pathfmt)
if archive and archive.check(kwdict): if archive and archive.check(kwdict):
pathfmt.fix_extension() pathfmt.fix_extension()
@@ -255,19 +256,19 @@ class DownloadJob(Job):
return return
# run post processors # run post processors
if postprocessors: if "file" in hooks:
for pp in postprocessors: for callback in hooks["file"]:
pp.run(pathfmt) callback(pathfmt)
# download succeeded # download succeeded
pathfmt.finalize() pathfmt.finalize()
self.out.success(pathfmt.path, 0) self.out.success(pathfmt.path, 0)
self._skipcnt = 0
if archive: if archive:
archive.add(kwdict) archive.add(kwdict)
if postprocessors: if "after" in hooks:
for pp in postprocessors: for callback in hooks["after"]:
pp.run_after(pathfmt) callback(pathfmt)
self._skipcnt = 0
def handle_directory(self, kwdict): def handle_directory(self, kwdict):
"""Set and create the target directory for downloads""" """Set and create the target directory for downloads"""
@@ -275,17 +276,18 @@ class DownloadJob(Job):
self.initialize(kwdict) self.initialize(kwdict)
else: else:
self.pathfmt.set_directory(kwdict) self.pathfmt.set_directory(kwdict)
if "post" in self.hooks:
for callback in self.hooks["post"]:
callback(self.pathfmt)
def handle_metadata(self, kwdict): def handle_metadata(self, kwdict):
"""Run postprocessors with metadata from 'kwdict'""" """Run postprocessors with metadata from 'kwdict'"""
postprocessors = self.postprocessors if "metadata" in self.hooks:
if postprocessors:
kwdict["extension"] = "metadata" kwdict["extension"] = "metadata"
pathfmt = self.pathfmt pathfmt = self.pathfmt
pathfmt.set_filename(kwdict) pathfmt.set_filename(kwdict)
for pp in postprocessors: for callback in self.hooks["metadata"]:
pp.run_metadata(pathfmt) callback(pathfmt)
def handle_queue(self, url, kwdict): def handle_queue(self, url, kwdict):
if url in self.visited: if url in self.visited:
@@ -313,13 +315,17 @@ class DownloadJob(Job):
self.archive.close() self.archive.close()
if pathfmt: if pathfmt:
self.extractor._store_cookies() self.extractor._store_cookies()
if self.postprocessors: if "finalize" in self.hooks:
status = self.status status = self.status
for pp in self.postprocessors: for callback in self.hooks["finalize"]:
pp.run_final(pathfmt, status) callback(pathfmt, status)
def handle_skip(self): def handle_skip(self):
self.out.skip(self.pathfmt.path) pathfmt = self.pathfmt
self.out.skip(pathfmt.path)
if "skip" in self.hooks:
for callback in self.hooks["skip"]:
callback(pathfmt)
if self._skipexc: if self._skipexc:
self._skipcnt += 1 self._skipcnt += 1
if self._skipcnt >= self._skipmax: if self._skipcnt >= self._skipmax:
@@ -407,6 +413,7 @@ class DownloadJob(Job):
postprocessors = self.extractor.config_accumulate("postprocessors") postprocessors = self.extractor.config_accumulate("postprocessors")
if postprocessors: if postprocessors:
self.hooks = collections.defaultdict(list)
pp_log = self.get_logger("postprocessor") pp_log = self.get_logger("postprocessor")
pp_list = [] pp_list = []
category = self.extractor.category category = self.extractor.category
@@ -438,9 +445,11 @@ class DownloadJob(Job):
pp_list.append(pp_obj) pp_list.append(pp_obj)
if pp_list: if pp_list:
self.postprocessors = pp_list
self.extractor.log.debug( self.extractor.log.debug(
"Active postprocessor modules: %s", pp_list) "Active postprocessor modules: %s", pp_list)
if "init" in self.hooks:
for callback in self.hooks["init"]:
callback(pathfmt)
def _build_blacklist(self): def _build_blacklist(self):
wlist = self.extractor.config("whitelist") wlist = self.extractor.config("whitelist")

View File

@@ -32,13 +32,16 @@ class ClassifyPP(PostProcessor):
for ext in exts for ext in exts
} }
job.hooks["prepare"].append(self.prepare)
job.hooks["file"].append(self.move)
def prepare(self, pathfmt): def prepare(self, pathfmt):
ext = pathfmt.extension ext = pathfmt.extension
if ext in self.mapping: if ext in self.mapping:
# set initial paths to enable download skips # set initial paths to enable download skips
self._build_paths(pathfmt, self.mapping[ext]) self._build_paths(pathfmt, self.mapping[ext])
def run(self, pathfmt): def move(self, pathfmt):
ext = pathfmt.extension ext = pathfmt.extension
if ext in self.mapping: if ext in self.mapping:
# rebuild paths in case the filename extension changed # rebuild paths in case the filename extension changed

View File

@@ -16,25 +16,5 @@ class PostProcessor():
name = self.__class__.__name__[:-2].lower() name = self.__class__.__name__[:-2].lower()
self.log = job.get_logger("postprocessor." + name) self.log = job.get_logger("postprocessor." + name)
@staticmethod
def prepare(pathfmt):
"""Update file paths, etc."""
@staticmethod
def run(pathfmt):
"""Execute the postprocessor for a file"""
@staticmethod
def run_metadata(pathfmt):
"""Execute the postprocessor for a file"""
@staticmethod
def run_after(pathfmt):
"""Execute postprocessor after moving a file to its target location"""
@staticmethod
def run_final(pathfmt, status):
"""Postprocessor finalization after all files have been downloaded"""
def __repr__(self): def __repr__(self):
return self.__class__.__name__ return self.__class__.__name__

View File

@@ -16,22 +16,25 @@ class ComparePP(PostProcessor):
def __init__(self, job, options): def __init__(self, job, options):
PostProcessor.__init__(self, job) PostProcessor.__init__(self, job)
if options.get("action") == "enumerate":
self.run = self._run_enumerate
if options.get("shallow"): if options.get("shallow"):
self.compare = self._compare_size self._compare = self._compare_size
job.hooks["file"].append(
self.enumerate
if options.get("action") == "enumerate" else
self.compare
)
def run(self, pathfmt): def compare(self, pathfmt):
try: try:
if self.compare(pathfmt.realpath, pathfmt.temppath): if self._compare(pathfmt.realpath, pathfmt.temppath):
pathfmt.delete = True pathfmt.delete = True
except OSError: except OSError:
pass pass
def _run_enumerate(self, pathfmt): def enumerate(self, pathfmt):
num = 1 num = 1
try: try:
while not self.compare(pathfmt.realpath, pathfmt.temppath): while not self._compare(pathfmt.realpath, pathfmt.temppath):
pathfmt.prefix = str(num) + "." pathfmt.prefix = str(num) + "."
pathfmt.set_extension(pathfmt.extension, False) pathfmt.set_extension(pathfmt.extension, False)
num += 1 num += 1
@@ -39,7 +42,7 @@ class ComparePP(PostProcessor):
except OSError: except OSError:
pass pass
def compare(self, f1, f2): def _compare(self, f1, f2):
return self._compare_size(f1, f2) and self._compare_content(f1, f2) return self._compare_size(f1, f2) and self._compare_content(f1, f2)
@staticmethod @staticmethod

View File

@@ -41,18 +41,13 @@ class ExecPP(PostProcessor):
self.args = [util.Formatter(arg) for arg in args] self.args = [util.Formatter(arg) for arg in args]
self.shell = False self.shell = False
if final:
self.run_after = PostProcessor.run_after
else:
self.run_final = PostProcessor.run_final
if options.get("async", False): if options.get("async", False):
self._exec = self._exec_async self._exec = self._exec_async
def run_after(self, pathfmt): event = "finalize" if final else "after"
self._exec(self._format(pathfmt)) job.hooks[event].append(self.run)
def run_final(self, pathfmt, status): def run(self, pathfmt, status=0):
if status == 0: if status == 0:
self._exec(self._format(pathfmt)) self._exec(self._format(pathfmt))

View File

@@ -48,8 +48,8 @@ class MetadataPP(PostProcessor):
else: else:
self.extension = options.get("extension", ext) self.extension = options.get("extension", ext)
if options.get("bypost"): event = "metadata" if options.get("bypost") else "file"
self.run_metadata, self.run = self.run, self.run_metadata job.hooks[event].append(self.run)
def run(self, pathfmt): def run(self, pathfmt):
path = self._directory(pathfmt) + self._filename(pathfmt) path = self._directory(pathfmt) + self._filename(pathfmt)

View File

@@ -17,6 +17,7 @@ class MtimePP(PostProcessor):
def __init__(self, job, options): def __init__(self, job, options):
PostProcessor.__init__(self, job) PostProcessor.__init__(self, job)
self.key = options.get("key", "date") self.key = options.get("key", "date")
job.hooks["file"].append(self.run)
def run(self, pathfmt): def run(self, pathfmt):
mtime = pathfmt.kwdict.get(self.key) mtime = pathfmt.kwdict.get(self.key)

View File

@@ -49,6 +49,9 @@ class UgoiraPP(PostProcessor):
else: else:
self.prevent_odd = False self.prevent_odd = False
job.hooks["prepare"].append(self.prepare)
job.hooks["file"].append(self.convert)
def prepare(self, pathfmt): def prepare(self, pathfmt):
self._frames = None self._frames = None
@@ -65,7 +68,7 @@ class UgoiraPP(PostProcessor):
if self.delete: if self.delete:
pathfmt.set_extension(self.extension) pathfmt.set_extension(self.extension)
def run(self, pathfmt): def convert(self, pathfmt):
if not self._frames: if not self._frames:
return return

View File

@@ -38,12 +38,11 @@ class ZipPP(PostProcessor):
self.args = (self.path[:-1] + ext, "a", self.args = (self.path[:-1] + ext, "a",
self.COMPRESSION_ALGORITHMS[algorithm], True) self.COMPRESSION_ALGORITHMS[algorithm], True)
if options.get("mode") == "safe": job.hooks["file"].append(
self.run = self._write_safe self.write_safe if options.get("mode") == "safe" else self.write)
else: job.hooks["finalize"].append(self.finalize)
self.run = self._write
def _write(self, pathfmt, zfile=None): def write(self, pathfmt, zfile=None):
# 'NameToInfo' is not officially documented, but it's available # 'NameToInfo' is not officially documented, but it's available
# for all supported Python versions and using it directly is a lot # for all supported Python versions and using it directly is a lot
# faster than calling getinfo() # faster than calling getinfo()
@@ -55,11 +54,11 @@ class ZipPP(PostProcessor):
zfile.write(pathfmt.temppath, pathfmt.filename) zfile.write(pathfmt.temppath, pathfmt.filename)
pathfmt.delete = self.delete pathfmt.delete = self.delete
def _write_safe(self, pathfmt): def write_safe(self, pathfmt):
with zipfile.ZipFile(*self.args) as zfile: with zipfile.ZipFile(*self.args) as zfile:
self._write(pathfmt, zfile) self._write(pathfmt, zfile)
def run_final(self, pathfmt, status): def finalize(self, pathfmt, status):
if self.zfile: if self.zfile:
self.zfile.close() self.zfile.close()

View File

@@ -15,11 +15,12 @@ from unittest.mock import Mock, mock_open, patch
import logging import logging
import zipfile import zipfile
import tempfile import tempfile
import collections
from datetime import datetime, timezone as tz from datetime import datetime, timezone as tz
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gallery_dl import extractor, output, util # noqa E402 from gallery_dl import extractor, output, util # noqa E402
from gallery_dl import postprocessor, util, config # noqa E402 from gallery_dl import postprocessor, config # noqa E402
from gallery_dl.postprocessor.common import PostProcessor # noqa E402 from gallery_dl.postprocessor.common import PostProcessor # noqa E402
@@ -34,6 +35,7 @@ class FakeJob():
self.pathfmt = util.PathFormat(self.extractor) self.pathfmt = util.PathFormat(self.extractor)
self.out = output.NullOutput() self.out = output.NullOutput()
self.get_logger = logging.getLogger self.get_logger = logging.getLogger
self.hooks = collections.defaultdict(list)
class TestPostprocessorModule(unittest.TestCase): class TestPostprocessorModule(unittest.TestCase):
@@ -78,6 +80,9 @@ class BasePostprocessorTest(unittest.TestCase):
cls.dir.cleanup() cls.dir.cleanup()
config.clear() config.clear()
def tearDown(self):
self.job.hooks.clear()
def _create(self, options=None, data=None): def _create(self, options=None, data=None):
kwdict = {"category": "test", "filename": "file", "extension": "ext"} kwdict = {"category": "test", "filename": "file", "extension": "ext"}
if options is None: if options is None:
@@ -92,6 +97,11 @@ class BasePostprocessorTest(unittest.TestCase):
pp = postprocessor.find(self.__class__.__name__[:-4].lower()) pp = postprocessor.find(self.__class__.__name__[:-4].lower())
return pp(self.job, options) return pp(self.job, options)
def _trigger(self, events=None, *args):
for event in (events or ("prepare", "file")):
for callback in self.job.hooks[event]:
callback(self.pathfmt, *args)
class ClassifyTest(BasePostprocessorTest): class ClassifyTest(BasePostprocessorTest):
@@ -111,7 +121,7 @@ class ClassifyTest(BasePostprocessorTest):
self.assertEqual(self.pathfmt.realpath, path + "/file.jpg") self.assertEqual(self.pathfmt.realpath, path + "/file.jpg")
with patch("os.makedirs") as mkdirs: with patch("os.makedirs") as mkdirs:
pp.run(self.pathfmt) self._trigger()
mkdirs.assert_called_once_with(path, exist_ok=True) mkdirs.assert_called_once_with(path, exist_ok=True)
def test_classify_noop(self): def test_classify_noop(self):
@@ -123,7 +133,7 @@ class ClassifyTest(BasePostprocessorTest):
self.assertEqual(self.pathfmt.realpath, rp) self.assertEqual(self.pathfmt.realpath, rp)
with patch("os.makedirs") as mkdirs: with patch("os.makedirs") as mkdirs:
pp.run(self.pathfmt) self._trigger()
self.assertEqual(mkdirs.call_count, 0) self.assertEqual(mkdirs.call_count, 0)
def test_classify_custom(self): def test_classify_custom(self):
@@ -143,7 +153,7 @@ class ClassifyTest(BasePostprocessorTest):
self.assertEqual(self.pathfmt.realpath, path + "/file.foo") self.assertEqual(self.pathfmt.realpath, path + "/file.foo")
with patch("os.makedirs") as mkdirs: with patch("os.makedirs") as mkdirs:
pp.run(self.pathfmt) self._trigger()
mkdirs.assert_called_once_with(path, exist_ok=True) mkdirs.assert_called_once_with(path, exist_ok=True)
@@ -175,8 +185,7 @@ class MetadataTest(BasePostprocessorTest):
self.assertEqual(pp.extension, "JSON") self.assertEqual(pp.extension, "JSON")
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
path = self.pathfmt.realpath + ".JSON" path = self.pathfmt.realpath + ".JSON"
m.assert_called_once_with(path, "w", encoding="utf-8") m.assert_called_once_with(path, "w", encoding="utf-8")
@@ -197,41 +206,37 @@ class MetadataTest(BasePostprocessorTest):
self.assertEqual(pp.extension, "txt") self.assertEqual(pp.extension, "txt")
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
path = self.pathfmt.realpath + ".txt" path = self.pathfmt.realpath + ".txt"
m.assert_called_once_with(path, "w", encoding="utf-8") m.assert_called_once_with(path, "w", encoding="utf-8")
self.assertEqual(self._output(m), "foo\nbar\nbaz\n") self.assertEqual(self._output(m), "foo\nbar\nbaz\n")
def test_metadata_tags_split_1(self): def test_metadata_tags_split_1(self):
pp = self._create( self._create(
{"mode": "tags"}, {"mode": "tags"},
{"tags": "foo, bar, baz"}, {"tags": "foo, bar, baz"},
) )
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
self.assertEqual(self._output(m), "foo\nbar\nbaz\n") self.assertEqual(self._output(m), "foo\nbar\nbaz\n")
def test_metadata_tags_split_2(self): def test_metadata_tags_split_2(self):
pp = self._create( self._create(
{"mode": "tags"}, {"mode": "tags"},
{"tags": "foobar1 foobar2 foobarbaz"}, {"tags": "foobar1 foobar2 foobarbaz"},
) )
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
self.assertEqual(self._output(m), "foobar1\nfoobar2\nfoobarbaz\n") self.assertEqual(self._output(m), "foobar1\nfoobar2\nfoobarbaz\n")
def test_metadata_tags_tagstring(self): def test_metadata_tags_tagstring(self):
pp = self._create( self._create(
{"mode": "tags"}, {"mode": "tags"},
{"tag_string": "foo, bar, baz"}, {"tag_string": "foo, bar, baz"},
) )
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
self.assertEqual(self._output(m), "foo\nbar\nbaz\n") self.assertEqual(self._output(m), "foo\nbar\nbaz\n")
def test_metadata_custom(self): def test_metadata_custom(self):
@@ -242,9 +247,9 @@ class MetadataTest(BasePostprocessorTest):
self.assertTrue(pp.contentfmt) self.assertTrue(pp.contentfmt)
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
self.assertEqual(self._output(m), "bar\nNone\n") self.assertEqual(self._output(m), "bar\nNone\n")
self.job.hooks.clear()
test({"mode": "custom", "content-format": "{foo}\n{missing}\n"}) test({"mode": "custom", "content-format": "{foo}\n{missing}\n"})
test({"mode": "custom", "content-format": ["{foo}", "{missing}"]}) test({"mode": "custom", "content-format": ["{foo}", "{missing}"]})
@@ -259,46 +264,42 @@ class MetadataTest(BasePostprocessorTest):
self.assertEqual(pp._filename, pp._filename_custom) self.assertEqual(pp._filename, pp._filename_custom)
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
path = self.pathfmt.realdirectory + "file.json" path = self.pathfmt.realdirectory + "file.json"
m.assert_called_once_with(path, "w", encoding="utf-8") m.assert_called_once_with(path, "w", encoding="utf-8")
def test_metadata_extfmt_2(self): def test_metadata_extfmt_2(self):
pp = self._create({ self._create({
"extension-format": "{extension!u}-data:{category:Res/ES/}", "extension-format": "{extension!u}-data:{category:Res/ES/}",
}) })
self.pathfmt.prefix = "2." self.pathfmt.prefix = "2."
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
path = self.pathfmt.realdirectory + "file.2.EXT-data:tESt" path = self.pathfmt.realdirectory + "file.2.EXT-data:tESt"
m.assert_called_once_with(path, "w", encoding="utf-8") m.assert_called_once_with(path, "w", encoding="utf-8")
def test_metadata_directory(self): def test_metadata_directory(self):
pp = self._create({ self._create({
"directory": "metadata", "directory": "metadata",
}) })
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
path = self.pathfmt.realdirectory + "metadata/file.ext.json" path = self.pathfmt.realdirectory + "metadata/file.ext.json"
m.assert_called_once_with(path, "w", encoding="utf-8") m.assert_called_once_with(path, "w", encoding="utf-8")
def test_metadata_directory_2(self): def test_metadata_directory_2(self):
pp = self._create({ self._create({
"directory" : "metadata////", "directory" : "metadata////",
"extension-format": "json", "extension-format": "json",
}) })
with patch("builtins.open", mock_open()) as m: with patch("builtins.open", mock_open()) as m:
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
path = self.pathfmt.realdirectory + "metadata/file.json" path = self.pathfmt.realdirectory + "metadata/file.json"
m.assert_called_once_with(path, "w", encoding="utf-8") m.assert_called_once_with(path, "w", encoding="utf-8")
@@ -319,21 +320,18 @@ class MtimeTest(BasePostprocessorTest):
self.assertEqual(pp.key, "date") self.assertEqual(pp.key, "date")
def test_mtime_datetime(self): def test_mtime_datetime(self):
pp = self._create(None, {"date": datetime(1980, 1, 1, tzinfo=tz.utc)}) self._create(None, {"date": datetime(1980, 1, 1, tzinfo=tz.utc)})
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
self.assertEqual(self.pathfmt.kwdict["_mtime"], 315532800) self.assertEqual(self.pathfmt.kwdict["_mtime"], 315532800)
def test_mtime_timestamp(self): def test_mtime_timestamp(self):
pp = self._create(None, {"date": 315532800}) self._create(None, {"date": 315532800})
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
self.assertEqual(self.pathfmt.kwdict["_mtime"], 315532800) self.assertEqual(self.pathfmt.kwdict["_mtime"], 315532800)
def test_mtime_custom(self): def test_mtime_custom(self):
pp = self._create({"key": "foo"}, {"foo": 315532800}) self._create({"key": "foo"}, {"foo": 315532800})
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
self.assertEqual(self.pathfmt.kwdict["_mtime"], 315532800) self.assertEqual(self.pathfmt.kwdict["_mtime"], 315532800)
@@ -341,8 +339,8 @@ class ZipTest(BasePostprocessorTest):
def test_zip_default(self): def test_zip_default(self):
pp = self._create() pp = self._create()
self.assertEqual(self.job.hooks["file"][0], pp.write)
self.assertEqual(pp.path, self.pathfmt.realdirectory) self.assertEqual(pp.path, self.pathfmt.realdirectory)
self.assertEqual(pp.run, pp._write)
self.assertEqual(pp.delete, True) self.assertEqual(pp.delete, True)
self.assertEqual(pp.args, ( self.assertEqual(pp.args, (
pp.path[:-1] + ".zip", "a", zipfile.ZIP_STORED, True, pp.path[:-1] + ".zip", "a", zipfile.ZIP_STORED, True,
@@ -351,8 +349,8 @@ class ZipTest(BasePostprocessorTest):
def test_zip_safe(self): def test_zip_safe(self):
pp = self._create({"mode": "safe"}) pp = self._create({"mode": "safe"})
self.assertEqual(self.job.hooks["file"][0], pp.write_safe)
self.assertEqual(pp.path, self.pathfmt.realdirectory) self.assertEqual(pp.path, self.pathfmt.realdirectory)
self.assertEqual(pp.run, pp._write_safe)
self.assertEqual(pp.delete, True) self.assertEqual(pp.delete, True)
self.assertEqual(pp.args, ( self.assertEqual(pp.args, (
pp.path[:-1] + ".zip", "a", zipfile.ZIP_STORED, True, pp.path[:-1] + ".zip", "a", zipfile.ZIP_STORED, True,
@@ -383,8 +381,7 @@ class ZipTest(BasePostprocessorTest):
self.pathfmt.temppath = file.name self.pathfmt.temppath = file.name
self.pathfmt.filename = name self.pathfmt.filename = name
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
nti = pp.zfile.NameToInfo nti = pp.zfile.NameToInfo
self.assertEqual(len(nti), i+1) self.assertEqual(len(nti), i+1)
@@ -397,12 +394,11 @@ class ZipTest(BasePostprocessorTest):
self.assertIn("file2.ext", nti) self.assertIn("file2.ext", nti)
# write the last file a second time (will be skipped) # write the last file a second time (will be skipped)
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
self.assertEqual(len(pp.zfile.NameToInfo), 3) self.assertEqual(len(pp.zfile.NameToInfo), 3)
# close file # close file
pp.run_final(self.pathfmt, 0) self._trigger(("finalize",), 0)
# reopen to check persistence # reopen to check persistence
with zipfile.ZipFile(pp.zfile.filename) as file: with zipfile.ZipFile(pp.zfile.filename) as file:
@@ -428,14 +424,13 @@ class ZipTest(BasePostprocessorTest):
for i in range(3): for i in range(3):
self.pathfmt.temppath = self.pathfmt.realdirectory + "file.ext" self.pathfmt.temppath = self.pathfmt.realdirectory + "file.ext"
self.pathfmt.filename = "file{}.ext".format(i) self.pathfmt.filename = "file{}.ext".format(i)
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
# write the last file a second time (will be skipped) # write the last file a second time (should be skipped)
pp.prepare(self.pathfmt) self._trigger()
pp.run(self.pathfmt)
pp.run_final(self.pathfmt, 0) # close file
self._trigger(("finalize",), 0)
self.assertEqual(pp.zfile.write.call_count, 3) self.assertEqual(pp.zfile.write.call_count, 3)
for call in pp.zfile.write.call_args_list: for call in pp.zfile.write.call_args_list: