diff --git a/r2/r2/controllers/oauth2.py b/r2/r2/controllers/oauth2.py index bf6802174..32f72891f 100644 --- a/r2/r2/controllers/oauth2.py +++ b/r2/r2/controllers/oauth2.py @@ -36,6 +36,7 @@ from r2.controllers.errors import ForbiddenError, errors from validator import validate, VRequired, VOneOf, VUser, VModhash, VOAuth2ClientID, VOAuth2Scope from r2.lib.pages import OAuth2AuthorizationPage from r2.lib.require import RequirementException, require, require_split +from r2.lib.utils import parse_http_basic scope_info = { "identity": { @@ -155,13 +156,7 @@ class OAuth2AccessController(MinimalController): def _get_client_auth(self): auth = request.headers.get("Authorization") try: - auth_scheme, auth_token = require_split(auth, 2) - require(auth_scheme.lower() == "basic") - try: - auth_data = base64.b64decode(auth_token) - except TypeError: - raise RequirementException - client_id, client_secret = require_split(auth_data, 2, ":") + client_id, client_secret = parse_http_basic(auth) client = OAuth2Client.get_token(client_id) require(client) require(client.secret == client_secret) diff --git a/r2/r2/lib/utils/utils.py b/r2/r2/lib/utils/utils.py index 41b2f5a25..20b653aa4 100644 --- a/r2/r2/lib/utils/utils.py +++ b/r2/r2/lib/utils/utils.py @@ -1409,3 +1409,17 @@ def find_containing_network(ip_ranges, address): def is_throttled(address): """Determine if an IP address is in a throttled range.""" return bool(find_containing_network(g.throttles, address)) + + +def parse_http_basic(authorization_header): + """Parse the username/credentials out of an HTTP Basic Auth header. + + Raises RequirementException if anything is uncool. + """ + auth_scheme, auth_token = require_split(auth, 2) + require(auth_scheme.lower() == "basic") + try: + auth_data = base64.b64decode(auth_token) + except TypeError: + raise RequirementException + return require_split(auth_data, 2, ":")