From b1e97e8bff0c0bfaa035e0a9ea16a4cd58630476 Mon Sep 17 00:00:00 2001 From: Keith Mitchell Date: Fri, 14 Feb 2014 15:42:23 -0800 Subject: [PATCH] Bring OAuth2 into the core controllers This makes it easier to avoid issues with running pre() functions multiple times (multiply subclassing was causing problems) and makes it so all resources are blocked from oauth access unless explicitly enabled (instead of randomly allowing access as a "logged out user" to endpoints that aren't part of an OAuth2ResourceController) Conflicts: r2/r2/controllers/apiv1.py r2/r2/controllers/oauth2.py --- r2/r2/controllers/api.py | 9 +--- r2/r2/controllers/apiv1.py | 4 +- r2/r2/controllers/error.py | 3 ++ r2/r2/controllers/front.py | 8 +-- r2/r2/controllers/listingcontroller.py | 14 ++--- r2/r2/controllers/multi.py | 8 +-- r2/r2/controllers/oauth2.py | 53 ++----------------- r2/r2/controllers/reddit_base.py | 72 +++++++++++++++++++++++++- r2/r2/controllers/wiki.py | 5 +- r2/r2/lib/errors.py | 1 + 10 files changed, 91 insertions(+), 86 deletions(-) diff --git a/r2/r2/controllers/api.py b/r2/r2/controllers/api.py index 3d9d36814..6e6d54cf6 100755 --- a/r2/r2/controllers/api.py +++ b/r2/r2/controllers/api.py @@ -75,7 +75,7 @@ from r2.lib.filters import safemarkdown from r2.lib.media import str_to_image from r2.controllers.api_docs import api_doc, api_section from r2.lib.search import SearchQuery -from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope +from r2.controllers.oauth2 import require_oauth2_scope from r2.lib.template_helpers import add_sr, get_domain from r2.lib.system_messages import notify_user_added from r2.controllers.ipn import generate_blob @@ -132,15 +132,10 @@ class ApiminimalController(MinimalController): form._send_data(iden = iden) -class ApiController(RedditController, OAuth2ResourceController): +class ApiController(RedditController): """ Controller which deals with almost all AJAX site interaction. """ - - def pre(self): - self.check_for_bearer_token() - RedditController.pre(self) - @validatedForm() def ajax_login_redirect(self, form, jquery, dest): form.redirect("/login" + query_string(dict(dest=dest))) diff --git a/r2/r2/controllers/apiv1.py b/r2/r2/controllers/apiv1.py index f10182bcb..d77e728be 100644 --- a/r2/r2/controllers/apiv1.py +++ b/r2/r2/controllers/apiv1.py @@ -22,11 +22,13 @@ from pylons import c from r2.controllers.api_docs import api_doc, api_section -from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope +from r2.controllers.oauth2 import require_oauth2_scope +from r2.controllers.reddit_base import OAuth2ResourceController from r2.lib.jsontemplates import IdentityJsonTemplate class APIv1Controller(OAuth2ResourceController): def pre(self): + OAuth2ResourceController.pre(self) self.check_for_bearer_token() def try_pagecache(self): diff --git a/r2/r2/controllers/error.py b/r2/r2/controllers/error.py index 1ce86536b..2651ad590 100644 --- a/r2/r2/controllers/error.py +++ b/r2/r2/controllers/error.py @@ -90,6 +90,9 @@ class ErrorController(RedditController): This behaviour can be altered by changing the parameters to the ErrorDocuments middleware in your config/middleware.py file. """ + def check_for_bearer_token(self): + pass + allowed_render_styles = ('html', 'xml', 'js', 'embed', '', "compact", 'api') # List of admins to blame (skip the first admin, "reddit") # If list is empty, just blame "an admin" diff --git a/r2/r2/controllers/front.py b/r2/r2/controllers/front.py index 64c8a7742..a9e1034dc 100755 --- a/r2/r2/controllers/front.py +++ b/r2/r2/controllers/front.py @@ -55,7 +55,7 @@ from r2.lib import sup import r2.lib.db.thing as thing from r2.lib.errors import errors from listingcontroller import ListingController -from oauth2 import OAuth2ResourceController, require_oauth2_scope +from oauth2 import require_oauth2_scope from api_docs import api_doc, api_section from pylons import c, request, response from r2.models.token import EmailVerificationToken @@ -68,14 +68,10 @@ import re, socket import time as time_module from urllib import quote_plus -class FrontController(RedditController, OAuth2ResourceController): +class FrontController(RedditController): allow_stylesheets = True - def pre(self): - self.check_for_bearer_token() - RedditController.pre(self) - @validate(article=VLink('article'), comment=VCommentID('comment')) def GET_oldinfo(self, article, type, dest, rest=None, comment=''): diff --git a/r2/r2/controllers/listingcontroller.py b/r2/r2/controllers/listingcontroller.py index 9cf548404..a5a9b3a52 100755 --- a/r2/r2/controllers/listingcontroller.py +++ b/r2/r2/controllers/listingcontroller.py @@ -20,7 +20,7 @@ # Inc. All Rights Reserved. ############################################################################### -from oauth2 import OAuth2ResourceController, require_oauth2_scope +from oauth2 import require_oauth2_scope from reddit_base import RedditController, base_listing, paginated_listing from r2.models import * @@ -53,7 +53,7 @@ from pylons.controllers.util import redirect_to import random from functools import partial -class ListingController(RedditController, OAuth2ResourceController): +class ListingController(RedditController): """Generalized controller for pages with lists of links.""" # toggle skipping of links based on the users' save/hide/vote preferences @@ -89,10 +89,6 @@ class ListingController(RedditController, OAuth2ResourceController): render_params = {} extra_page_classes = ['listing-page'] - def pre(self): - self.check_for_bearer_token() - RedditController.pre(self) - @property def menus(self): """list of menus underneat the header (e.g., sort, time, kind, @@ -1025,13 +1021,9 @@ class RedditsController(ListingController): self.where = where return ListingController.GET_listing(self, **env) -class MyredditsController(ListingController, OAuth2ResourceController): +class MyredditsController(ListingController): render_cls = MySubredditsPage - def pre(self): - self.check_for_bearer_token() - ListingController.pre(self) - @property def menus(self): buttons = (NavButton(plurals.subscriber, 'subscriber'), diff --git a/r2/r2/controllers/multi.py b/r2/r2/controllers/multi.py index a5a57a2b2..b982e7fd0 100644 --- a/r2/r2/controllers/multi.py +++ b/r2/r2/controllers/multi.py @@ -26,10 +26,7 @@ from pylons.i18n import _ from r2.config.extensions import set_extension from r2.controllers.api_docs import api_doc, api_section from r2.controllers.reddit_base import RedditController, abort_with_error -from r2.controllers.oauth2 import ( - OAuth2ResourceController, - require_oauth2_scope, -) +from r2.controllers.oauth2 import require_oauth2_scope from r2.models.account import Account from r2.models.subreddit import ( FakeSubreddit, @@ -75,12 +72,11 @@ multi_description_json_spec = VValidatedJSON.Object({ }) -class MultiApiController(RedditController, OAuth2ResourceController): +class MultiApiController(RedditController): on_validation_error = staticmethod(abort_with_error) def pre(self): set_extension(request.environ, "json") - self.check_for_bearer_token() RedditController.pre(self) @require_oauth2_scope("read") diff --git a/r2/r2/controllers/oauth2.py b/r2/r2/controllers/oauth2.py index 86069b82c..809601c38 100644 --- a/r2/r2/controllers/oauth2.py +++ b/r2/r2/controllers/oauth2.py @@ -52,6 +52,9 @@ from r2.lib.validator import ( ) class OAuth2FrontendController(RedditController): + def check_for_bearer_token(self): + pass + def pre(self): RedditController.pre(self) require_https() @@ -300,56 +303,6 @@ class OAuth2AccessController(MinimalController): return self.api_wrapper(resp) -class OAuth2ResourceController(MinimalController): - def pre(self): - set_extension(request.environ, "json") - MinimalController.pre(self) - require_https() - - try: - access_token = OAuth2AccessToken.get_token(self._get_bearer_token()) - require(access_token) - require(access_token.check_valid()) - c.oauth2_access_token = access_token - account = Account._byID36(access_token.user_id, data=True) - require(account) - require(not account._deleted) - c.oauth_user = account - except RequirementException: - self._auth_error(401, "invalid_token") - - handler = self._get_action_handler() - if handler: - oauth2_perms = getattr(handler, "oauth2_perms", None) - if oauth2_perms: - grant = OAuth2Scope(access_token.scope) - required = set(oauth2_perms['allowed_scopes']) - if not grant.has_access(c.site.name, required): - self._auth_error(403, "insufficient_scope") - c.oauth_scope = grant - else: - self._auth_error(400, "invalid_request") - - def check_for_bearer_token(self): - if self._get_bearer_token(strict=False): - OAuth2ResourceController.pre(self) - if c.oauth_user: - c.user = c.oauth_user - c.user_is_loggedin = True - - def _auth_error(self, code, error): - abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)]) - - def _get_bearer_token(self, strict=True): - auth = request.headers.get("Authorization") - try: - auth_scheme, bearer_token = require_split(auth, 2) - require(auth_scheme.lower() == "bearer") - return bearer_token - except RequirementException: - if strict: - self._auth_error(400, "invalid_request") - def require_oauth2_scope(*scopes): def oauth2_scope_wrap(fn): fn.oauth2_perms = {"allowed_scopes": scopes} diff --git a/r2/r2/controllers/reddit_base.py b/r2/r2/controllers/reddit_base.py index c290e6953..9640ea930 100644 --- a/r2/r2/controllers/reddit_base.py +++ b/r2/r2/controllers/reddit_base.py @@ -43,7 +43,7 @@ from pylons.controllers.util import redirect_to from pylons.i18n import _ from pylons.i18n.translation import LanguageError -from r2.config.extensions import is_api +from r2.config.extensions import is_api, set_extension from r2.lib import filters, pages, utils, hooks from r2.lib.authentication import authenticate_user from r2.lib.base import BaseController, abort @@ -56,6 +56,7 @@ from r2.lib.errors import ( reddit_http_error, ) from r2.lib.filters import _force_utf8, _force_unicode +from r2.lib.require import RequirementException, require, require_split from r2.lib.strings import strings from r2.lib.template_helpers import add_sr, JSPreload from r2.lib.tracking import encrypt, decrypt @@ -83,6 +84,7 @@ from r2.lib.validator import ( VTarget, ) from r2.models import ( + Account, All, AllMinus, DefaultSR, @@ -97,6 +99,8 @@ from r2.models import ( ModMinus, MultiReddit, NotFound, + OAuth2AccessToken, + OAuth2Scope, Random, RandomNSFW, RandomSubscription, @@ -687,6 +691,12 @@ def require_https(): if not c.secure: abort(ForbiddenError(errors.HTTPS_REQUIRED)) + +def require_domain(required_domain): + if not is_subdomain(request.host, required_domain): + abort(ForbiddenError(errors.WRONG_DOMAIN)) + + def disable_subreddit_css(): def wrap(f): @wraps(f) @@ -968,7 +978,63 @@ class MinimalController(BaseController): return request.method.upper() != "POST" -class RedditController(MinimalController): +class OAuth2ResourceController(MinimalController): + def authenticate_with_token(self): + set_extension(request.environ, "json") + set_content_type() + require_https() + require_domain(g.oauth_domain) + + try: + access_token = OAuth2AccessToken.get_token(self._get_bearer_token()) + require(access_token) + require(access_token.check_valid()) + c.oauth2_access_token = access_token + account = Account._byID36(access_token.user_id, data=True) + require(account) + require(not account._deleted) + c.oauth_user = account + except RequirementException: + self._auth_error(401, "invalid_token") + + handler = self._get_action_handler() + if handler: + oauth2_perms = getattr(handler, "oauth2_perms", None) + if oauth2_perms or True: + grant = OAuth2Scope(access_token.scope) + required = set(oauth2_perms['allowed_scopes']) + if not grant.has_access(c.site.name, required): + self._auth_error(403, "insufficient_scope") + c.oauth_scope = grant + else: + self._auth_error(400, "invalid_request") + + def check_for_bearer_token(self): + if self._get_bearer_token(strict=False): + self.authenticate_with_token() + if c.oauth_user: + c.user = c.oauth_user + c.user_is_loggedin = True + + def _auth_error(self, code, error): + abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)]) + + def _get_bearer_token(self, strict=True): + auth = request.headers.get("Authorization") + if not auth: + return None + try: + auth_scheme, bearer_token = require_split(auth, 2) + require(auth_scheme.lower() == "bearer") + return bearer_token + except RequirementException: + if strict: + self._auth_error(400, "invalid_request") + else: + return None + + +class RedditController(OAuth2ResourceController): @staticmethod def login(user, rem=False): @@ -1039,6 +1105,8 @@ class RedditController(MinimalController): maybe_admin = False is_otpcookie_valid = False + self.check_for_bearer_token() + # no logins for RSS feed unless valid_feed has already been called if not c.user: if c.extension != "rss": diff --git a/r2/r2/controllers/wiki.py b/r2/r2/controllers/wiki.py index 2cfae8908..422d2a352 100644 --- a/r2/r2/controllers/wiki.py +++ b/r2/r2/controllers/wiki.py @@ -23,7 +23,7 @@ from pylons import request, g, c from pylons.controllers.util import redirect_to from reddit_base import RedditController -from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope +from r2.controllers.oauth2 import require_oauth2_scope from r2.lib.utils import url_links_builder from reddit_base import paginated_listing from r2.models.wiki import (WikiPage, WikiRevision, ContentLengthError, @@ -94,7 +94,7 @@ RENDERERS_BY_PAGE = {"config/sidebar": "reddit", "config/description": "reddit", "config/stylesheet": "stylesheet"} -class WikiController(RedditController, OAuth2ResourceController): +class WikiController(RedditController): allow_stylesheets = True @require_oauth2_scope("wikiread") @@ -308,7 +308,6 @@ class WikiController(RedditController, OAuth2ResourceController): abort(reddit_http_error(code, reason, **data)) def pre(self): - self.check_for_bearer_token() RedditController.pre(self) if g.disable_wiki and not c.user_is_admin: self.handle_error(403, 'WIKI_DOWN') diff --git a/r2/r2/lib/errors.py b/r2/r2/lib/errors.py index 60b977de1..100c192ad 100644 --- a/r2/r2/lib/errors.py +++ b/r2/r2/lib/errors.py @@ -30,6 +30,7 @@ from copy import copy error_list = dict(( ('USER_REQUIRED', _("please login to do that")), ('HTTPS_REQUIRED', _("this page must be accessed using https")), + ('WRONG_DOMAIN', _("you can't do that on this domain")), ('VERIFIED_USER_REQUIRED', _("you need to set a valid email address to do that.")), ('NO_URL', _('a url is required')), ('BAD_URL', _('you should check that url')),