Bring OAuth2 into the core controllers

This makes it easier to avoid issues with running pre()
functions multiple times (multiply subclassing was causing
problems) and makes it so all resources are blocked from oauth
access unless explicitly enabled (instead of randomly allowing
access as a "logged out user" to endpoints that aren't part of
an OAuth2ResourceController)

Conflicts:

	r2/r2/controllers/apiv1.py
	r2/r2/controllers/oauth2.py
This commit is contained in:
Keith Mitchell
2014-02-14 15:42:23 -08:00
parent d368e31cb4
commit b1e97e8bff
10 changed files with 91 additions and 86 deletions

View File

@@ -75,7 +75,7 @@ from r2.lib.filters import safemarkdown
from r2.lib.media import str_to_image
from r2.controllers.api_docs import api_doc, api_section
from r2.lib.search import SearchQuery
from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope
from r2.controllers.oauth2 import require_oauth2_scope
from r2.lib.template_helpers import add_sr, get_domain
from r2.lib.system_messages import notify_user_added
from r2.controllers.ipn import generate_blob
@@ -132,15 +132,10 @@ class ApiminimalController(MinimalController):
form._send_data(iden = iden)
class ApiController(RedditController, OAuth2ResourceController):
class ApiController(RedditController):
"""
Controller which deals with almost all AJAX site interaction.
"""
def pre(self):
self.check_for_bearer_token()
RedditController.pre(self)
@validatedForm()
def ajax_login_redirect(self, form, jquery, dest):
form.redirect("/login" + query_string(dict(dest=dest)))

View File

@@ -22,11 +22,13 @@
from pylons import c
from r2.controllers.api_docs import api_doc, api_section
from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope
from r2.controllers.oauth2 import require_oauth2_scope
from r2.controllers.reddit_base import OAuth2ResourceController
from r2.lib.jsontemplates import IdentityJsonTemplate
class APIv1Controller(OAuth2ResourceController):
def pre(self):
OAuth2ResourceController.pre(self)
self.check_for_bearer_token()
def try_pagecache(self):

View File

@@ -90,6 +90,9 @@ class ErrorController(RedditController):
This behaviour can be altered by changing the parameters to the
ErrorDocuments middleware in your config/middleware.py file.
"""
def check_for_bearer_token(self):
pass
allowed_render_styles = ('html', 'xml', 'js', 'embed', '', "compact", 'api')
# List of admins to blame (skip the first admin, "reddit")
# If list is empty, just blame "an admin"

View File

@@ -55,7 +55,7 @@ from r2.lib import sup
import r2.lib.db.thing as thing
from r2.lib.errors import errors
from listingcontroller import ListingController
from oauth2 import OAuth2ResourceController, require_oauth2_scope
from oauth2 import require_oauth2_scope
from api_docs import api_doc, api_section
from pylons import c, request, response
from r2.models.token import EmailVerificationToken
@@ -68,14 +68,10 @@ import re, socket
import time as time_module
from urllib import quote_plus
class FrontController(RedditController, OAuth2ResourceController):
class FrontController(RedditController):
allow_stylesheets = True
def pre(self):
self.check_for_bearer_token()
RedditController.pre(self)
@validate(article=VLink('article'),
comment=VCommentID('comment'))
def GET_oldinfo(self, article, type, dest, rest=None, comment=''):

View File

@@ -20,7 +20,7 @@
# Inc. All Rights Reserved.
###############################################################################
from oauth2 import OAuth2ResourceController, require_oauth2_scope
from oauth2 import require_oauth2_scope
from reddit_base import RedditController, base_listing, paginated_listing
from r2.models import *
@@ -53,7 +53,7 @@ from pylons.controllers.util import redirect_to
import random
from functools import partial
class ListingController(RedditController, OAuth2ResourceController):
class ListingController(RedditController):
"""Generalized controller for pages with lists of links."""
# toggle skipping of links based on the users' save/hide/vote preferences
@@ -89,10 +89,6 @@ class ListingController(RedditController, OAuth2ResourceController):
render_params = {}
extra_page_classes = ['listing-page']
def pre(self):
self.check_for_bearer_token()
RedditController.pre(self)
@property
def menus(self):
"""list of menus underneat the header (e.g., sort, time, kind,
@@ -1025,13 +1021,9 @@ class RedditsController(ListingController):
self.where = where
return ListingController.GET_listing(self, **env)
class MyredditsController(ListingController, OAuth2ResourceController):
class MyredditsController(ListingController):
render_cls = MySubredditsPage
def pre(self):
self.check_for_bearer_token()
ListingController.pre(self)
@property
def menus(self):
buttons = (NavButton(plurals.subscriber, 'subscriber'),

View File

@@ -26,10 +26,7 @@ from pylons.i18n import _
from r2.config.extensions import set_extension
from r2.controllers.api_docs import api_doc, api_section
from r2.controllers.reddit_base import RedditController, abort_with_error
from r2.controllers.oauth2 import (
OAuth2ResourceController,
require_oauth2_scope,
)
from r2.controllers.oauth2 import require_oauth2_scope
from r2.models.account import Account
from r2.models.subreddit import (
FakeSubreddit,
@@ -75,12 +72,11 @@ multi_description_json_spec = VValidatedJSON.Object({
})
class MultiApiController(RedditController, OAuth2ResourceController):
class MultiApiController(RedditController):
on_validation_error = staticmethod(abort_with_error)
def pre(self):
set_extension(request.environ, "json")
self.check_for_bearer_token()
RedditController.pre(self)
@require_oauth2_scope("read")

View File

@@ -52,6 +52,9 @@ from r2.lib.validator import (
)
class OAuth2FrontendController(RedditController):
def check_for_bearer_token(self):
pass
def pre(self):
RedditController.pre(self)
require_https()
@@ -300,56 +303,6 @@ class OAuth2AccessController(MinimalController):
return self.api_wrapper(resp)
class OAuth2ResourceController(MinimalController):
def pre(self):
set_extension(request.environ, "json")
MinimalController.pre(self)
require_https()
try:
access_token = OAuth2AccessToken.get_token(self._get_bearer_token())
require(access_token)
require(access_token.check_valid())
c.oauth2_access_token = access_token
account = Account._byID36(access_token.user_id, data=True)
require(account)
require(not account._deleted)
c.oauth_user = account
except RequirementException:
self._auth_error(401, "invalid_token")
handler = self._get_action_handler()
if handler:
oauth2_perms = getattr(handler, "oauth2_perms", None)
if oauth2_perms:
grant = OAuth2Scope(access_token.scope)
required = set(oauth2_perms['allowed_scopes'])
if not grant.has_access(c.site.name, required):
self._auth_error(403, "insufficient_scope")
c.oauth_scope = grant
else:
self._auth_error(400, "invalid_request")
def check_for_bearer_token(self):
if self._get_bearer_token(strict=False):
OAuth2ResourceController.pre(self)
if c.oauth_user:
c.user = c.oauth_user
c.user_is_loggedin = True
def _auth_error(self, code, error):
abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)])
def _get_bearer_token(self, strict=True):
auth = request.headers.get("Authorization")
try:
auth_scheme, bearer_token = require_split(auth, 2)
require(auth_scheme.lower() == "bearer")
return bearer_token
except RequirementException:
if strict:
self._auth_error(400, "invalid_request")
def require_oauth2_scope(*scopes):
def oauth2_scope_wrap(fn):
fn.oauth2_perms = {"allowed_scopes": scopes}

View File

@@ -43,7 +43,7 @@ from pylons.controllers.util import redirect_to
from pylons.i18n import _
from pylons.i18n.translation import LanguageError
from r2.config.extensions import is_api
from r2.config.extensions import is_api, set_extension
from r2.lib import filters, pages, utils, hooks
from r2.lib.authentication import authenticate_user
from r2.lib.base import BaseController, abort
@@ -56,6 +56,7 @@ from r2.lib.errors import (
reddit_http_error,
)
from r2.lib.filters import _force_utf8, _force_unicode
from r2.lib.require import RequirementException, require, require_split
from r2.lib.strings import strings
from r2.lib.template_helpers import add_sr, JSPreload
from r2.lib.tracking import encrypt, decrypt
@@ -83,6 +84,7 @@ from r2.lib.validator import (
VTarget,
)
from r2.models import (
Account,
All,
AllMinus,
DefaultSR,
@@ -97,6 +99,8 @@ from r2.models import (
ModMinus,
MultiReddit,
NotFound,
OAuth2AccessToken,
OAuth2Scope,
Random,
RandomNSFW,
RandomSubscription,
@@ -687,6 +691,12 @@ def require_https():
if not c.secure:
abort(ForbiddenError(errors.HTTPS_REQUIRED))
def require_domain(required_domain):
if not is_subdomain(request.host, required_domain):
abort(ForbiddenError(errors.WRONG_DOMAIN))
def disable_subreddit_css():
def wrap(f):
@wraps(f)
@@ -968,7 +978,63 @@ class MinimalController(BaseController):
return request.method.upper() != "POST"
class RedditController(MinimalController):
class OAuth2ResourceController(MinimalController):
def authenticate_with_token(self):
set_extension(request.environ, "json")
set_content_type()
require_https()
require_domain(g.oauth_domain)
try:
access_token = OAuth2AccessToken.get_token(self._get_bearer_token())
require(access_token)
require(access_token.check_valid())
c.oauth2_access_token = access_token
account = Account._byID36(access_token.user_id, data=True)
require(account)
require(not account._deleted)
c.oauth_user = account
except RequirementException:
self._auth_error(401, "invalid_token")
handler = self._get_action_handler()
if handler:
oauth2_perms = getattr(handler, "oauth2_perms", None)
if oauth2_perms or True:
grant = OAuth2Scope(access_token.scope)
required = set(oauth2_perms['allowed_scopes'])
if not grant.has_access(c.site.name, required):
self._auth_error(403, "insufficient_scope")
c.oauth_scope = grant
else:
self._auth_error(400, "invalid_request")
def check_for_bearer_token(self):
if self._get_bearer_token(strict=False):
self.authenticate_with_token()
if c.oauth_user:
c.user = c.oauth_user
c.user_is_loggedin = True
def _auth_error(self, code, error):
abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)])
def _get_bearer_token(self, strict=True):
auth = request.headers.get("Authorization")
if not auth:
return None
try:
auth_scheme, bearer_token = require_split(auth, 2)
require(auth_scheme.lower() == "bearer")
return bearer_token
except RequirementException:
if strict:
self._auth_error(400, "invalid_request")
else:
return None
class RedditController(OAuth2ResourceController):
@staticmethod
def login(user, rem=False):
@@ -1039,6 +1105,8 @@ class RedditController(MinimalController):
maybe_admin = False
is_otpcookie_valid = False
self.check_for_bearer_token()
# no logins for RSS feed unless valid_feed has already been called
if not c.user:
if c.extension != "rss":

View File

@@ -23,7 +23,7 @@
from pylons import request, g, c
from pylons.controllers.util import redirect_to
from reddit_base import RedditController
from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope
from r2.controllers.oauth2 import require_oauth2_scope
from r2.lib.utils import url_links_builder
from reddit_base import paginated_listing
from r2.models.wiki import (WikiPage, WikiRevision, ContentLengthError,
@@ -94,7 +94,7 @@ RENDERERS_BY_PAGE = {"config/sidebar": "reddit",
"config/description": "reddit",
"config/stylesheet": "stylesheet"}
class WikiController(RedditController, OAuth2ResourceController):
class WikiController(RedditController):
allow_stylesheets = True
@require_oauth2_scope("wikiread")
@@ -308,7 +308,6 @@ class WikiController(RedditController, OAuth2ResourceController):
abort(reddit_http_error(code, reason, **data))
def pre(self):
self.check_for_bearer_token()
RedditController.pre(self)
if g.disable_wiki and not c.user_is_admin:
self.handle_error(403, 'WIKI_DOWN')

View File

@@ -30,6 +30,7 @@ from copy import copy
error_list = dict((
('USER_REQUIRED', _("please login to do that")),
('HTTPS_REQUIRED', _("this page must be accessed using https")),
('WRONG_DOMAIN', _("you can't do that on this domain")),
('VERIFIED_USER_REQUIRED', _("you need to set a valid email address to do that.")),
('NO_URL', _('a url is required')),
('BAD_URL', _('you should check that url')),