update downloader.find() and related code

Instead of replacing 'https' with 'http' for every URL in
'get_downloader()', this now only happens once during downloader
initialization. Also unit tests.
This commit is contained in:
Mike Fährmann
2019-06-20 16:59:44 +02:00
parent f4ba98771d
commit ee4d7c3d89
3 changed files with 91 additions and 11 deletions

View File

@@ -8,13 +8,16 @@
# published by the Free Software Foundation.
import re
import sys
import base64
import os.path
import tempfile
import unittest
import threading
import http.server
import unittest
from unittest.mock import Mock, MagicMock, patch
import gallery_dl.downloader as downloader
import gallery_dl.extractor as extractor
import gallery_dl.config as config
@@ -23,6 +26,73 @@ from gallery_dl.output import NullOutput
from gallery_dl.util import PathFormat
class MockDownloaderModule(Mock):
__downloader__ = "mock"
class TestDownloaderModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
# allow import of ytdl downloader module without youtube_dl installed
sys.modules["youtube_dl"] = MagicMock()
@classmethod
def tearDownClass(cls):
del sys.modules["youtube_dl"]
def tearDown(self):
downloader._cache.clear()
def test_find(self):
cls = downloader.find("http")
self.assertEqual(cls.__name__, "HttpDownloader")
self.assertEqual(cls.scheme , "http")
cls = downloader.find("https")
self.assertEqual(cls.__name__, "HttpDownloader")
self.assertEqual(cls.scheme , "http")
cls = downloader.find("text")
self.assertEqual(cls.__name__, "TextDownloader")
self.assertEqual(cls.scheme , "text")
cls = downloader.find("ytdl")
self.assertEqual(cls.__name__, "YoutubeDLDownloader")
self.assertEqual(cls.scheme , "ytdl")
self.assertEqual(downloader.find("ftp"), None)
self.assertEqual(downloader.find("foo"), None)
self.assertEqual(downloader.find(1234) , None)
self.assertEqual(downloader.find(None) , None)
@patch("importlib.import_module")
def test_cache(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("http")
downloader.find("text")
downloader.find("ytdl")
self.assertEqual(import_module.call_count, 3)
downloader.find("http")
downloader.find("text")
downloader.find("ytdl")
self.assertEqual(import_module.call_count, 3)
@patch("importlib.import_module")
def test_cache_http(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("http")
downloader.find("https")
self.assertEqual(import_module.call_count, 1)
@patch("importlib.import_module")
def test_cache_https(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("https")
downloader.find("http")
self.assertEqual(import_module.call_count, 1)
class TestDownloaderBase(unittest.TestCase):
@classmethod