mirror of
https://github.com/reddit-archive/reddit.git
synced 2026-04-27 03:00:12 -04:00
Bring OAuth2 into the core controllers
This makes it easier to avoid issues with running pre() functions multiple times (multiply subclassing was causing problems) and makes it so all resources are blocked from oauth access unless explicitly enabled (instead of randomly allowing access as a "logged out user" to endpoints that aren't part of an OAuth2ResourceController) Conflicts: r2/r2/controllers/apiv1.py r2/r2/controllers/oauth2.py
This commit is contained in:
@@ -75,7 +75,7 @@ from r2.lib.filters import safemarkdown
|
||||
from r2.lib.media import str_to_image
|
||||
from r2.controllers.api_docs import api_doc, api_section
|
||||
from r2.lib.search import SearchQuery
|
||||
from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope
|
||||
from r2.controllers.oauth2 import require_oauth2_scope
|
||||
from r2.lib.template_helpers import add_sr, get_domain
|
||||
from r2.lib.system_messages import notify_user_added
|
||||
from r2.controllers.ipn import generate_blob
|
||||
@@ -132,15 +132,10 @@ class ApiminimalController(MinimalController):
|
||||
form._send_data(iden = iden)
|
||||
|
||||
|
||||
class ApiController(RedditController, OAuth2ResourceController):
|
||||
class ApiController(RedditController):
|
||||
"""
|
||||
Controller which deals with almost all AJAX site interaction.
|
||||
"""
|
||||
|
||||
def pre(self):
|
||||
self.check_for_bearer_token()
|
||||
RedditController.pre(self)
|
||||
|
||||
@validatedForm()
|
||||
def ajax_login_redirect(self, form, jquery, dest):
|
||||
form.redirect("/login" + query_string(dict(dest=dest)))
|
||||
|
||||
@@ -22,11 +22,13 @@
|
||||
|
||||
from pylons import c
|
||||
from r2.controllers.api_docs import api_doc, api_section
|
||||
from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope
|
||||
from r2.controllers.oauth2 import require_oauth2_scope
|
||||
from r2.controllers.reddit_base import OAuth2ResourceController
|
||||
from r2.lib.jsontemplates import IdentityJsonTemplate
|
||||
|
||||
class APIv1Controller(OAuth2ResourceController):
|
||||
def pre(self):
|
||||
OAuth2ResourceController.pre(self)
|
||||
self.check_for_bearer_token()
|
||||
|
||||
def try_pagecache(self):
|
||||
|
||||
@@ -90,6 +90,9 @@ class ErrorController(RedditController):
|
||||
This behaviour can be altered by changing the parameters to the
|
||||
ErrorDocuments middleware in your config/middleware.py file.
|
||||
"""
|
||||
def check_for_bearer_token(self):
|
||||
pass
|
||||
|
||||
allowed_render_styles = ('html', 'xml', 'js', 'embed', '', "compact", 'api')
|
||||
# List of admins to blame (skip the first admin, "reddit")
|
||||
# If list is empty, just blame "an admin"
|
||||
|
||||
@@ -55,7 +55,7 @@ from r2.lib import sup
|
||||
import r2.lib.db.thing as thing
|
||||
from r2.lib.errors import errors
|
||||
from listingcontroller import ListingController
|
||||
from oauth2 import OAuth2ResourceController, require_oauth2_scope
|
||||
from oauth2 import require_oauth2_scope
|
||||
from api_docs import api_doc, api_section
|
||||
from pylons import c, request, response
|
||||
from r2.models.token import EmailVerificationToken
|
||||
@@ -68,14 +68,10 @@ import re, socket
|
||||
import time as time_module
|
||||
from urllib import quote_plus
|
||||
|
||||
class FrontController(RedditController, OAuth2ResourceController):
|
||||
class FrontController(RedditController):
|
||||
|
||||
allow_stylesheets = True
|
||||
|
||||
def pre(self):
|
||||
self.check_for_bearer_token()
|
||||
RedditController.pre(self)
|
||||
|
||||
@validate(article=VLink('article'),
|
||||
comment=VCommentID('comment'))
|
||||
def GET_oldinfo(self, article, type, dest, rest=None, comment=''):
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
# Inc. All Rights Reserved.
|
||||
###############################################################################
|
||||
|
||||
from oauth2 import OAuth2ResourceController, require_oauth2_scope
|
||||
from oauth2 import require_oauth2_scope
|
||||
from reddit_base import RedditController, base_listing, paginated_listing
|
||||
|
||||
from r2.models import *
|
||||
@@ -53,7 +53,7 @@ from pylons.controllers.util import redirect_to
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
class ListingController(RedditController, OAuth2ResourceController):
|
||||
class ListingController(RedditController):
|
||||
"""Generalized controller for pages with lists of links."""
|
||||
|
||||
# toggle skipping of links based on the users' save/hide/vote preferences
|
||||
@@ -89,10 +89,6 @@ class ListingController(RedditController, OAuth2ResourceController):
|
||||
render_params = {}
|
||||
extra_page_classes = ['listing-page']
|
||||
|
||||
def pre(self):
|
||||
self.check_for_bearer_token()
|
||||
RedditController.pre(self)
|
||||
|
||||
@property
|
||||
def menus(self):
|
||||
"""list of menus underneat the header (e.g., sort, time, kind,
|
||||
@@ -1025,13 +1021,9 @@ class RedditsController(ListingController):
|
||||
self.where = where
|
||||
return ListingController.GET_listing(self, **env)
|
||||
|
||||
class MyredditsController(ListingController, OAuth2ResourceController):
|
||||
class MyredditsController(ListingController):
|
||||
render_cls = MySubredditsPage
|
||||
|
||||
def pre(self):
|
||||
self.check_for_bearer_token()
|
||||
ListingController.pre(self)
|
||||
|
||||
@property
|
||||
def menus(self):
|
||||
buttons = (NavButton(plurals.subscriber, 'subscriber'),
|
||||
|
||||
@@ -26,10 +26,7 @@ from pylons.i18n import _
|
||||
from r2.config.extensions import set_extension
|
||||
from r2.controllers.api_docs import api_doc, api_section
|
||||
from r2.controllers.reddit_base import RedditController, abort_with_error
|
||||
from r2.controllers.oauth2 import (
|
||||
OAuth2ResourceController,
|
||||
require_oauth2_scope,
|
||||
)
|
||||
from r2.controllers.oauth2 import require_oauth2_scope
|
||||
from r2.models.account import Account
|
||||
from r2.models.subreddit import (
|
||||
FakeSubreddit,
|
||||
@@ -75,12 +72,11 @@ multi_description_json_spec = VValidatedJSON.Object({
|
||||
})
|
||||
|
||||
|
||||
class MultiApiController(RedditController, OAuth2ResourceController):
|
||||
class MultiApiController(RedditController):
|
||||
on_validation_error = staticmethod(abort_with_error)
|
||||
|
||||
def pre(self):
|
||||
set_extension(request.environ, "json")
|
||||
self.check_for_bearer_token()
|
||||
RedditController.pre(self)
|
||||
|
||||
@require_oauth2_scope("read")
|
||||
|
||||
@@ -52,6 +52,9 @@ from r2.lib.validator import (
|
||||
)
|
||||
|
||||
class OAuth2FrontendController(RedditController):
|
||||
def check_for_bearer_token(self):
|
||||
pass
|
||||
|
||||
def pre(self):
|
||||
RedditController.pre(self)
|
||||
require_https()
|
||||
@@ -300,56 +303,6 @@ class OAuth2AccessController(MinimalController):
|
||||
return self.api_wrapper(resp)
|
||||
|
||||
|
||||
class OAuth2ResourceController(MinimalController):
|
||||
def pre(self):
|
||||
set_extension(request.environ, "json")
|
||||
MinimalController.pre(self)
|
||||
require_https()
|
||||
|
||||
try:
|
||||
access_token = OAuth2AccessToken.get_token(self._get_bearer_token())
|
||||
require(access_token)
|
||||
require(access_token.check_valid())
|
||||
c.oauth2_access_token = access_token
|
||||
account = Account._byID36(access_token.user_id, data=True)
|
||||
require(account)
|
||||
require(not account._deleted)
|
||||
c.oauth_user = account
|
||||
except RequirementException:
|
||||
self._auth_error(401, "invalid_token")
|
||||
|
||||
handler = self._get_action_handler()
|
||||
if handler:
|
||||
oauth2_perms = getattr(handler, "oauth2_perms", None)
|
||||
if oauth2_perms:
|
||||
grant = OAuth2Scope(access_token.scope)
|
||||
required = set(oauth2_perms['allowed_scopes'])
|
||||
if not grant.has_access(c.site.name, required):
|
||||
self._auth_error(403, "insufficient_scope")
|
||||
c.oauth_scope = grant
|
||||
else:
|
||||
self._auth_error(400, "invalid_request")
|
||||
|
||||
def check_for_bearer_token(self):
|
||||
if self._get_bearer_token(strict=False):
|
||||
OAuth2ResourceController.pre(self)
|
||||
if c.oauth_user:
|
||||
c.user = c.oauth_user
|
||||
c.user_is_loggedin = True
|
||||
|
||||
def _auth_error(self, code, error):
|
||||
abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)])
|
||||
|
||||
def _get_bearer_token(self, strict=True):
|
||||
auth = request.headers.get("Authorization")
|
||||
try:
|
||||
auth_scheme, bearer_token = require_split(auth, 2)
|
||||
require(auth_scheme.lower() == "bearer")
|
||||
return bearer_token
|
||||
except RequirementException:
|
||||
if strict:
|
||||
self._auth_error(400, "invalid_request")
|
||||
|
||||
def require_oauth2_scope(*scopes):
|
||||
def oauth2_scope_wrap(fn):
|
||||
fn.oauth2_perms = {"allowed_scopes": scopes}
|
||||
|
||||
@@ -43,7 +43,7 @@ from pylons.controllers.util import redirect_to
|
||||
from pylons.i18n import _
|
||||
from pylons.i18n.translation import LanguageError
|
||||
|
||||
from r2.config.extensions import is_api
|
||||
from r2.config.extensions import is_api, set_extension
|
||||
from r2.lib import filters, pages, utils, hooks
|
||||
from r2.lib.authentication import authenticate_user
|
||||
from r2.lib.base import BaseController, abort
|
||||
@@ -56,6 +56,7 @@ from r2.lib.errors import (
|
||||
reddit_http_error,
|
||||
)
|
||||
from r2.lib.filters import _force_utf8, _force_unicode
|
||||
from r2.lib.require import RequirementException, require, require_split
|
||||
from r2.lib.strings import strings
|
||||
from r2.lib.template_helpers import add_sr, JSPreload
|
||||
from r2.lib.tracking import encrypt, decrypt
|
||||
@@ -83,6 +84,7 @@ from r2.lib.validator import (
|
||||
VTarget,
|
||||
)
|
||||
from r2.models import (
|
||||
Account,
|
||||
All,
|
||||
AllMinus,
|
||||
DefaultSR,
|
||||
@@ -97,6 +99,8 @@ from r2.models import (
|
||||
ModMinus,
|
||||
MultiReddit,
|
||||
NotFound,
|
||||
OAuth2AccessToken,
|
||||
OAuth2Scope,
|
||||
Random,
|
||||
RandomNSFW,
|
||||
RandomSubscription,
|
||||
@@ -687,6 +691,12 @@ def require_https():
|
||||
if not c.secure:
|
||||
abort(ForbiddenError(errors.HTTPS_REQUIRED))
|
||||
|
||||
|
||||
def require_domain(required_domain):
|
||||
if not is_subdomain(request.host, required_domain):
|
||||
abort(ForbiddenError(errors.WRONG_DOMAIN))
|
||||
|
||||
|
||||
def disable_subreddit_css():
|
||||
def wrap(f):
|
||||
@wraps(f)
|
||||
@@ -968,7 +978,63 @@ class MinimalController(BaseController):
|
||||
return request.method.upper() != "POST"
|
||||
|
||||
|
||||
class RedditController(MinimalController):
|
||||
class OAuth2ResourceController(MinimalController):
|
||||
def authenticate_with_token(self):
|
||||
set_extension(request.environ, "json")
|
||||
set_content_type()
|
||||
require_https()
|
||||
require_domain(g.oauth_domain)
|
||||
|
||||
try:
|
||||
access_token = OAuth2AccessToken.get_token(self._get_bearer_token())
|
||||
require(access_token)
|
||||
require(access_token.check_valid())
|
||||
c.oauth2_access_token = access_token
|
||||
account = Account._byID36(access_token.user_id, data=True)
|
||||
require(account)
|
||||
require(not account._deleted)
|
||||
c.oauth_user = account
|
||||
except RequirementException:
|
||||
self._auth_error(401, "invalid_token")
|
||||
|
||||
handler = self._get_action_handler()
|
||||
if handler:
|
||||
oauth2_perms = getattr(handler, "oauth2_perms", None)
|
||||
if oauth2_perms or True:
|
||||
grant = OAuth2Scope(access_token.scope)
|
||||
required = set(oauth2_perms['allowed_scopes'])
|
||||
if not grant.has_access(c.site.name, required):
|
||||
self._auth_error(403, "insufficient_scope")
|
||||
c.oauth_scope = grant
|
||||
else:
|
||||
self._auth_error(400, "invalid_request")
|
||||
|
||||
def check_for_bearer_token(self):
|
||||
if self._get_bearer_token(strict=False):
|
||||
self.authenticate_with_token()
|
||||
if c.oauth_user:
|
||||
c.user = c.oauth_user
|
||||
c.user_is_loggedin = True
|
||||
|
||||
def _auth_error(self, code, error):
|
||||
abort(code, headers=[("WWW-Authenticate", 'Bearer realm="reddit", error="%s"' % error)])
|
||||
|
||||
def _get_bearer_token(self, strict=True):
|
||||
auth = request.headers.get("Authorization")
|
||||
if not auth:
|
||||
return None
|
||||
try:
|
||||
auth_scheme, bearer_token = require_split(auth, 2)
|
||||
require(auth_scheme.lower() == "bearer")
|
||||
return bearer_token
|
||||
except RequirementException:
|
||||
if strict:
|
||||
self._auth_error(400, "invalid_request")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class RedditController(OAuth2ResourceController):
|
||||
|
||||
@staticmethod
|
||||
def login(user, rem=False):
|
||||
@@ -1039,6 +1105,8 @@ class RedditController(MinimalController):
|
||||
maybe_admin = False
|
||||
is_otpcookie_valid = False
|
||||
|
||||
self.check_for_bearer_token()
|
||||
|
||||
# no logins for RSS feed unless valid_feed has already been called
|
||||
if not c.user:
|
||||
if c.extension != "rss":
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
from pylons import request, g, c
|
||||
from pylons.controllers.util import redirect_to
|
||||
from reddit_base import RedditController
|
||||
from r2.controllers.oauth2 import OAuth2ResourceController, require_oauth2_scope
|
||||
from r2.controllers.oauth2 import require_oauth2_scope
|
||||
from r2.lib.utils import url_links_builder
|
||||
from reddit_base import paginated_listing
|
||||
from r2.models.wiki import (WikiPage, WikiRevision, ContentLengthError,
|
||||
@@ -94,7 +94,7 @@ RENDERERS_BY_PAGE = {"config/sidebar": "reddit",
|
||||
"config/description": "reddit",
|
||||
"config/stylesheet": "stylesheet"}
|
||||
|
||||
class WikiController(RedditController, OAuth2ResourceController):
|
||||
class WikiController(RedditController):
|
||||
allow_stylesheets = True
|
||||
|
||||
@require_oauth2_scope("wikiread")
|
||||
@@ -308,7 +308,6 @@ class WikiController(RedditController, OAuth2ResourceController):
|
||||
abort(reddit_http_error(code, reason, **data))
|
||||
|
||||
def pre(self):
|
||||
self.check_for_bearer_token()
|
||||
RedditController.pre(self)
|
||||
if g.disable_wiki and not c.user_is_admin:
|
||||
self.handle_error(403, 'WIKI_DOWN')
|
||||
|
||||
@@ -30,6 +30,7 @@ from copy import copy
|
||||
error_list = dict((
|
||||
('USER_REQUIRED', _("please login to do that")),
|
||||
('HTTPS_REQUIRED', _("this page must be accessed using https")),
|
||||
('WRONG_DOMAIN', _("you can't do that on this domain")),
|
||||
('VERIFIED_USER_REQUIRED', _("you need to set a valid email address to do that.")),
|
||||
('NO_URL', _('a url is required')),
|
||||
('BAD_URL', _('you should check that url')),
|
||||
|
||||
Reference in New Issue
Block a user