mirror of
https://github.com/reddit-archive/reddit.git
synced 2026-01-26 23:39:11 -05:00
Add OAuth2 handling to the main APIController.
This commit is contained in:
@@ -60,6 +60,7 @@ from r2.lib.filters import safemarkdown
|
||||
from r2.lib.scraper import str_to_image
|
||||
from r2.controllers.api_docs import api_doc, api_section
|
||||
from r2.lib.search import SearchQuery
|
||||
from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope
|
||||
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
@@ -96,11 +97,19 @@ class ApiminimalController(MinimalController):
|
||||
form._send_data(iden = iden)
|
||||
|
||||
|
||||
class ApiController(RedditController):
|
||||
class ApiController(RedditController, OAuth2ResourceController):
|
||||
"""
|
||||
Controller which deals with almost all AJAX site interaction.
|
||||
"""
|
||||
|
||||
def pre(self):
|
||||
RedditController.pre(self)
|
||||
if self._get_bearer_token(strict=False):
|
||||
OAuth2ResourceController.pre(self)
|
||||
if c.oauth_user:
|
||||
c.user = c.oauth_user
|
||||
c.user_is_loggedin = True
|
||||
|
||||
@validatedForm()
|
||||
def ajax_login_redirect(self, form, jquery, dest):
|
||||
form.redirect("/login" + query_string(dict(dest=dest)))
|
||||
|
||||
@@ -205,7 +205,7 @@ class OAuth2ResourceController(MinimalController):
|
||||
require_https()
|
||||
|
||||
try:
|
||||
access_token = self._get_bearer_token()
|
||||
access_token = OAuth2AccessToken.get_token(self._get_bearer_token())
|
||||
require(access_token)
|
||||
c.oauth2_access_token = access_token
|
||||
account = Account._byID(access_token.user_id, data=True)
|
||||
@@ -227,14 +227,15 @@ class OAuth2ResourceController(MinimalController):
|
||||
def _auth_error(self, code, error):
|
||||
abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)])
|
||||
|
||||
def _get_bearer_token(self):
|
||||
def _get_bearer_token(self, strict=True):
|
||||
auth = request.headers.get("Authorization")
|
||||
try:
|
||||
auth_scheme, bearer_token = require_split(auth, 2)
|
||||
require(auth_scheme.lower() == "bearer")
|
||||
return OAuth2AccessToken.get_token(bearer_token)
|
||||
return bearer_token
|
||||
except RequirementException:
|
||||
self._auth_error(400, "invalid_request")
|
||||
if strict:
|
||||
self._auth_error(400, "invalid_request")
|
||||
|
||||
def require_oauth2_scope(*scopes):
|
||||
def oauth2_scope_wrap(fn):
|
||||
|
||||
@@ -264,7 +264,11 @@ class Account(Thing):
|
||||
return modhash(self, rand = rand, test = test)
|
||||
|
||||
def valid_hash(self, hash):
|
||||
return valid_hash(self, hash)
|
||||
if self == c.oauth_user:
|
||||
# OAuth authenticated requests do not require CSRF protection.
|
||||
return True
|
||||
else:
|
||||
return valid_hash(self, hash)
|
||||
|
||||
@classmethod
|
||||
@memoize('account._by_name')
|
||||
|
||||
Reference in New Issue
Block a user