update downloader.find() and related code
Instead of replacing 'https' with 'http' for every URL in 'get_downloader()', this now only happens once during downloader initialization. Also unit tests.
This commit is contained in:
@@ -8,13 +8,16 @@
|
||||
# published by the Free Software Foundation.
|
||||
|
||||
import re
|
||||
import sys
|
||||
import base64
|
||||
import os.path
|
||||
import tempfile
|
||||
import unittest
|
||||
import threading
|
||||
import http.server
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
import gallery_dl.downloader as downloader
|
||||
import gallery_dl.extractor as extractor
|
||||
import gallery_dl.config as config
|
||||
@@ -23,6 +26,73 @@ from gallery_dl.output import NullOutput
|
||||
from gallery_dl.util import PathFormat
|
||||
|
||||
|
||||
class MockDownloaderModule(Mock):
|
||||
__downloader__ = "mock"
|
||||
|
||||
|
||||
class TestDownloaderModule(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# allow import of ytdl downloader module without youtube_dl installed
|
||||
sys.modules["youtube_dl"] = MagicMock()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
del sys.modules["youtube_dl"]
|
||||
|
||||
def tearDown(self):
|
||||
downloader._cache.clear()
|
||||
|
||||
def test_find(self):
|
||||
cls = downloader.find("http")
|
||||
self.assertEqual(cls.__name__, "HttpDownloader")
|
||||
self.assertEqual(cls.scheme , "http")
|
||||
|
||||
cls = downloader.find("https")
|
||||
self.assertEqual(cls.__name__, "HttpDownloader")
|
||||
self.assertEqual(cls.scheme , "http")
|
||||
|
||||
cls = downloader.find("text")
|
||||
self.assertEqual(cls.__name__, "TextDownloader")
|
||||
self.assertEqual(cls.scheme , "text")
|
||||
|
||||
cls = downloader.find("ytdl")
|
||||
self.assertEqual(cls.__name__, "YoutubeDLDownloader")
|
||||
self.assertEqual(cls.scheme , "ytdl")
|
||||
|
||||
self.assertEqual(downloader.find("ftp"), None)
|
||||
self.assertEqual(downloader.find("foo"), None)
|
||||
self.assertEqual(downloader.find(1234) , None)
|
||||
self.assertEqual(downloader.find(None) , None)
|
||||
|
||||
@patch("importlib.import_module")
|
||||
def test_cache(self, import_module):
|
||||
import_module.return_value = MockDownloaderModule()
|
||||
downloader.find("http")
|
||||
downloader.find("text")
|
||||
downloader.find("ytdl")
|
||||
self.assertEqual(import_module.call_count, 3)
|
||||
downloader.find("http")
|
||||
downloader.find("text")
|
||||
downloader.find("ytdl")
|
||||
self.assertEqual(import_module.call_count, 3)
|
||||
|
||||
@patch("importlib.import_module")
|
||||
def test_cache_http(self, import_module):
|
||||
import_module.return_value = MockDownloaderModule()
|
||||
downloader.find("http")
|
||||
downloader.find("https")
|
||||
self.assertEqual(import_module.call_count, 1)
|
||||
|
||||
@patch("importlib.import_module")
|
||||
def test_cache_https(self, import_module):
|
||||
import_module.return_value = MockDownloaderModule()
|
||||
downloader.find("https")
|
||||
downloader.find("http")
|
||||
self.assertEqual(import_module.call_count, 1)
|
||||
|
||||
|
||||
class TestDownloaderBase(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user