diff --git a/docs/configuration.rst b/docs/configuration.rst index b385585d..a0c44619 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -910,6 +910,16 @@ Description Controls how a user is directed to an OAuth authorization site. =========== ===== +extractor.oauth.cache +--------------------- +=========== ===== +Type ``bool`` +Default ``true`` +Description Store tokens received during OAuth authorizations + in `cache `__. +=========== ===== + + extractor.oauth.port -------------------- =========== ===== diff --git a/docs/gallery-dl.conf b/docs/gallery-dl.conf index 1f312982..c120d25b 100644 --- a/docs/gallery-dl.conf +++ b/docs/gallery-dl.conf @@ -30,7 +30,6 @@ }, "deviantart": { - "refresh-token": null, "extra": false, "flat": true, "folders": false, @@ -51,8 +50,6 @@ }, "flickr": { - "access-token": null, - "access-token-secret": null, "videos": true, "size-max": null }, @@ -96,6 +93,7 @@ "oauth": { "browser": true, + "cache": true, "port": 6414 }, "pixiv": @@ -120,7 +118,6 @@ }, "reddit": { - "refresh-token": null, "comments": 0, "morecomments": false, "date-min": 0, diff --git a/gallery_dl/extractor/deviantart.py b/gallery_dl/extractor/deviantart.py index c5d410fc..cda357a3 100644 --- a/gallery_dl/extractor/deviantart.py +++ b/gallery_dl/extractor/deviantart.py @@ -850,9 +850,12 @@ class DeviantartOAuthAPI(): self.client_secret = extractor.config( "client-secret", self.CLIENT_SECRET) - self.refresh_token = extractor.config("refresh-token") - if self.refresh_token == "cache": - self.refresh_token = "#" + str(self.client_id) + token = extractor.config("refresh-token") + if token is None or token == "cache": + token = "#" + str(self.client_id) + if not _refresh_token_cache(token): + token = None + self.refresh_token_key = token self.log.debug( "Using %s API credentials (client-id %s)", @@ -952,18 +955,19 @@ class DeviantartOAuthAPI(): endpoint = "user/profile/" + username return self._call(endpoint, fatal=False) - def authenticate(self, refresh_token): + def authenticate(self, refresh_token_key): """Authenticate the application by requesting an access token""" - self.headers["Authorization"] = self._authenticate_impl(refresh_token) + self.headers["Authorization"] = \ + self._authenticate_impl(refresh_token_key) @cache(maxage=3600, keyarg=1) - def _authenticate_impl(self, refresh_token): + def _authenticate_impl(self, refresh_token_key): """Actual authenticate implementation""" url = "https://www.deviantart.com/oauth2/token" - if refresh_token: + if refresh_token_key: self.log.info("Refreshing private access token") data = {"grant_type": "refresh_token", - "refresh_token": _refresh_token_cache(refresh_token)} + "refresh_token": _refresh_token_cache(refresh_token_key)} else: self.log.info("Requesting public access token") data = {"grant_type": "client_credentials"} @@ -977,8 +981,9 @@ class DeviantartOAuthAPI(): self.log.debug("Server response: %s", data) raise exception.AuthenticationError('"{}" ({})'.format( data.get("error_description"), data.get("error"))) - if refresh_token: - _refresh_token_cache.update(refresh_token, data["refresh_token"]) + if refresh_token_key: + _refresh_token_cache.update( + refresh_token_key, data["refresh_token"]) return "Bearer " + data["access_token"] def _call(self, endpoint, params=None, fatal=True, public=True): @@ -988,7 +993,7 @@ class DeviantartOAuthAPI(): if self.delay >= 0: time.sleep(2 ** self.delay) - self.authenticate(None if public else self.refresh_token) + self.authenticate(None if public else self.refresh_token_key) response = self.extractor.request( url, headers=self.headers, params=params, fatal=None) data = response.json() @@ -1024,7 +1029,7 @@ class DeviantartOAuthAPI(): if extend: if public and len(data["results"]) < params["limit"]: - if self.refresh_token: + if self.refresh_token_key: self.log.debug("Switching to private access token") public = False continue @@ -1155,9 +1160,11 @@ class DeviantartEclipseAPI(): return text.rextract(page, '\\"id\\":', ',', pos)[0].strip('" ') -@cache(maxage=10*365*24*3600, keyarg=0) -def _refresh_token_cache(original_token, new_token=None): - return new_token or original_token +@cache(maxage=100*365*24*3600, keyarg=0) +def _refresh_token_cache(token): + if token and token[0] == "#": + return None + return token ############################################################################### diff --git a/gallery_dl/extractor/oauth.py b/gallery_dl/extractor/oauth.py index 123bb44f..076c7707 100644 --- a/gallery_dl/extractor/oauth.py +++ b/gallery_dl/extractor/oauth.py @@ -26,6 +26,7 @@ class OAuthBase(Extractor): def __init__(self, match): Extractor.__init__(self, match) self.client = None + self.cache = config.get(("extractor", self.category), "cache", True) def oauth_config(self, key, default=None): return config.interpolate( @@ -94,6 +95,13 @@ class OAuthBase(Extractor): token_secret=data["oauth_token_secret"], )) + # write to cache + if self.cache: + key = (self.subcategory, self.session.auth.consumer_key) + tokens = (data["oauth_token"], data["oauth_token_secret"]) + oauth._token_cache.update(key, tokens) + self.log.info("Writing tokens to cache") + def _oauth2_authorization_code_grant( self, client_id, client_secret, auth_url, token_url, scope="read", key="refresh_token", auth=True, @@ -162,7 +170,7 @@ class OAuthBase(Extractor): )) # write to cache - if cache and config.get(("extractor", self.category), "cache"): + if self.cache and cache: cache.update("#" + str(client_id), data[key]) self.log.info("Writing 'refresh-token' to cache") @@ -223,6 +231,7 @@ class OAuthReddit(OAuthBase): "https://www.reddit.com/api/v1/authorize", "https://www.reddit.com/api/v1/access_token", scope="read history", + cache=reddit._refresh_token_cache, ) diff --git a/gallery_dl/extractor/reddit.py b/gallery_dl/extractor/reddit.py index c3b64793..2e3864a3 100644 --- a/gallery_dl/extractor/reddit.py +++ b/gallery_dl/extractor/reddit.py @@ -222,7 +222,6 @@ class RedditAPI(): self.extractor = extractor self.comments = text.parse_int(extractor.config("comments", 0)) self.morecomments = extractor.config("morecomments", False) - self.refresh_token = extractor.config("refresh-token") self.log = extractor.log client_id = extractor.config("client-id", self.CLIENT_ID) @@ -236,6 +235,13 @@ class RedditAPI(): self.client_id = client_id self.headers = {"User-Agent": user_agent} + token = extractor.config("refresh-token") + if token is None or token == "cache": + key = "#" + self.client_id + self.refresh_token = _refresh_token_cache(key) + else: + self.refresh_token = token + def submission(self, submission_id): """Fetch the (submission, comments)=-tuple for a submission id""" endpoint = "/comments/" + submission_id + "/.json" @@ -382,3 +388,10 @@ class RedditAPI(): @staticmethod def _decode(sid): return util.bdecode(sid, "0123456789abcdefghijklmnopqrstuvwxyz") + + +@cache(maxage=100*365*24*3600, keyarg=0) +def _refresh_token_cache(token): + if token and token[0] == "#": + return None + return token diff --git a/gallery_dl/oauth.py b/gallery_dl/oauth.py index 9ceefbf3..e9dfff02 100644 --- a/gallery_dl/oauth.py +++ b/gallery_dl/oauth.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018-2019 Mike Fährmann +# Copyright 2018-2020 Mike Fährmann # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 as @@ -20,6 +20,7 @@ import requests import requests.auth from . import text +from .cache import cache def nonce(size, alphabet=string.ascii_letters): @@ -117,6 +118,10 @@ class OAuth1API(): token_secret = extractor.config("access-token-secret") key_type = "default" if api_key == self.API_KEY else "custom" + if token is None or token == "cache": + key = (extractor.category, api_key) + token, token_secret = _token_cache(key) + if api_key and api_secret and token and token_secret: self.log.debug("Using %s OAuth1.0 authentication", key_type) self.session = OAuth1Session( @@ -131,3 +136,8 @@ class OAuth1API(): kwargs["fatal"] = None kwargs["session"] = self.session return self.extractor.request(url, **kwargs) + + +@cache(maxage=100*365*24*3600, keyarg=0) +def _token_cache(key): + return None, None diff --git a/test/test_cache.py b/test/test_cache.py index 753bda77..ecf482ca 100644 --- a/test/test_cache.py +++ b/test/test_cache.py @@ -22,8 +22,8 @@ config.set(("cache",), "file", dbpath) from gallery_dl import cache # noqa E402 -def tearDownModule(): - util.remove_file(dbpath) +# def tearDownModule(): +# util.remove_file(dbpath) class TestCache(unittest.TestCase):