re-implement OAuth1.0 code
OAuth support for SmugMug needs some additional features (auth-rebuild on redirect, query parameters in URL, ...) and fixing this in the old code wouldn't work all that well.
This commit is contained in:
@@ -9,7 +9,7 @@
|
|||||||
"""Extract images from https://www.flickr.com/"""
|
"""Extract images from https://www.flickr.com/"""
|
||||||
|
|
||||||
from .common import Extractor, Message
|
from .common import Extractor, Message
|
||||||
from .. import text, util, exception
|
from .. import text, oauth, util, exception
|
||||||
|
|
||||||
|
|
||||||
class FlickrExtractor(Extractor):
|
class FlickrExtractor(Extractor):
|
||||||
@@ -264,17 +264,20 @@ class FlickrAPI():
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, extractor):
|
def __init__(self, extractor):
|
||||||
self.api_key = extractor.config("api-key", self.API_KEY)
|
api_key = extractor.config("api-key", self.API_KEY)
|
||||||
self.api_secret = extractor.config("api-secret", self.API_SECRET)
|
api_secret = extractor.config("api-secret", self.API_SECRET)
|
||||||
token = extractor.config("access-token")
|
token = extractor.config("access-token")
|
||||||
token_secret = extractor.config("access-token-secret")
|
token_secret = extractor.config("access-token-secret")
|
||||||
if token and token_secret:
|
|
||||||
self.session = util.OAuthSession(
|
if api_key and api_secret and token and token_secret:
|
||||||
extractor.session,
|
self.session = oauth.OAuth1Session(
|
||||||
self.api_key, self.api_secret, token, token_secret)
|
api_key, api_secret,
|
||||||
|
token, token_secret,
|
||||||
|
)
|
||||||
self.api_key = None
|
self.api_key = None
|
||||||
else:
|
else:
|
||||||
self.session = extractor.session
|
self.session = extractor.session
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
self.maxsize = extractor.config("size-max")
|
self.maxsize = extractor.config("size-max")
|
||||||
if isinstance(self.maxsize, str):
|
if isinstance(self.maxsize, str):
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
from .common import Extractor, Message
|
from .common import Extractor, Message
|
||||||
from . import deviantart, flickr, reddit, tumblr
|
from . import deviantart, flickr, reddit, tumblr
|
||||||
from .. import text, util, config
|
from .. import text, oauth, config
|
||||||
import os
|
import os
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
@@ -70,21 +70,19 @@ class OAuthBase(Extractor):
|
|||||||
def _oauth1_authorization_flow(
|
def _oauth1_authorization_flow(
|
||||||
self, request_token_url, authorize_url, access_token_url):
|
self, request_token_url, authorize_url, access_token_url):
|
||||||
"""Perform the OAuth 1.0a authorization flow"""
|
"""Perform the OAuth 1.0a authorization flow"""
|
||||||
del self.session.params["oauth_token"]
|
|
||||||
|
|
||||||
# get a request token
|
# get a request token
|
||||||
params = {"oauth_callback": self.redirect_uri}
|
params = {"oauth_callback": self.redirect_uri}
|
||||||
data = self.session.get(request_token_url, params=params).text
|
data = self.session.get(request_token_url, params=params).text
|
||||||
|
|
||||||
data = text.parse_query(data)
|
data = text.parse_query(data)
|
||||||
self.session.params["oauth_token"] = token = data["oauth_token"]
|
self.session.auth.token_secret = data["oauth_token_secret"]
|
||||||
self.session.token_secret = data["oauth_token_secret"]
|
|
||||||
|
|
||||||
# get the user's authorization
|
# get the user's authorization
|
||||||
params = {"oauth_token": token, "perms": "read"}
|
params = {"oauth_token": data["oauth_token"], "perms": "read"}
|
||||||
data = self.open(authorize_url, params)
|
data = self.open(authorize_url, params)
|
||||||
|
|
||||||
# exchange the request token for an access token
|
# exchange the request token for an access token
|
||||||
|
# self.session.token = data["oauth_token"]
|
||||||
data = self.session.get(access_token_url, params=data).text
|
data = self.session.get(access_token_url, params=data).text
|
||||||
|
|
||||||
data = text.parse_query(data)
|
data = text.parse_query(data)
|
||||||
@@ -101,7 +99,7 @@ class OAuthBase(Extractor):
|
|||||||
|
|
||||||
state = "gallery-dl_{}_{}".format(
|
state = "gallery-dl_{}_{}".format(
|
||||||
self.subcategory,
|
self.subcategory,
|
||||||
util.OAuthSession.nonce(8)
|
oauth.nonce(8),
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_params = {
|
auth_params = {
|
||||||
@@ -182,8 +180,7 @@ class OAuthFlickr(OAuthBase):
|
|||||||
|
|
||||||
def __init__(self, match):
|
def __init__(self, match):
|
||||||
OAuthBase.__init__(self, match)
|
OAuthBase.__init__(self, match)
|
||||||
self.session = util.OAuthSession(
|
self.session = oauth.OAuth1Session(
|
||||||
self.session,
|
|
||||||
self.oauth_config("api-key", flickr.FlickrAPI.API_KEY),
|
self.oauth_config("api-key", flickr.FlickrAPI.API_KEY),
|
||||||
self.oauth_config("api-secret", flickr.FlickrAPI.API_SECRET),
|
self.oauth_config("api-secret", flickr.FlickrAPI.API_SECRET),
|
||||||
)
|
)
|
||||||
@@ -221,8 +218,7 @@ class OAuthTumblr(OAuthBase):
|
|||||||
|
|
||||||
def __init__(self, match):
|
def __init__(self, match):
|
||||||
OAuthBase.__init__(self, match)
|
OAuthBase.__init__(self, match)
|
||||||
self.session = util.OAuthSession(
|
self.session = oauth.OAuth1Session(
|
||||||
self.session,
|
|
||||||
self.oauth_config("api-key", tumblr.TumblrAPI.API_KEY),
|
self.oauth_config("api-key", tumblr.TumblrAPI.API_KEY),
|
||||||
self.oauth_config("api-secret", tumblr.TumblrAPI.API_SECRET),
|
self.oauth_config("api-secret", tumblr.TumblrAPI.API_SECRET),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
"""Extract images from https://www.smugmug.com/"""
|
"""Extract images from https://www.smugmug.com/"""
|
||||||
|
|
||||||
from .common import Extractor, Message
|
from .common import Extractor, Message
|
||||||
from .. import text, util, exception
|
from .. import text, oauth, exception
|
||||||
|
|
||||||
BASE_PATTERN = (
|
BASE_PATTERN = (
|
||||||
r"(?:smugmug:(?!album:)(?:https?://)?([^/]+)|"
|
r"(?:smugmug:(?!album:)(?:https?://)?([^/]+)|"
|
||||||
@@ -186,8 +186,7 @@ class SmugmugAPI():
|
|||||||
token_secret = extractor.config("access-token-secret")
|
token_secret = extractor.config("access-token-secret")
|
||||||
|
|
||||||
if api_key and api_secret and token and token_secret:
|
if api_key and api_secret and token and token_secret:
|
||||||
self.session = util.OAuthSession(
|
self.session = oauth.OAuth1Session(
|
||||||
extractor.session,
|
|
||||||
api_key, api_secret,
|
api_key, api_secret,
|
||||||
token, token_secret,
|
token, token_secret,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
"""Extract images from https://www.tumblr.com/"""
|
"""Extract images from https://www.tumblr.com/"""
|
||||||
|
|
||||||
from .common import Extractor, Message
|
from .common import Extractor, Message
|
||||||
from .. import text, util, exception
|
from .. import text, oauth, exception
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@@ -261,8 +261,7 @@ class TumblrAPI():
|
|||||||
token_secret = extractor.config("access-token-secret")
|
token_secret = extractor.config("access-token-secret")
|
||||||
|
|
||||||
if api_key and api_secret and token and token_secret:
|
if api_key and api_secret and token and token_secret:
|
||||||
self.session = util.OAuthSession(
|
self.session = oauth.OAuth1Session(
|
||||||
extractor.session,
|
|
||||||
api_key, api_secret,
|
api_key, api_secret,
|
||||||
token, token_secret,
|
token, token_secret,
|
||||||
)
|
)
|
||||||
|
|||||||
101
gallery_dl/oauth.py
Normal file
101
gallery_dl/oauth.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2018 Mike Fährmann
|
||||||
|
#
|
||||||
|
# This program is free software; you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU General Public License version 2 as
|
||||||
|
# published by the Free Software Foundation.
|
||||||
|
|
||||||
|
"""OAuth helper functions and classes"""
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import hashlib
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import requests.auth
|
||||||
|
|
||||||
|
from . import text
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth1Session(requests.Session):
|
||||||
|
"""Extension to requests.Session objects to support OAuth 1.0"""
|
||||||
|
|
||||||
|
def __init__(self, consumer_key, consumer_secret,
|
||||||
|
token=None, token_secret=None):
|
||||||
|
|
||||||
|
requests.Session.__init__(self)
|
||||||
|
self.auth = OAuth1Client(
|
||||||
|
consumer_key, consumer_secret,
|
||||||
|
token, token_secret,
|
||||||
|
)
|
||||||
|
|
||||||
|
def rebuild_auth(self, prepared_request, response):
|
||||||
|
if "Authorization" in prepared_request.headers:
|
||||||
|
del prepared_request.headers["Authorization"]
|
||||||
|
prepared_request.prepare_auth(self.auth)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth1Client(requests.auth.AuthBase):
|
||||||
|
"""OAuth1.0a authentication"""
|
||||||
|
def __init__(self, consumer_key, consumer_secret,
|
||||||
|
token=None, token_secret=None):
|
||||||
|
|
||||||
|
self.consumer_key = consumer_key
|
||||||
|
self.consumer_secret = consumer_secret
|
||||||
|
self.token = token
|
||||||
|
self.token_secret = token_secret
|
||||||
|
|
||||||
|
def __call__(self, request):
|
||||||
|
oauth_params = [
|
||||||
|
("oauth_consumer_key", self.consumer_key),
|
||||||
|
("oauth_nonce", nonce(16)),
|
||||||
|
("oauth_signature_method", "HMAC-SHA1"),
|
||||||
|
("oauth_timestamp", str(int(time.time()))),
|
||||||
|
("oauth_version", "1.0"),
|
||||||
|
]
|
||||||
|
if self.token:
|
||||||
|
oauth_params.append(("oauth_token", self.token))
|
||||||
|
|
||||||
|
signature = self.generate_signature(request, oauth_params)
|
||||||
|
oauth_params.append(("oauth_signature", signature))
|
||||||
|
|
||||||
|
request.headers["Authorization"] = "OAuth " + ",".join(
|
||||||
|
key + '="' + value + '"' for key, value in oauth_params)
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
||||||
|
def generate_signature(self, request, params):
|
||||||
|
"""Generate 'oauth_signature' value"""
|
||||||
|
url, _, query = request.url.partition("?")
|
||||||
|
|
||||||
|
params = params.copy()
|
||||||
|
for key, value in text.parse_query(query).items():
|
||||||
|
params.append((quote(key), quote(value)))
|
||||||
|
params.sort()
|
||||||
|
query = "&".join("=".join(item) for item in params)
|
||||||
|
|
||||||
|
message = concat(request.method, url, query).encode()
|
||||||
|
key = concat(self.consumer_secret, self.token_secret or "").encode()
|
||||||
|
signature = hmac.new(key, message, hashlib.sha1).digest()
|
||||||
|
|
||||||
|
return quote(base64.b64encode(signature).decode())
|
||||||
|
|
||||||
|
|
||||||
|
def concat(*args):
|
||||||
|
"""Concatenate 'args'"""
|
||||||
|
return "&".join(quote(item) for item in args)
|
||||||
|
|
||||||
|
|
||||||
|
def nonce(size, alphabet=string.ascii_letters):
|
||||||
|
"""Generate a nonce value with 'size' characters"""
|
||||||
|
return "".join(random.choice(alphabet) for _ in range(size))
|
||||||
|
|
||||||
|
|
||||||
|
def quote(value, quote=urllib.parse.quote):
|
||||||
|
"""Quote 'value' according to the OAuth1.0 standard"""
|
||||||
|
return quote(value, "~")
|
||||||
@@ -11,14 +11,9 @@
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import hmac
|
|
||||||
import time
|
|
||||||
import base64
|
|
||||||
import random
|
|
||||||
import shutil
|
import shutil
|
||||||
import string
|
import string
|
||||||
import _string
|
import _string
|
||||||
import hashlib
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import datetime
|
import datetime
|
||||||
import itertools
|
import itertools
|
||||||
@@ -497,54 +492,6 @@ class PathFormat():
|
|||||||
return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path
|
return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path
|
||||||
|
|
||||||
|
|
||||||
class OAuthSession():
|
|
||||||
"""Minimal wrapper for requests.session objects to support OAuth 1.0"""
|
|
||||||
def __init__(self, session, consumer_key, consumer_secret,
|
|
||||||
token=None, token_secret=None):
|
|
||||||
self.session = session
|
|
||||||
self.consumer_secret = consumer_secret
|
|
||||||
self.token_secret = token_secret or ""
|
|
||||||
self.params = {}
|
|
||||||
self.params["oauth_consumer_key"] = consumer_key
|
|
||||||
self.params["oauth_token"] = token
|
|
||||||
self.params["oauth_signature_method"] = "HMAC-SHA1"
|
|
||||||
self.params["oauth_version"] = "1.0"
|
|
||||||
|
|
||||||
def get(self, url, params, **kwargs):
|
|
||||||
params.update(self.params)
|
|
||||||
params["oauth_nonce"] = self.nonce(16)
|
|
||||||
params["oauth_timestamp"] = int(time.time())
|
|
||||||
return self.session.get(url + self.sign(url, params), **kwargs)
|
|
||||||
|
|
||||||
def sign(self, url, params):
|
|
||||||
"""Generate 'oauth_signature' value and return query string"""
|
|
||||||
query = self.urlencode(params)
|
|
||||||
message = self.concat("GET", url, query).encode()
|
|
||||||
key = self.concat(self.consumer_secret, self.token_secret).encode()
|
|
||||||
signature = hmac.new(key, message, hashlib.sha1).digest()
|
|
||||||
return "?{}&oauth_signature={}".format(
|
|
||||||
query, self.quote(base64.b64encode(signature).decode()))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def concat(*args):
|
|
||||||
return "&".join(OAuthSession.quote(item) for item in args)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def nonce(N, alphabet=string.ascii_letters):
|
|
||||||
return "".join(random.choice(alphabet) for _ in range(N))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def quote(value, quote=urllib.parse.quote):
|
|
||||||
return quote(value, "~")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def urlencode(params):
|
|
||||||
return "&".join(
|
|
||||||
OAuthSession.quote(str(key)) + "=" + OAuthSession.quote(str(value))
|
|
||||||
for key, value in sorted(params.items()) if value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadArchive():
|
class DownloadArchive():
|
||||||
|
|
||||||
def __init__(self, path, extractor):
|
def __init__(self, path, extractor):
|
||||||
|
|||||||
@@ -8,10 +8,8 @@
|
|||||||
# published by the Free Software Foundation.
|
# published by the Free Software Foundation.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import requests
|
|
||||||
|
|
||||||
from gallery_dl import text
|
from gallery_dl import oauth, text
|
||||||
from gallery_dl.util import OAuthSession
|
|
||||||
|
|
||||||
TESTSERVER = "http://oauthbin.com"
|
TESTSERVER = "http://oauthbin.com"
|
||||||
CONSUMER_KEY = "key"
|
CONSUMER_KEY = "key"
|
||||||
@@ -25,7 +23,7 @@ ACCESS_TOKEN_SECRET = "accesssecret"
|
|||||||
class TestOAuthSession(unittest.TestCase):
|
class TestOAuthSession(unittest.TestCase):
|
||||||
|
|
||||||
def test_concat(self):
|
def test_concat(self):
|
||||||
concat = OAuthSession.concat
|
concat = oauth.concat
|
||||||
|
|
||||||
self.assertEqual(concat(), "")
|
self.assertEqual(concat(), "")
|
||||||
self.assertEqual(concat("str"), "str")
|
self.assertEqual(concat("str"), "str")
|
||||||
@@ -37,18 +35,18 @@ class TestOAuthSession(unittest.TestCase):
|
|||||||
"GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da"
|
"GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_nonce(self, N=16):
|
def test_nonce(self, size=16):
|
||||||
nonce_values = set(OAuthSession.nonce(N) for _ in range(N))
|
nonce_values = set(oauth.nonce(size) for _ in range(size))
|
||||||
|
|
||||||
# uniqueness
|
# uniqueness
|
||||||
self.assertEqual(len(nonce_values), N)
|
self.assertEqual(len(nonce_values), size)
|
||||||
|
|
||||||
# length
|
# length
|
||||||
for nonce in nonce_values:
|
for nonce in nonce_values:
|
||||||
self.assertEqual(len(nonce), N)
|
self.assertEqual(len(nonce), size)
|
||||||
|
|
||||||
def test_quote(self):
|
def test_quote(self):
|
||||||
quote = OAuthSession.quote
|
quote = oauth.quote
|
||||||
|
|
||||||
reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü"
|
reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü"
|
||||||
unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
@@ -65,33 +63,6 @@ class TestOAuthSession(unittest.TestCase):
|
|||||||
self.assertTrue(len(quoted) >= 3)
|
self.assertTrue(len(quoted) >= 3)
|
||||||
self.assertEqual(quoted_hex.upper(), quoted_hex)
|
self.assertEqual(quoted_hex.upper(), quoted_hex)
|
||||||
|
|
||||||
def test_urlencode(self):
|
|
||||||
urlencode = OAuthSession.urlencode
|
|
||||||
|
|
||||||
self.assertEqual(urlencode({}), "")
|
|
||||||
self.assertEqual(urlencode({"foo": "bar"}), "foo=bar")
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
urlencode({"foo": "bar", "baz": "a", "a": "baz"}),
|
|
||||||
"a=baz&baz=a&foo=bar"
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
urlencode({
|
|
||||||
"oauth_consumer_key": "0685bd9184jfhq22",
|
|
||||||
"oauth_token": "ad180jjd733klru7",
|
|
||||||
"oauth_signature_method": "HMAC-SHA1",
|
|
||||||
"oauth_timestamp": 137131200,
|
|
||||||
"oauth_nonce": "4572616e48616d6d65724c61686176",
|
|
||||||
"oauth_version": "1.0"
|
|
||||||
}),
|
|
||||||
"oauth_consumer_key=0685bd9184jfhq22&"
|
|
||||||
"oauth_nonce=4572616e48616d6d65724c61686176&"
|
|
||||||
"oauth_signature_method=HMAC-SHA1&"
|
|
||||||
"oauth_timestamp=137131200&"
|
|
||||||
"oauth_token=ad180jjd733klru7&"
|
|
||||||
"oauth_version=1.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_request_token(self):
|
def test_request_token(self):
|
||||||
response = self._oauth_request(
|
response = self._oauth_request(
|
||||||
"/v1/request-token", {})
|
"/v1/request-token", {})
|
||||||
@@ -113,23 +84,20 @@ class TestOAuthSession(unittest.TestCase):
|
|||||||
self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET)
|
self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET)
|
||||||
|
|
||||||
def test_authenticated_call(self):
|
def test_authenticated_call(self):
|
||||||
params = {"method": "foo", "bar": "baz", "a": "äöüß/?&#"}
|
params = {"method": "foo", "a": "äöüß/?&#", "äöüß/?&#": "a"}
|
||||||
response = self._oauth_request(
|
response = self._oauth_request(
|
||||||
"/v1/echo", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET)
|
"/v1/echo", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET)
|
||||||
expected = OAuthSession.urlencode(params)
|
|
||||||
|
|
||||||
self.assertEqual(response, expected, msg=response)
|
|
||||||
self.assertEqual(text.parse_query(response), params)
|
self.assertEqual(text.parse_query(response), params)
|
||||||
|
|
||||||
def _oauth_request(self, endpoint, params=None,
|
def _oauth_request(self, endpoint, params=None,
|
||||||
oauth_token=None, oauth_token_secret=None):
|
oauth_token=None, oauth_token_secret=None):
|
||||||
session = OAuthSession(
|
session = oauth.OAuth1Session(
|
||||||
requests.session(),
|
|
||||||
CONSUMER_KEY, CONSUMER_SECRET,
|
CONSUMER_KEY, CONSUMER_SECRET,
|
||||||
oauth_token, oauth_token_secret,
|
oauth_token, oauth_token_secret,
|
||||||
)
|
)
|
||||||
url = TESTSERVER + endpoint
|
url = TESTSERVER + endpoint
|
||||||
return session.get(url, params.copy()).text
|
return session.get(url, params=params).text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user