re-implement OAuth1.0 code
OAuth support for SmugMug needs some additional features (auth-rebuild on redirect, query parameters in URL, ...) and fixing this in the old code wouldn't work all that well.
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
"""Extract images from https://www.flickr.com/"""
|
||||
|
||||
from .common import Extractor, Message
|
||||
from .. import text, util, exception
|
||||
from .. import text, oauth, util, exception
|
||||
|
||||
|
||||
class FlickrExtractor(Extractor):
|
||||
@@ -264,17 +264,20 @@ class FlickrAPI():
|
||||
]
|
||||
|
||||
def __init__(self, extractor):
|
||||
self.api_key = extractor.config("api-key", self.API_KEY)
|
||||
self.api_secret = extractor.config("api-secret", self.API_SECRET)
|
||||
api_key = extractor.config("api-key", self.API_KEY)
|
||||
api_secret = extractor.config("api-secret", self.API_SECRET)
|
||||
token = extractor.config("access-token")
|
||||
token_secret = extractor.config("access-token-secret")
|
||||
if token and token_secret:
|
||||
self.session = util.OAuthSession(
|
||||
extractor.session,
|
||||
self.api_key, self.api_secret, token, token_secret)
|
||||
|
||||
if api_key and api_secret and token and token_secret:
|
||||
self.session = oauth.OAuth1Session(
|
||||
api_key, api_secret,
|
||||
token, token_secret,
|
||||
)
|
||||
self.api_key = None
|
||||
else:
|
||||
self.session = extractor.session
|
||||
self.api_key = api_key
|
||||
|
||||
self.maxsize = extractor.config("size-max")
|
||||
if isinstance(self.maxsize, str):
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
from .common import Extractor, Message
|
||||
from . import deviantart, flickr, reddit, tumblr
|
||||
from .. import text, util, config
|
||||
from .. import text, oauth, config
|
||||
import os
|
||||
import urllib.parse
|
||||
|
||||
@@ -70,21 +70,19 @@ class OAuthBase(Extractor):
|
||||
def _oauth1_authorization_flow(
|
||||
self, request_token_url, authorize_url, access_token_url):
|
||||
"""Perform the OAuth 1.0a authorization flow"""
|
||||
del self.session.params["oauth_token"]
|
||||
|
||||
# get a request token
|
||||
params = {"oauth_callback": self.redirect_uri}
|
||||
data = self.session.get(request_token_url, params=params).text
|
||||
|
||||
data = text.parse_query(data)
|
||||
self.session.params["oauth_token"] = token = data["oauth_token"]
|
||||
self.session.token_secret = data["oauth_token_secret"]
|
||||
self.session.auth.token_secret = data["oauth_token_secret"]
|
||||
|
||||
# get the user's authorization
|
||||
params = {"oauth_token": token, "perms": "read"}
|
||||
params = {"oauth_token": data["oauth_token"], "perms": "read"}
|
||||
data = self.open(authorize_url, params)
|
||||
|
||||
# exchange the request token for an access token
|
||||
# self.session.token = data["oauth_token"]
|
||||
data = self.session.get(access_token_url, params=data).text
|
||||
|
||||
data = text.parse_query(data)
|
||||
@@ -101,7 +99,7 @@ class OAuthBase(Extractor):
|
||||
|
||||
state = "gallery-dl_{}_{}".format(
|
||||
self.subcategory,
|
||||
util.OAuthSession.nonce(8)
|
||||
oauth.nonce(8),
|
||||
)
|
||||
|
||||
auth_params = {
|
||||
@@ -182,8 +180,7 @@ class OAuthFlickr(OAuthBase):
|
||||
|
||||
def __init__(self, match):
|
||||
OAuthBase.__init__(self, match)
|
||||
self.session = util.OAuthSession(
|
||||
self.session,
|
||||
self.session = oauth.OAuth1Session(
|
||||
self.oauth_config("api-key", flickr.FlickrAPI.API_KEY),
|
||||
self.oauth_config("api-secret", flickr.FlickrAPI.API_SECRET),
|
||||
)
|
||||
@@ -221,8 +218,7 @@ class OAuthTumblr(OAuthBase):
|
||||
|
||||
def __init__(self, match):
|
||||
OAuthBase.__init__(self, match)
|
||||
self.session = util.OAuthSession(
|
||||
self.session,
|
||||
self.session = oauth.OAuth1Session(
|
||||
self.oauth_config("api-key", tumblr.TumblrAPI.API_KEY),
|
||||
self.oauth_config("api-secret", tumblr.TumblrAPI.API_SECRET),
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"""Extract images from https://www.smugmug.com/"""
|
||||
|
||||
from .common import Extractor, Message
|
||||
from .. import text, util, exception
|
||||
from .. import text, oauth, exception
|
||||
|
||||
BASE_PATTERN = (
|
||||
r"(?:smugmug:(?!album:)(?:https?://)?([^/]+)|"
|
||||
@@ -186,8 +186,7 @@ class SmugmugAPI():
|
||||
token_secret = extractor.config("access-token-secret")
|
||||
|
||||
if api_key and api_secret and token and token_secret:
|
||||
self.session = util.OAuthSession(
|
||||
extractor.session,
|
||||
self.session = oauth.OAuth1Session(
|
||||
api_key, api_secret,
|
||||
token, token_secret,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"""Extract images from https://www.tumblr.com/"""
|
||||
|
||||
from .common import Extractor, Message
|
||||
from .. import text, util, exception
|
||||
from .. import text, oauth, exception
|
||||
from datetime import datetime, timedelta
|
||||
import re
|
||||
import time
|
||||
@@ -261,8 +261,7 @@ class TumblrAPI():
|
||||
token_secret = extractor.config("access-token-secret")
|
||||
|
||||
if api_key and api_secret and token and token_secret:
|
||||
self.session = util.OAuthSession(
|
||||
extractor.session,
|
||||
self.session = oauth.OAuth1Session(
|
||||
api_key, api_secret,
|
||||
token, token_secret,
|
||||
)
|
||||
|
||||
101
gallery_dl/oauth.py
Normal file
101
gallery_dl/oauth.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2018 Mike Fährmann
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License version 2 as
|
||||
# published by the Free Software Foundation.
|
||||
|
||||
"""OAuth helper functions and classes"""
|
||||
|
||||
import hmac
|
||||
import time
|
||||
import base64
|
||||
import random
|
||||
import string
|
||||
import hashlib
|
||||
import urllib.parse
|
||||
|
||||
import requests
|
||||
import requests.auth
|
||||
|
||||
from . import text
|
||||
|
||||
|
||||
class OAuth1Session(requests.Session):
|
||||
"""Extension to requests.Session objects to support OAuth 1.0"""
|
||||
|
||||
def __init__(self, consumer_key, consumer_secret,
|
||||
token=None, token_secret=None):
|
||||
|
||||
requests.Session.__init__(self)
|
||||
self.auth = OAuth1Client(
|
||||
consumer_key, consumer_secret,
|
||||
token, token_secret,
|
||||
)
|
||||
|
||||
def rebuild_auth(self, prepared_request, response):
|
||||
if "Authorization" in prepared_request.headers:
|
||||
del prepared_request.headers["Authorization"]
|
||||
prepared_request.prepare_auth(self.auth)
|
||||
|
||||
|
||||
class OAuth1Client(requests.auth.AuthBase):
|
||||
"""OAuth1.0a authentication"""
|
||||
def __init__(self, consumer_key, consumer_secret,
|
||||
token=None, token_secret=None):
|
||||
|
||||
self.consumer_key = consumer_key
|
||||
self.consumer_secret = consumer_secret
|
||||
self.token = token
|
||||
self.token_secret = token_secret
|
||||
|
||||
def __call__(self, request):
|
||||
oauth_params = [
|
||||
("oauth_consumer_key", self.consumer_key),
|
||||
("oauth_nonce", nonce(16)),
|
||||
("oauth_signature_method", "HMAC-SHA1"),
|
||||
("oauth_timestamp", str(int(time.time()))),
|
||||
("oauth_version", "1.0"),
|
||||
]
|
||||
if self.token:
|
||||
oauth_params.append(("oauth_token", self.token))
|
||||
|
||||
signature = self.generate_signature(request, oauth_params)
|
||||
oauth_params.append(("oauth_signature", signature))
|
||||
|
||||
request.headers["Authorization"] = "OAuth " + ",".join(
|
||||
key + '="' + value + '"' for key, value in oauth_params)
|
||||
|
||||
return request
|
||||
|
||||
def generate_signature(self, request, params):
|
||||
"""Generate 'oauth_signature' value"""
|
||||
url, _, query = request.url.partition("?")
|
||||
|
||||
params = params.copy()
|
||||
for key, value in text.parse_query(query).items():
|
||||
params.append((quote(key), quote(value)))
|
||||
params.sort()
|
||||
query = "&".join("=".join(item) for item in params)
|
||||
|
||||
message = concat(request.method, url, query).encode()
|
||||
key = concat(self.consumer_secret, self.token_secret or "").encode()
|
||||
signature = hmac.new(key, message, hashlib.sha1).digest()
|
||||
|
||||
return quote(base64.b64encode(signature).decode())
|
||||
|
||||
|
||||
def concat(*args):
|
||||
"""Concatenate 'args'"""
|
||||
return "&".join(quote(item) for item in args)
|
||||
|
||||
|
||||
def nonce(size, alphabet=string.ascii_letters):
|
||||
"""Generate a nonce value with 'size' characters"""
|
||||
return "".join(random.choice(alphabet) for _ in range(size))
|
||||
|
||||
|
||||
def quote(value, quote=urllib.parse.quote):
|
||||
"""Quote 'value' according to the OAuth1.0 standard"""
|
||||
return quote(value, "~")
|
||||
@@ -11,14 +11,9 @@
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
import hmac
|
||||
import time
|
||||
import base64
|
||||
import random
|
||||
import shutil
|
||||
import string
|
||||
import _string
|
||||
import hashlib
|
||||
import sqlite3
|
||||
import datetime
|
||||
import itertools
|
||||
@@ -497,54 +492,6 @@ class PathFormat():
|
||||
return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path
|
||||
|
||||
|
||||
class OAuthSession():
|
||||
"""Minimal wrapper for requests.session objects to support OAuth 1.0"""
|
||||
def __init__(self, session, consumer_key, consumer_secret,
|
||||
token=None, token_secret=None):
|
||||
self.session = session
|
||||
self.consumer_secret = consumer_secret
|
||||
self.token_secret = token_secret or ""
|
||||
self.params = {}
|
||||
self.params["oauth_consumer_key"] = consumer_key
|
||||
self.params["oauth_token"] = token
|
||||
self.params["oauth_signature_method"] = "HMAC-SHA1"
|
||||
self.params["oauth_version"] = "1.0"
|
||||
|
||||
def get(self, url, params, **kwargs):
|
||||
params.update(self.params)
|
||||
params["oauth_nonce"] = self.nonce(16)
|
||||
params["oauth_timestamp"] = int(time.time())
|
||||
return self.session.get(url + self.sign(url, params), **kwargs)
|
||||
|
||||
def sign(self, url, params):
|
||||
"""Generate 'oauth_signature' value and return query string"""
|
||||
query = self.urlencode(params)
|
||||
message = self.concat("GET", url, query).encode()
|
||||
key = self.concat(self.consumer_secret, self.token_secret).encode()
|
||||
signature = hmac.new(key, message, hashlib.sha1).digest()
|
||||
return "?{}&oauth_signature={}".format(
|
||||
query, self.quote(base64.b64encode(signature).decode()))
|
||||
|
||||
@staticmethod
|
||||
def concat(*args):
|
||||
return "&".join(OAuthSession.quote(item) for item in args)
|
||||
|
||||
@staticmethod
|
||||
def nonce(N, alphabet=string.ascii_letters):
|
||||
return "".join(random.choice(alphabet) for _ in range(N))
|
||||
|
||||
@staticmethod
|
||||
def quote(value, quote=urllib.parse.quote):
|
||||
return quote(value, "~")
|
||||
|
||||
@staticmethod
|
||||
def urlencode(params):
|
||||
return "&".join(
|
||||
OAuthSession.quote(str(key)) + "=" + OAuthSession.quote(str(value))
|
||||
for key, value in sorted(params.items()) if value
|
||||
)
|
||||
|
||||
|
||||
class DownloadArchive():
|
||||
|
||||
def __init__(self, path, extractor):
|
||||
|
||||
@@ -8,10 +8,8 @@
|
||||
# published by the Free Software Foundation.
|
||||
|
||||
import unittest
|
||||
import requests
|
||||
|
||||
from gallery_dl import text
|
||||
from gallery_dl.util import OAuthSession
|
||||
from gallery_dl import oauth, text
|
||||
|
||||
TESTSERVER = "http://oauthbin.com"
|
||||
CONSUMER_KEY = "key"
|
||||
@@ -25,7 +23,7 @@ ACCESS_TOKEN_SECRET = "accesssecret"
|
||||
class TestOAuthSession(unittest.TestCase):
|
||||
|
||||
def test_concat(self):
|
||||
concat = OAuthSession.concat
|
||||
concat = oauth.concat
|
||||
|
||||
self.assertEqual(concat(), "")
|
||||
self.assertEqual(concat("str"), "str")
|
||||
@@ -37,18 +35,18 @@ class TestOAuthSession(unittest.TestCase):
|
||||
"GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da"
|
||||
)
|
||||
|
||||
def test_nonce(self, N=16):
|
||||
nonce_values = set(OAuthSession.nonce(N) for _ in range(N))
|
||||
def test_nonce(self, size=16):
|
||||
nonce_values = set(oauth.nonce(size) for _ in range(size))
|
||||
|
||||
# uniqueness
|
||||
self.assertEqual(len(nonce_values), N)
|
||||
self.assertEqual(len(nonce_values), size)
|
||||
|
||||
# length
|
||||
for nonce in nonce_values:
|
||||
self.assertEqual(len(nonce), N)
|
||||
self.assertEqual(len(nonce), size)
|
||||
|
||||
def test_quote(self):
|
||||
quote = OAuthSession.quote
|
||||
quote = oauth.quote
|
||||
|
||||
reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü"
|
||||
unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
@@ -65,33 +63,6 @@ class TestOAuthSession(unittest.TestCase):
|
||||
self.assertTrue(len(quoted) >= 3)
|
||||
self.assertEqual(quoted_hex.upper(), quoted_hex)
|
||||
|
||||
def test_urlencode(self):
|
||||
urlencode = OAuthSession.urlencode
|
||||
|
||||
self.assertEqual(urlencode({}), "")
|
||||
self.assertEqual(urlencode({"foo": "bar"}), "foo=bar")
|
||||
|
||||
self.assertEqual(
|
||||
urlencode({"foo": "bar", "baz": "a", "a": "baz"}),
|
||||
"a=baz&baz=a&foo=bar"
|
||||
)
|
||||
self.assertEqual(
|
||||
urlencode({
|
||||
"oauth_consumer_key": "0685bd9184jfhq22",
|
||||
"oauth_token": "ad180jjd733klru7",
|
||||
"oauth_signature_method": "HMAC-SHA1",
|
||||
"oauth_timestamp": 137131200,
|
||||
"oauth_nonce": "4572616e48616d6d65724c61686176",
|
||||
"oauth_version": "1.0"
|
||||
}),
|
||||
"oauth_consumer_key=0685bd9184jfhq22&"
|
||||
"oauth_nonce=4572616e48616d6d65724c61686176&"
|
||||
"oauth_signature_method=HMAC-SHA1&"
|
||||
"oauth_timestamp=137131200&"
|
||||
"oauth_token=ad180jjd733klru7&"
|
||||
"oauth_version=1.0"
|
||||
)
|
||||
|
||||
def test_request_token(self):
|
||||
response = self._oauth_request(
|
||||
"/v1/request-token", {})
|
||||
@@ -113,23 +84,20 @@ class TestOAuthSession(unittest.TestCase):
|
||||
self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET)
|
||||
|
||||
def test_authenticated_call(self):
|
||||
params = {"method": "foo", "bar": "baz", "a": "äöüß/?&#"}
|
||||
params = {"method": "foo", "a": "äöüß/?&#", "äöüß/?&#": "a"}
|
||||
response = self._oauth_request(
|
||||
"/v1/echo", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET)
|
||||
expected = OAuthSession.urlencode(params)
|
||||
|
||||
self.assertEqual(response, expected, msg=response)
|
||||
self.assertEqual(text.parse_query(response), params)
|
||||
|
||||
def _oauth_request(self, endpoint, params=None,
|
||||
oauth_token=None, oauth_token_secret=None):
|
||||
session = OAuthSession(
|
||||
requests.session(),
|
||||
session = oauth.OAuth1Session(
|
||||
CONSUMER_KEY, CONSUMER_SECRET,
|
||||
oauth_token, oauth_token_secret,
|
||||
)
|
||||
url = TESTSERVER + endpoint
|
||||
return session.get(url, params.copy()).text
|
||||
return session.get(url, params=params).text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user