diff --git a/r2/r2/controllers/oauth2.py b/r2/r2/controllers/oauth2.py index e00d041fc..dd3be731c 100644 --- a/r2/r2/controllers/oauth2.py +++ b/r2/r2/controllers/oauth2.py @@ -69,9 +69,11 @@ class OAuth2FrontendController(RedditController): if not redirect_uri or not client or redirect_uri != client.redirect_uri: abort(ForbiddenError(errors.OAUTH2_INVALID_REDIRECT_URI)) - def _error_response(self, resp, redirect_uri): + def _error_response(self, state, redirect_uri): """Return an error redirect, but only if client_id and redirect_uri are valid.""" + resp = {"state": state} + if (errors.OAUTH2_INVALID_CLIENT, "client_id") in c.errors: resp["error"] = "unauthorized_client" elif (errors.OAUTH2_ACCESS_DENIED, "authorize") in c.errors: @@ -117,15 +119,11 @@ class OAuth2FrontendController(RedditController): self._check_redirect_uri(client, redirect_uri) - resp = {} - if state: - resp["state"] = state - if not c.errors: c.deny_frames = True return OAuth2AuthorizationPage(client, redirect_uri, scope_info[scope], state).render() else: - return self._error_response(resp, redirect_uri) + return self._error_response(state, redirect_uri) @validate(VUser(), VModhash(fatal=False), @@ -139,16 +137,12 @@ class OAuth2FrontendController(RedditController): self._check_redirect_uri(client, redirect_uri) - resp = {} - if state: - resp["state"] = state - if not c.errors: code = OAuth2AuthorizationCode._new(client._id, redirect_uri, c.user._id, scope) - resp["code"] = code._id + resp = {"code": code._id, "state": state} return self.redirect(redirect_uri+"?"+urlencode(resp), code=302) else: - return self._error_response(resp, redirect_uri) + return self._error_response(state, redirect_uri) class OAuth2AccessController(MinimalController): def pre(self):