diff --git a/gallery_dl/downloader/__init__.py b/gallery_dl/downloader/__init__.py index 97972cd3..e0851ea7 100644 --- a/gallery_dl/downloader/__init__.py +++ b/gallery_dl/downloader/__init__.py @@ -22,15 +22,23 @@ def find(scheme): try: return _cache[scheme] except KeyError: - klass = None + pass + + klass = None + if scheme == "https": + scheme = "http" + if scheme in modules: # prevent unwanted imports try: - if scheme in modules: # prevent unwanted imports - module = importlib.import_module("." + scheme, __package__) - klass = module.__downloader__ - except (ImportError, AttributeError, TypeError): + module = importlib.import_module("." + scheme, __package__) + klass = module.__downloader__ + except ImportError: pass + + if scheme == "http": + _cache["http"] = _cache["https"] = klass + else: _cache[scheme] = klass - return klass + return klass # -------------------------------------------------------------------- diff --git a/gallery_dl/job.py b/gallery_dl/job.py index 667b9b37..749bde72 100644 --- a/gallery_dl/job.py +++ b/gallery_dl/job.py @@ -281,20 +281,22 @@ class DownloadJob(Job): def get_downloader(self, scheme): """Return a downloader suitable for 'scheme'""" - if scheme == "https": - scheme = "http" try: return self.downloaders[scheme] except KeyError: pass klass = downloader.find(scheme) - if klass and config.get(("downloader", scheme, "enabled"), True): + if klass and config.get(("downloader", klass.scheme, "enabled"), True): instance = klass(self.extractor, self.out) else: instance = None self.log.error("'%s:' URLs are not supported/enabled", scheme) - self.downloaders[scheme] = instance + + if klass.scheme == "http": + self.downloaders["http"] = self.downloaders["https"] = instance + else: + self.downloaders[scheme] = instance return instance def initialize(self, keywords=None): diff --git a/test/test_downloader.py b/test/test_downloader.py index 7c2b981b..caed9838 100644 --- a/test/test_downloader.py +++ b/test/test_downloader.py @@ -8,13 +8,16 @@ # published by the Free Software Foundation. import re +import sys import base64 import os.path import tempfile -import unittest import threading import http.server +import unittest +from unittest.mock import Mock, MagicMock, patch + import gallery_dl.downloader as downloader import gallery_dl.extractor as extractor import gallery_dl.config as config @@ -23,6 +26,73 @@ from gallery_dl.output import NullOutput from gallery_dl.util import PathFormat +class MockDownloaderModule(Mock): + __downloader__ = "mock" + + +class TestDownloaderModule(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # allow import of ytdl downloader module without youtube_dl installed + sys.modules["youtube_dl"] = MagicMock() + + @classmethod + def tearDownClass(cls): + del sys.modules["youtube_dl"] + + def tearDown(self): + downloader._cache.clear() + + def test_find(self): + cls = downloader.find("http") + self.assertEqual(cls.__name__, "HttpDownloader") + self.assertEqual(cls.scheme , "http") + + cls = downloader.find("https") + self.assertEqual(cls.__name__, "HttpDownloader") + self.assertEqual(cls.scheme , "http") + + cls = downloader.find("text") + self.assertEqual(cls.__name__, "TextDownloader") + self.assertEqual(cls.scheme , "text") + + cls = downloader.find("ytdl") + self.assertEqual(cls.__name__, "YoutubeDLDownloader") + self.assertEqual(cls.scheme , "ytdl") + + self.assertEqual(downloader.find("ftp"), None) + self.assertEqual(downloader.find("foo"), None) + self.assertEqual(downloader.find(1234) , None) + self.assertEqual(downloader.find(None) , None) + + @patch("importlib.import_module") + def test_cache(self, import_module): + import_module.return_value = MockDownloaderModule() + downloader.find("http") + downloader.find("text") + downloader.find("ytdl") + self.assertEqual(import_module.call_count, 3) + downloader.find("http") + downloader.find("text") + downloader.find("ytdl") + self.assertEqual(import_module.call_count, 3) + + @patch("importlib.import_module") + def test_cache_http(self, import_module): + import_module.return_value = MockDownloaderModule() + downloader.find("http") + downloader.find("https") + self.assertEqual(import_module.call_count, 1) + + @patch("importlib.import_module") + def test_cache_https(self, import_module): + import_module.return_value = MockDownloaderModule() + downloader.find("https") + downloader.find("http") + self.assertEqual(import_module.call_count, 1) + + class TestDownloaderBase(unittest.TestCase): @classmethod