Add OAuth2 handling to the main APIController.

This commit is contained in:
Max Goodman
2012-03-14 19:33:58 -07:00
committed by Logan Hanks
parent 81caf1213d
commit deff9405da
3 changed files with 20 additions and 6 deletions

View File

@@ -60,6 +60,7 @@ from r2.lib.filters import safemarkdown
from r2.lib.scraper import str_to_image
from r2.controllers.api_docs import api_doc, api_section
from r2.lib.search import SearchQuery
from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope
import csv
from collections import defaultdict
@@ -96,11 +97,19 @@ class ApiminimalController(MinimalController):
form._send_data(iden = iden)
class ApiController(RedditController):
class ApiController(RedditController, OAuth2ResourceController):
"""
Controller which deals with almost all AJAX site interaction.
"""
def pre(self):
RedditController.pre(self)
if self._get_bearer_token(strict=False):
OAuth2ResourceController.pre(self)
if c.oauth_user:
c.user = c.oauth_user
c.user_is_loggedin = True
@validatedForm()
def ajax_login_redirect(self, form, jquery, dest):
form.redirect("/login" + query_string(dict(dest=dest)))

View File

@@ -205,7 +205,7 @@ class OAuth2ResourceController(MinimalController):
require_https()
try:
access_token = self._get_bearer_token()
access_token = OAuth2AccessToken.get_token(self._get_bearer_token())
require(access_token)
c.oauth2_access_token = access_token
account = Account._byID(access_token.user_id, data=True)
@@ -227,14 +227,15 @@ class OAuth2ResourceController(MinimalController):
def _auth_error(self, code, error):
abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)])
def _get_bearer_token(self):
def _get_bearer_token(self, strict=True):
auth = request.headers.get("Authorization")
try:
auth_scheme, bearer_token = require_split(auth, 2)
require(auth_scheme.lower() == "bearer")
return OAuth2AccessToken.get_token(bearer_token)
return bearer_token
except RequirementException:
self._auth_error(400, "invalid_request")
if strict:
self._auth_error(400, "invalid_request")
def require_oauth2_scope(*scopes):
def oauth2_scope_wrap(fn):

View File

@@ -264,7 +264,11 @@ class Account(Thing):
return modhash(self, rand = rand, test = test)
def valid_hash(self, hash):
return valid_hash(self, hash)
if self == c.oauth_user:
# OAuth authenticated requests do not require CSRF protection.
return True
else:
return valid_hash(self, hash)
@classmethod
@memoize('account._by_name')