diff --git a/r2/r2/controllers/api.py b/r2/r2/controllers/api.py index 3d9d36814..6e6d54cf6 100755 --- a/r2/r2/controllers/api.py +++ b/r2/r2/controllers/api.py @@ -75,7 +75,7 @@ from r2.lib.filters import safemarkdown from r2.lib.media import str_to_image from r2.controllers.api_docs import api_doc, api_section from r2.lib.search import SearchQuery -from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope +from r2.controllers.oauth2 import require_oauth2_scope from r2.lib.template_helpers import add_sr, get_domain from r2.lib.system_messages import notify_user_added from r2.controllers.ipn import generate_blob @@ -132,15 +132,10 @@ class ApiminimalController(MinimalController): form._send_data(iden = iden) -class ApiController(RedditController, OAuth2ResourceController): +class ApiController(RedditController): """ Controller which deals with almost all AJAX site interaction. """ - - def pre(self): - self.check_for_bearer_token() - RedditController.pre(self) - @validatedForm() def ajax_login_redirect(self, form, jquery, dest): form.redirect("/login" + query_string(dict(dest=dest))) diff --git a/r2/r2/controllers/apiv1.py b/r2/r2/controllers/apiv1.py index f10182bcb..d77e728be 100644 --- a/r2/r2/controllers/apiv1.py +++ b/r2/r2/controllers/apiv1.py @@ -22,11 +22,13 @@ from pylons import c from r2.controllers.api_docs import api_doc, api_section -from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope +from r2.controllers.oauth2 import require_oauth2_scope +from r2.controllers.reddit_base import OAuth2ResourceController from r2.lib.jsontemplates import IdentityJsonTemplate class APIv1Controller(OAuth2ResourceController): def pre(self): + OAuth2ResourceController.pre(self) self.check_for_bearer_token() def try_pagecache(self): diff --git a/r2/r2/controllers/error.py b/r2/r2/controllers/error.py index 1ce86536b..2651ad590 100644 --- a/r2/r2/controllers/error.py +++ b/r2/r2/controllers/error.py @@ -90,6 +90,9 @@ class ErrorController(RedditController): This behaviour can be altered by changing the parameters to the ErrorDocuments middleware in your config/middleware.py file. """ + def check_for_bearer_token(self): + pass + allowed_render_styles = ('html', 'xml', 'js', 'embed', '', "compact", 'api') # List of admins to blame (skip the first admin, "reddit") # If list is empty, just blame "an admin" diff --git a/r2/r2/controllers/front.py b/r2/r2/controllers/front.py index 64c8a7742..a9e1034dc 100755 --- a/r2/r2/controllers/front.py +++ b/r2/r2/controllers/front.py @@ -55,7 +55,7 @@ from r2.lib import sup import r2.lib.db.thing as thing from r2.lib.errors import errors from listingcontroller import ListingController -from oauth2 import OAuth2ResourceController, require_oauth2_scope +from oauth2 import require_oauth2_scope from api_docs import api_doc, api_section from pylons import c, request, response from r2.models.token import EmailVerificationToken @@ -68,14 +68,10 @@ import re, socket import time as time_module from urllib import quote_plus -class FrontController(RedditController, OAuth2ResourceController): +class FrontController(RedditController): allow_stylesheets = True - def pre(self): - self.check_for_bearer_token() - RedditController.pre(self) - @validate(article=VLink('article'), comment=VCommentID('comment')) def GET_oldinfo(self, article, type, dest, rest=None, comment=''): diff --git a/r2/r2/controllers/listingcontroller.py b/r2/r2/controllers/listingcontroller.py index 9cf548404..a5a9b3a52 100755 --- a/r2/r2/controllers/listingcontroller.py +++ b/r2/r2/controllers/listingcontroller.py @@ -20,7 +20,7 @@ # Inc. All Rights Reserved. ############################################################################### -from oauth2 import OAuth2ResourceController, require_oauth2_scope +from oauth2 import require_oauth2_scope from reddit_base import RedditController, base_listing, paginated_listing from r2.models import * @@ -53,7 +53,7 @@ from pylons.controllers.util import redirect_to import random from functools import partial -class ListingController(RedditController, OAuth2ResourceController): +class ListingController(RedditController): """Generalized controller for pages with lists of links.""" # toggle skipping of links based on the users' save/hide/vote preferences @@ -89,10 +89,6 @@ class ListingController(RedditController, OAuth2ResourceController): render_params = {} extra_page_classes = ['listing-page'] - def pre(self): - self.check_for_bearer_token() - RedditController.pre(self) - @property def menus(self): """list of menus underneat the header (e.g., sort, time, kind, @@ -1025,13 +1021,9 @@ class RedditsController(ListingController): self.where = where return ListingController.GET_listing(self, **env) -class MyredditsController(ListingController, OAuth2ResourceController): +class MyredditsController(ListingController): render_cls = MySubredditsPage - def pre(self): - self.check_for_bearer_token() - ListingController.pre(self) - @property def menus(self): buttons = (NavButton(plurals.subscriber, 'subscriber'), diff --git a/r2/r2/controllers/multi.py b/r2/r2/controllers/multi.py index a5a57a2b2..b982e7fd0 100644 --- a/r2/r2/controllers/multi.py +++ b/r2/r2/controllers/multi.py @@ -26,10 +26,7 @@ from pylons.i18n import _ from r2.config.extensions import set_extension from r2.controllers.api_docs import api_doc, api_section from r2.controllers.reddit_base import RedditController, abort_with_error -from r2.controllers.oauth2 import ( - OAuth2ResourceController, - require_oauth2_scope, -) +from r2.controllers.oauth2 import require_oauth2_scope from r2.models.account import Account from r2.models.subreddit import ( FakeSubreddit, @@ -75,12 +72,11 @@ multi_description_json_spec = VValidatedJSON.Object({ }) -class MultiApiController(RedditController, OAuth2ResourceController): +class MultiApiController(RedditController): on_validation_error = staticmethod(abort_with_error) def pre(self): set_extension(request.environ, "json") - self.check_for_bearer_token() RedditController.pre(self) @require_oauth2_scope("read") diff --git a/r2/r2/controllers/oauth2.py b/r2/r2/controllers/oauth2.py index 86069b82c..809601c38 100644 --- a/r2/r2/controllers/oauth2.py +++ b/r2/r2/controllers/oauth2.py @@ -52,6 +52,9 @@ from r2.lib.validator import ( ) class OAuth2FrontendController(RedditController): + def check_for_bearer_token(self): + pass + def pre(self): RedditController.pre(self) require_https() @@ -300,56 +303,6 @@ class OAuth2AccessController(MinimalController): return self.api_wrapper(resp) -class OAuth2ResourceController(MinimalController): - def pre(self): - set_extension(request.environ, "json") - MinimalController.pre(self) - require_https() - - try: - access_token = OAuth2AccessToken.get_token(self._get_bearer_token()) - require(access_token) - require(access_token.check_valid()) - c.oauth2_access_token = access_token - account = Account._byID36(access_token.user_id, data=True) - require(account) - require(not account._deleted) - c.oauth_user = account - except RequirementException: - self._auth_error(401, "invalid_token") - - handler = self._get_action_handler() - if handler: - oauth2_perms = getattr(handler, "oauth2_perms", None) - if oauth2_perms: - grant = OAuth2Scope(access_token.scope) - required = set(oauth2_perms['allowed_scopes']) - if not grant.has_access(c.site.name, required): - self._auth_error(403, "insufficient_scope") - c.oauth_scope = grant - else: - self._auth_error(400, "invalid_request") - - def check_for_bearer_token(self): - if self._get_bearer_token(strict=False): - OAuth2ResourceController.pre(self) - if c.oauth_user: - c.user = c.oauth_user - c.user_is_loggedin = True - - def _auth_error(self, code, error): - abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)]) - - def _get_bearer_token(self, strict=True): - auth = request.headers.get("Authorization") - try: - auth_scheme, bearer_token = require_split(auth, 2) - require(auth_scheme.lower() == "bearer") - return bearer_token - except RequirementException: - if strict: - self._auth_error(400, "invalid_request") - def require_oauth2_scope(*scopes): def oauth2_scope_wrap(fn): fn.oauth2_perms = {"allowed_scopes": scopes} diff --git a/r2/r2/controllers/reddit_base.py b/r2/r2/controllers/reddit_base.py index c290e6953..9640ea930 100644 --- a/r2/r2/controllers/reddit_base.py +++ b/r2/r2/controllers/reddit_base.py @@ -43,7 +43,7 @@ from pylons.controllers.util import redirect_to from pylons.i18n import _ from pylons.i18n.translation import LanguageError -from r2.config.extensions import is_api +from r2.config.extensions import is_api, set_extension from r2.lib import filters, pages, utils, hooks from r2.lib.authentication import authenticate_user from r2.lib.base import BaseController, abort @@ -56,6 +56,7 @@ from r2.lib.errors import ( reddit_http_error, ) from r2.lib.filters import _force_utf8, _force_unicode +from r2.lib.require import RequirementException, require, require_split from r2.lib.strings import strings from r2.lib.template_helpers import add_sr, JSPreload from r2.lib.tracking import encrypt, decrypt @@ -83,6 +84,7 @@ from r2.lib.validator import ( VTarget, ) from r2.models import ( + Account, All, AllMinus, DefaultSR, @@ -97,6 +99,8 @@ from r2.models import ( ModMinus, MultiReddit, NotFound, + OAuth2AccessToken, + OAuth2Scope, Random, RandomNSFW, RandomSubscription, @@ -687,6 +691,12 @@ def require_https(): if not c.secure: abort(ForbiddenError(errors.HTTPS_REQUIRED)) + +def require_domain(required_domain): + if not is_subdomain(request.host, required_domain): + abort(ForbiddenError(errors.WRONG_DOMAIN)) + + def disable_subreddit_css(): def wrap(f): @wraps(f) @@ -968,7 +978,63 @@ class MinimalController(BaseController): return request.method.upper() != "POST" -class RedditController(MinimalController): +class OAuth2ResourceController(MinimalController): + def authenticate_with_token(self): + set_extension(request.environ, "json") + set_content_type() + require_https() + require_domain(g.oauth_domain) + + try: + access_token = OAuth2AccessToken.get_token(self._get_bearer_token()) + require(access_token) + require(access_token.check_valid()) + c.oauth2_access_token = access_token + account = Account._byID36(access_token.user_id, data=True) + require(account) + require(not account._deleted) + c.oauth_user = account + except RequirementException: + self._auth_error(401, "invalid_token") + + handler = self._get_action_handler() + if handler: + oauth2_perms = getattr(handler, "oauth2_perms", None) + if oauth2_perms or True: + grant = OAuth2Scope(access_token.scope) + required = set(oauth2_perms['allowed_scopes']) + if not grant.has_access(c.site.name, required): + self._auth_error(403, "insufficient_scope") + c.oauth_scope = grant + else: + self._auth_error(400, "invalid_request") + + def check_for_bearer_token(self): + if self._get_bearer_token(strict=False): + self.authenticate_with_token() + if c.oauth_user: + c.user = c.oauth_user + c.user_is_loggedin = True + + def _auth_error(self, code, error): + abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)]) + + def _get_bearer_token(self, strict=True): + auth = request.headers.get("Authorization") + if not auth: + return None + try: + auth_scheme, bearer_token = require_split(auth, 2) + require(auth_scheme.lower() == "bearer") + return bearer_token + except RequirementException: + if strict: + self._auth_error(400, "invalid_request") + else: + return None + + +class RedditController(OAuth2ResourceController): @staticmethod def login(user, rem=False): @@ -1039,6 +1105,8 @@ class RedditController(MinimalController): maybe_admin = False is_otpcookie_valid = False + self.check_for_bearer_token() + # no logins for RSS feed unless valid_feed has already been called if not c.user: if c.extension != "rss": diff --git a/r2/r2/controllers/wiki.py b/r2/r2/controllers/wiki.py index 2cfae8908..422d2a352 100644 --- a/r2/r2/controllers/wiki.py +++ b/r2/r2/controllers/wiki.py @@ -23,7 +23,7 @@ from pylons import request, g, c from pylons.controllers.util import redirect_to from reddit_base import RedditController -from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope +from r2.controllers.oauth2 import require_oauth2_scope from r2.lib.utils import url_links_builder from reddit_base import paginated_listing from r2.models.wiki import (WikiPage, WikiRevision, ContentLengthError, @@ -94,7 +94,7 @@ RENDERERS_BY_PAGE = {"config/sidebar": "reddit", "config/description": "reddit", "config/stylesheet": "stylesheet"} -class WikiController(RedditController, OAuth2ResourceController): +class WikiController(RedditController): allow_stylesheets = True @require_oauth2_scope("wikiread") @@ -308,7 +308,6 @@ class WikiController(RedditController, OAuth2ResourceController): abort(reddit_http_error(code, reason, **data)) def pre(self): - self.check_for_bearer_token() RedditController.pre(self) if g.disable_wiki and not c.user_is_admin: self.handle_error(403, 'WIKI_DOWN') diff --git a/r2/r2/lib/errors.py b/r2/r2/lib/errors.py index 60b977de1..100c192ad 100644 --- a/r2/r2/lib/errors.py +++ b/r2/r2/lib/errors.py @@ -30,6 +30,7 @@ from copy import copy error_list = dict(( ('USER_REQUIRED', _("please login to do that")), ('HTTPS_REQUIRED', _("this page must be accessed using https")), + ('WRONG_DOMAIN', _("you can't do that on this domain")), ('VERIFIED_USER_REQUIRED', _("you need to set a valid email address to do that.")), ('NO_URL', _('a url is required')), ('BAD_URL', _('you should check that url')),