re-implement OAuth1.0 code

OAuth support for SmugMug needs some additional features
(auth-rebuild on redirect, query parameters in URL, ...)
and fixing this in the old code wouldn't work all that well.
This commit is contained in:
Mike Fährmann
2018-05-10 18:26:10 +02:00
parent 0e3883303f
commit 6a31ada9e3
7 changed files with 132 additions and 119 deletions

View File

@@ -9,7 +9,7 @@
"""Extract images from https://www.flickr.com/""" """Extract images from https://www.flickr.com/"""
from .common import Extractor, Message from .common import Extractor, Message
from .. import text, util, exception from .. import text, oauth, util, exception
class FlickrExtractor(Extractor): class FlickrExtractor(Extractor):
@@ -264,17 +264,20 @@ class FlickrAPI():
] ]
def __init__(self, extractor): def __init__(self, extractor):
self.api_key = extractor.config("api-key", self.API_KEY) api_key = extractor.config("api-key", self.API_KEY)
self.api_secret = extractor.config("api-secret", self.API_SECRET) api_secret = extractor.config("api-secret", self.API_SECRET)
token = extractor.config("access-token") token = extractor.config("access-token")
token_secret = extractor.config("access-token-secret") token_secret = extractor.config("access-token-secret")
if token and token_secret:
self.session = util.OAuthSession( if api_key and api_secret and token and token_secret:
extractor.session, self.session = oauth.OAuth1Session(
self.api_key, self.api_secret, token, token_secret) api_key, api_secret,
token, token_secret,
)
self.api_key = None self.api_key = None
else: else:
self.session = extractor.session self.session = extractor.session
self.api_key = api_key
self.maxsize = extractor.config("size-max") self.maxsize = extractor.config("size-max")
if isinstance(self.maxsize, str): if isinstance(self.maxsize, str):

View File

@@ -10,7 +10,7 @@
from .common import Extractor, Message from .common import Extractor, Message
from . import deviantart, flickr, reddit, tumblr from . import deviantart, flickr, reddit, tumblr
from .. import text, util, config from .. import text, oauth, config
import os import os
import urllib.parse import urllib.parse
@@ -70,21 +70,19 @@ class OAuthBase(Extractor):
def _oauth1_authorization_flow( def _oauth1_authorization_flow(
self, request_token_url, authorize_url, access_token_url): self, request_token_url, authorize_url, access_token_url):
"""Perform the OAuth 1.0a authorization flow""" """Perform the OAuth 1.0a authorization flow"""
del self.session.params["oauth_token"]
# get a request token # get a request token
params = {"oauth_callback": self.redirect_uri} params = {"oauth_callback": self.redirect_uri}
data = self.session.get(request_token_url, params=params).text data = self.session.get(request_token_url, params=params).text
data = text.parse_query(data) data = text.parse_query(data)
self.session.params["oauth_token"] = token = data["oauth_token"] self.session.auth.token_secret = data["oauth_token_secret"]
self.session.token_secret = data["oauth_token_secret"]
# get the user's authorization # get the user's authorization
params = {"oauth_token": token, "perms": "read"} params = {"oauth_token": data["oauth_token"], "perms": "read"}
data = self.open(authorize_url, params) data = self.open(authorize_url, params)
# exchange the request token for an access token # exchange the request token for an access token
# self.session.token = data["oauth_token"]
data = self.session.get(access_token_url, params=data).text data = self.session.get(access_token_url, params=data).text
data = text.parse_query(data) data = text.parse_query(data)
@@ -101,7 +99,7 @@ class OAuthBase(Extractor):
state = "gallery-dl_{}_{}".format( state = "gallery-dl_{}_{}".format(
self.subcategory, self.subcategory,
util.OAuthSession.nonce(8) oauth.nonce(8),
) )
auth_params = { auth_params = {
@@ -182,8 +180,7 @@ class OAuthFlickr(OAuthBase):
def __init__(self, match): def __init__(self, match):
OAuthBase.__init__(self, match) OAuthBase.__init__(self, match)
self.session = util.OAuthSession( self.session = oauth.OAuth1Session(
self.session,
self.oauth_config("api-key", flickr.FlickrAPI.API_KEY), self.oauth_config("api-key", flickr.FlickrAPI.API_KEY),
self.oauth_config("api-secret", flickr.FlickrAPI.API_SECRET), self.oauth_config("api-secret", flickr.FlickrAPI.API_SECRET),
) )
@@ -221,8 +218,7 @@ class OAuthTumblr(OAuthBase):
def __init__(self, match): def __init__(self, match):
OAuthBase.__init__(self, match) OAuthBase.__init__(self, match)
self.session = util.OAuthSession( self.session = oauth.OAuth1Session(
self.session,
self.oauth_config("api-key", tumblr.TumblrAPI.API_KEY), self.oauth_config("api-key", tumblr.TumblrAPI.API_KEY),
self.oauth_config("api-secret", tumblr.TumblrAPI.API_SECRET), self.oauth_config("api-secret", tumblr.TumblrAPI.API_SECRET),
) )

View File

@@ -9,7 +9,7 @@
"""Extract images from https://www.smugmug.com/""" """Extract images from https://www.smugmug.com/"""
from .common import Extractor, Message from .common import Extractor, Message
from .. import text, util, exception from .. import text, oauth, exception
BASE_PATTERN = ( BASE_PATTERN = (
r"(?:smugmug:(?!album:)(?:https?://)?([^/]+)|" r"(?:smugmug:(?!album:)(?:https?://)?([^/]+)|"
@@ -186,8 +186,7 @@ class SmugmugAPI():
token_secret = extractor.config("access-token-secret") token_secret = extractor.config("access-token-secret")
if api_key and api_secret and token and token_secret: if api_key and api_secret and token and token_secret:
self.session = util.OAuthSession( self.session = oauth.OAuth1Session(
extractor.session,
api_key, api_secret, api_key, api_secret,
token, token_secret, token, token_secret,
) )

View File

@@ -9,7 +9,7 @@
"""Extract images from https://www.tumblr.com/""" """Extract images from https://www.tumblr.com/"""
from .common import Extractor, Message from .common import Extractor, Message
from .. import text, util, exception from .. import text, oauth, exception
from datetime import datetime, timedelta from datetime import datetime, timedelta
import re import re
import time import time
@@ -261,8 +261,7 @@ class TumblrAPI():
token_secret = extractor.config("access-token-secret") token_secret = extractor.config("access-token-secret")
if api_key and api_secret and token and token_secret: if api_key and api_secret and token and token_secret:
self.session = util.OAuthSession( self.session = oauth.OAuth1Session(
extractor.session,
api_key, api_secret, api_key, api_secret,
token, token_secret, token, token_secret,
) )

101
gallery_dl/oauth.py Normal file
View File

@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation.
"""OAuth helper functions and classes"""
import hmac
import time
import base64
import random
import string
import hashlib
import urllib.parse
import requests
import requests.auth
from . import text
class OAuth1Session(requests.Session):
"""Extension to requests.Session objects to support OAuth 1.0"""
def __init__(self, consumer_key, consumer_secret,
token=None, token_secret=None):
requests.Session.__init__(self)
self.auth = OAuth1Client(
consumer_key, consumer_secret,
token, token_secret,
)
def rebuild_auth(self, prepared_request, response):
if "Authorization" in prepared_request.headers:
del prepared_request.headers["Authorization"]
prepared_request.prepare_auth(self.auth)
class OAuth1Client(requests.auth.AuthBase):
"""OAuth1.0a authentication"""
def __init__(self, consumer_key, consumer_secret,
token=None, token_secret=None):
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.token = token
self.token_secret = token_secret
def __call__(self, request):
oauth_params = [
("oauth_consumer_key", self.consumer_key),
("oauth_nonce", nonce(16)),
("oauth_signature_method", "HMAC-SHA1"),
("oauth_timestamp", str(int(time.time()))),
("oauth_version", "1.0"),
]
if self.token:
oauth_params.append(("oauth_token", self.token))
signature = self.generate_signature(request, oauth_params)
oauth_params.append(("oauth_signature", signature))
request.headers["Authorization"] = "OAuth " + ",".join(
key + '="' + value + '"' for key, value in oauth_params)
return request
def generate_signature(self, request, params):
"""Generate 'oauth_signature' value"""
url, _, query = request.url.partition("?")
params = params.copy()
for key, value in text.parse_query(query).items():
params.append((quote(key), quote(value)))
params.sort()
query = "&".join("=".join(item) for item in params)
message = concat(request.method, url, query).encode()
key = concat(self.consumer_secret, self.token_secret or "").encode()
signature = hmac.new(key, message, hashlib.sha1).digest()
return quote(base64.b64encode(signature).decode())
def concat(*args):
"""Concatenate 'args'"""
return "&".join(quote(item) for item in args)
def nonce(size, alphabet=string.ascii_letters):
"""Generate a nonce value with 'size' characters"""
return "".join(random.choice(alphabet) for _ in range(size))
def quote(value, quote=urllib.parse.quote):
"""Quote 'value' according to the OAuth1.0 standard"""
return quote(value, "~")

View File

@@ -11,14 +11,9 @@
import re import re
import os import os
import sys import sys
import hmac
import time
import base64
import random
import shutil import shutil
import string import string
import _string import _string
import hashlib
import sqlite3 import sqlite3
import datetime import datetime
import itertools import itertools
@@ -497,54 +492,6 @@ class PathFormat():
return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path
class OAuthSession():
"""Minimal wrapper for requests.session objects to support OAuth 1.0"""
def __init__(self, session, consumer_key, consumer_secret,
token=None, token_secret=None):
self.session = session
self.consumer_secret = consumer_secret
self.token_secret = token_secret or ""
self.params = {}
self.params["oauth_consumer_key"] = consumer_key
self.params["oauth_token"] = token
self.params["oauth_signature_method"] = "HMAC-SHA1"
self.params["oauth_version"] = "1.0"
def get(self, url, params, **kwargs):
params.update(self.params)
params["oauth_nonce"] = self.nonce(16)
params["oauth_timestamp"] = int(time.time())
return self.session.get(url + self.sign(url, params), **kwargs)
def sign(self, url, params):
"""Generate 'oauth_signature' value and return query string"""
query = self.urlencode(params)
message = self.concat("GET", url, query).encode()
key = self.concat(self.consumer_secret, self.token_secret).encode()
signature = hmac.new(key, message, hashlib.sha1).digest()
return "?{}&oauth_signature={}".format(
query, self.quote(base64.b64encode(signature).decode()))
@staticmethod
def concat(*args):
return "&".join(OAuthSession.quote(item) for item in args)
@staticmethod
def nonce(N, alphabet=string.ascii_letters):
return "".join(random.choice(alphabet) for _ in range(N))
@staticmethod
def quote(value, quote=urllib.parse.quote):
return quote(value, "~")
@staticmethod
def urlencode(params):
return "&".join(
OAuthSession.quote(str(key)) + "=" + OAuthSession.quote(str(value))
for key, value in sorted(params.items()) if value
)
class DownloadArchive(): class DownloadArchive():
def __init__(self, path, extractor): def __init__(self, path, extractor):

View File

@@ -8,10 +8,8 @@
# published by the Free Software Foundation. # published by the Free Software Foundation.
import unittest import unittest
import requests
from gallery_dl import text from gallery_dl import oauth, text
from gallery_dl.util import OAuthSession
TESTSERVER = "http://oauthbin.com" TESTSERVER = "http://oauthbin.com"
CONSUMER_KEY = "key" CONSUMER_KEY = "key"
@@ -25,7 +23,7 @@ ACCESS_TOKEN_SECRET = "accesssecret"
class TestOAuthSession(unittest.TestCase): class TestOAuthSession(unittest.TestCase):
def test_concat(self): def test_concat(self):
concat = OAuthSession.concat concat = oauth.concat
self.assertEqual(concat(), "") self.assertEqual(concat(), "")
self.assertEqual(concat("str"), "str") self.assertEqual(concat("str"), "str")
@@ -37,18 +35,18 @@ class TestOAuthSession(unittest.TestCase):
"GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da" "GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da"
) )
def test_nonce(self, N=16): def test_nonce(self, size=16):
nonce_values = set(OAuthSession.nonce(N) for _ in range(N)) nonce_values = set(oauth.nonce(size) for _ in range(size))
# uniqueness # uniqueness
self.assertEqual(len(nonce_values), N) self.assertEqual(len(nonce_values), size)
# length # length
for nonce in nonce_values: for nonce in nonce_values:
self.assertEqual(len(nonce), N) self.assertEqual(len(nonce), size)
def test_quote(self): def test_quote(self):
quote = OAuthSession.quote quote = oauth.quote
reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü" reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü"
unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ" unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ"
@@ -65,33 +63,6 @@ class TestOAuthSession(unittest.TestCase):
self.assertTrue(len(quoted) >= 3) self.assertTrue(len(quoted) >= 3)
self.assertEqual(quoted_hex.upper(), quoted_hex) self.assertEqual(quoted_hex.upper(), quoted_hex)
def test_urlencode(self):
urlencode = OAuthSession.urlencode
self.assertEqual(urlencode({}), "")
self.assertEqual(urlencode({"foo": "bar"}), "foo=bar")
self.assertEqual(
urlencode({"foo": "bar", "baz": "a", "a": "baz"}),
"a=baz&baz=a&foo=bar"
)
self.assertEqual(
urlencode({
"oauth_consumer_key": "0685bd9184jfhq22",
"oauth_token": "ad180jjd733klru7",
"oauth_signature_method": "HMAC-SHA1",
"oauth_timestamp": 137131200,
"oauth_nonce": "4572616e48616d6d65724c61686176",
"oauth_version": "1.0"
}),
"oauth_consumer_key=0685bd9184jfhq22&"
"oauth_nonce=4572616e48616d6d65724c61686176&"
"oauth_signature_method=HMAC-SHA1&"
"oauth_timestamp=137131200&"
"oauth_token=ad180jjd733klru7&"
"oauth_version=1.0"
)
def test_request_token(self): def test_request_token(self):
response = self._oauth_request( response = self._oauth_request(
"/v1/request-token", {}) "/v1/request-token", {})
@@ -113,23 +84,20 @@ class TestOAuthSession(unittest.TestCase):
self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET) self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET)
def test_authenticated_call(self): def test_authenticated_call(self):
params = {"method": "foo", "bar": "baz", "a": "äöüß/?&#"} params = {"method": "foo", "a": "äöüß/?&#", "äöüß/?&#": "a"}
response = self._oauth_request( response = self._oauth_request(
"/v1/echo", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET) "/v1/echo", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET)
expected = OAuthSession.urlencode(params)
self.assertEqual(response, expected, msg=response)
self.assertEqual(text.parse_query(response), params) self.assertEqual(text.parse_query(response), params)
def _oauth_request(self, endpoint, params=None, def _oauth_request(self, endpoint, params=None,
oauth_token=None, oauth_token_secret=None): oauth_token=None, oauth_token_secret=None):
session = OAuthSession( session = oauth.OAuth1Session(
requests.session(),
CONSUMER_KEY, CONSUMER_SECRET, CONSUMER_KEY, CONSUMER_SECRET,
oauth_token, oauth_token_secret, oauth_token, oauth_token_secret,
) )
url = TESTSERVER + endpoint url = TESTSERVER + endpoint
return session.get(url, params.copy()).text return session.get(url, params=params).text
if __name__ == "__main__": if __name__ == "__main__":