diff --git a/r2/r2/controllers/oauth2.py b/r2/r2/controllers/oauth2.py index 809601c38..28f709dcb 100644 --- a/r2/r2/controllers/oauth2.py +++ b/r2/r2/controllers/oauth2.py @@ -37,7 +37,7 @@ from r2.models.token import ( from r2.lib.errors import ForbiddenError, errors from r2.lib.pages import OAuth2AuthorizationPage from r2.lib.require import RequirementException, require, require_split -from r2.lib.utils import constant_time_compare, parse_http_basic +from r2.lib.utils import constant_time_compare, parse_http_basic, UrlParser from r2.lib.validator import ( nop, validate, @@ -51,6 +51,13 @@ from r2.lib.validator import ( VOAuth2RefreshToken, ) + +def _update_redirect_uri(base_redirect_uri, params): + parsed = UrlParser(base_redirect_uri) + parsed.update_query(**params) + return parsed.unparse() + + class OAuth2FrontendController(RedditController): def check_for_bearer_token(self): pass @@ -81,7 +88,8 @@ class OAuth2FrontendController(RedditController): else: resp["error"] = "invalid_request" - return self.redirect(redirect_uri+"?"+urlencode(resp), code=302) + final_redirect = _update_redirect_uri(redirect_uri, resp) + return self.redirect(final_redirect, code=302) @validate(VUser(), response_type = VOneOf("response_type", ("code",)), @@ -142,7 +150,8 @@ class OAuth2FrontendController(RedditController): c.user._id36, scope, duration == "permanent") resp = {"code": code._id, "state": state} - return self.redirect(redirect_uri+"?"+urlencode(resp), code=302) + final_redirect = _update_redirect_uri(redirect_uri, resp) + return self.redirect(final_redirect, code=302) else: return self._error_response(state, redirect_uri)