diff --git a/scripts/init.py b/scripts/init.py index c92b21a1..7f2e72d7 100755 --- a/scripts/init.py +++ b/scripts/init.py @@ -29,39 +29,43 @@ LICENSE = """\ """ -def init_extractor_module(args): +def init_extractor(args): + category = args.category + + files = [(util.path("test", "results", f"{category}.py"), + generate_test, False)] if args.init_module: - try: - create_extractor_module(args) - except FileExistsError: - LOG.warning("… already present") - except Exception as exc: - LOG.error("%s: %s", exc.__class__.__name__, exc, exc_info=exc) - - if msg := insert_into_modules_list(args): - LOG.warning(msg) - - try: - create_test_results_file(args) - except FileExistsError: - LOG.warning("… already present") - except Exception as exc: - LOG.error("%s: %s", exc.__class__.__name__, exc, exc_info=exc) - + files.append((util.path("gallery_dl", "extractor", f"{category}.py"), + generate_module, False)) + files.append((util.path("gallery_dl", "extractor", "__init__.py"), + insert_into_modules_list, True)) if args.site_name: - if msg := insert_into_supportedsites(args): - LOG.warning(msg) + files.append((util.path("scripts", "supportedsites.py"), + insert_into_supportedsites, True)) + + for path, func, lines in files: + LOG.info(util.trim(path)) + + if lines: + with util.open(path) as fp: + lines = fp.readlines() + if func(args, lines): + with util.lazy(path) as fp: + fp.writelines(lines) + else: + try: + with util.open(path, args.open_mode) as fp: + fp.write(func(args)) + except FileExistsError: + LOG.warning("File already present") + except Exception as exc: + LOG.error("%s: %s", exc.__class__.__name__, exc, exc_info=exc) ############################################################################### -# File Creation ############################################################### - -def create_extractor_module(args): - category = args.category - - path = util.path("gallery_dl", "extractor", f"{category}.py") - LOG.info("Creating '%s'", util.trim(path)) +# Extractor ################################################################### +def generate_module(args): type = args.type if type == "manga": generate_extractors = generate_extractors_manga @@ -70,17 +74,16 @@ def create_extractor_module(args): else: generate_extractors = generate_extractors_basic - with util.open(path, args.open_mode) as fp: - if copyright := args.copyright: - copyright = f"\n# Copyright {dt.date.today().year} {copyright}\n#" + if copyright := args.copyright: + copyright = f"\n# Copyright {dt.date.today().year} {copyright}\n#" - fp.write(f'''\ + return f'''\ {ENCODING}{copyright} {LICENSE} """Extractors for {args.root}/""" {generate_extractors(args)}\ -''') +''' def generate_extractors_basic(args): @@ -216,85 +219,66 @@ BASE_PATTERN = r"(?:https?://)?{subdomain}{re.escape(domain)}" ############################################################################### # Test Results ################################################################ -def create_test_results_file(args): - path = util.path("test", "results", f"{args.category}.py") - LOG.info("Creating '%s'", util.trim(path)) +def generate_test(args): + category = args.category - import_stmt = generate_test_result_import(args) - with util.open(path, "x") as fp: - fp.write(f"""\ + if category[0].isdecimal(): + import_stmt = f"""\ +gallery_dl = __import__("gallery_dl.extractor.{category}") +_{category} = getattr(gallery_dl.extractor, "{category}") +""" + else: + import_stmt = f"""\ +from gallery_dl.extractor import {category} +""" + + return f"""\ {ENCODING} {LICENSE} {import_stmt} __tests__ = ( ) -""") - - -def generate_test_result_import(args): - cat = args.category - - if cat[0].isdecimal(): - import_stmt = f"""\ -gallery_dl = __import__("gallery_dl.extractor.{cat}") -_{cat} = getattr(gallery_dl.extractor, "{cat}") """ - else: - import_stmt = f"""\ -from gallery_dl.extractor import {cat} -""" - - return import_stmt ############################################################################### -# Code Modification ########################################################### +# Modules List ################################################################ -def insert_into_modules_list(args): +def insert_into_modules_list(args, lines): category = args.category - LOG.info("Adding '%s' to gallery_dl/extractor/__init__.py modules list", - category) - - path = util.path("gallery_dl", "extractor", "__init__.py") - with util.open(path) as fp: - lines = fp.readlines() module_name = f' "{category}",\n' if module_name in lines: - return "… already present" + return False compare = False for idx, line in enumerate(lines): if compare: cat = text.extr(line, '"', '"') if cat == category: - return "… already present" + return False if cat > category or cat == "booru": break elif line.startswith("modules = "): compare = True lines.insert(idx, module_name) - with util.lazy(path) as fp: - fp.writelines(lines) + return True -def insert_into_supportedsites(args): +############################################################################### +# Supported Sites ############################################################# + +def insert_into_supportedsites(args, lines): category = args.category - LOG.info("Adding '%s' to scripts/supportedsites.py category list", - category) - - path = util.path("scripts", "supportedsites.py") - with util.open(path) as fp: - lines = fp.readlines() compare = False for idx, line in enumerate(lines): if compare: cat = text.extr(line, '"', '"') if cat == category: - return "… already present" + return False if cat > category: break elif line.startswith("CATEGORY_MAP = "): @@ -303,13 +287,11 @@ def insert_into_supportedsites(args): ws = " " * max(15 - len(category), 0) line = f''' "{category}"{ws}: "{args.site_name}",\n''' lines.insert(idx, line) - - with util.lazy(path) as fp: - fp.writelines(lines) + return True ############################################################################### -# General ##################################################################### +# Command-Line Options ######################################################## def parse_args(args=None): parser = argparse.ArgumentParser(args) @@ -371,7 +353,7 @@ def parse_args(args=None): def main(): args = parse_args() - init_extractor_module(args) + init_extractor(args) if __name__ == "__main__":