diff --git a/r2/r2/controllers/api.py b/r2/r2/controllers/api.py index 78b36557b..1a4f204c1 100755 --- a/r2/r2/controllers/api.py +++ b/r2/r2/controllers/api.py @@ -60,6 +60,7 @@ from r2.lib.filters import safemarkdown from r2.lib.scraper import str_to_image from r2.controllers.api_docs import api_doc, api_section from r2.lib.search import SearchQuery +from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope import csv from collections import defaultdict @@ -96,11 +97,19 @@ class ApiminimalController(MinimalController): form._send_data(iden = iden) -class ApiController(RedditController): +class ApiController(RedditController, OAuth2ResourceController): """ Controller which deals with almost all AJAX site interaction. """ + def pre(self): + RedditController.pre(self) + if self._get_bearer_token(strict=False): + OAuth2ResourceController.pre(self) + if c.oauth_user: + c.user = c.oauth_user + c.user_is_loggedin = True + @validatedForm() def ajax_login_redirect(self, form, jquery, dest): form.redirect("/login" + query_string(dict(dest=dest))) diff --git a/r2/r2/controllers/oauth2.py b/r2/r2/controllers/oauth2.py index c1cd29be9..a4e6b40f0 100644 --- a/r2/r2/controllers/oauth2.py +++ b/r2/r2/controllers/oauth2.py @@ -205,7 +205,7 @@ class OAuth2ResourceController(MinimalController): require_https() try: - access_token = self._get_bearer_token() + access_token = OAuth2AccessToken.get_token(self._get_bearer_token()) require(access_token) c.oauth2_access_token = access_token account = Account._byID(access_token.user_id, data=True) @@ -227,14 +227,15 @@ class OAuth2ResourceController(MinimalController): def _auth_error(self, code, error): abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)]) - def _get_bearer_token(self): + def _get_bearer_token(self, strict=True): auth = request.headers.get("Authorization") try: auth_scheme, bearer_token = require_split(auth, 2) require(auth_scheme.lower() == "bearer") - return OAuth2AccessToken.get_token(bearer_token) + return bearer_token except RequirementException: - self._auth_error(400, "invalid_request") + if strict: + self._auth_error(400, "invalid_request") def require_oauth2_scope(*scopes): def oauth2_scope_wrap(fn): diff --git a/r2/r2/models/account.py b/r2/r2/models/account.py index ba72a2149..ae0120c38 100644 --- a/r2/r2/models/account.py +++ b/r2/r2/models/account.py @@ -264,7 +264,11 @@ class Account(Thing): return modhash(self, rand = rand, test = test) def valid_hash(self, hash): - return valid_hash(self, hash) + if self == c.oauth_user: + # OAuth authenticated requests do not require CSRF protection. + return True + else: + return valid_hash(self, hash) @classmethod @memoize('account._by_name')