From ba7841248076ad4532f53605286a41efea7c23b2 Mon Sep 17 00:00:00 2001 From: Max Goodman Date: Thu, 7 Feb 2013 23:24:26 -0800 Subject: [PATCH] Generalize wiki validator / error system. * Generalize validator error class and make an Exception subclass. * Provide per-controller handling of fatal validator errors. * Instantiate correct http exception classes using status_map. --- r2/r2/controllers/reddit_base.py | 6 +++ r2/r2/controllers/wiki.py | 73 ++++++++++++++------------- r2/r2/lib/errors.py | 84 ++++++++++++++++++++++---------- r2/r2/lib/validator/validator.py | 22 ++++++--- r2/r2/lib/validator/wiki.py | 16 ------ 5 files changed, 120 insertions(+), 81 deletions(-) diff --git a/r2/r2/controllers/reddit_base.py b/r2/r2/controllers/reddit_base.py index 9bb7323b7..3232e2113 100644 --- a/r2/r2/controllers/reddit_base.py +++ b/r2/r2/controllers/reddit_base.py @@ -823,6 +823,12 @@ class MinimalController(BaseController): c.request_timer.stop() g.stats.flush() + def on_validation_error(self, error): + if error.name == errors.USER_REQUIRED: + self.intermediate_redirect('/login') + elif error.name == errors.VERIFIED_USER_REQUIRED: + self.intermediate_redirect('/verify') + def abort404(self): abort(404, "not found") diff --git a/r2/r2/controllers/wiki.py b/r2/r2/controllers/wiki.py index f1ad021d3..4daeaefcb 100644 --- a/r2/r2/controllers/wiki.py +++ b/r2/r2/controllers/wiki.py @@ -35,6 +35,7 @@ from r2.lib.template_helpers import join_urls from r2.lib.validator import ( + validate, VExistingUname, VInt, VMarkdown, @@ -52,7 +53,6 @@ from r2.lib.validator.wiki import ( this_may_revise, this_may_view, VWikiPageName, - wiki_validate, ) from r2.controllers.api_docs import api_doc, api_section from r2.lib.pages.wiki import (WikiPageView, WikiNotFound, WikiRevisions, @@ -74,7 +74,7 @@ from r2.lib.pages import PaneStack from r2.lib.utils import timesince from r2.config import extensions from r2.lib.base import abort -from r2.lib.errors import WikiError +from r2.lib.errors import reddit_http_error import json @@ -89,12 +89,12 @@ ATTRIBUTE_BY_PAGE = {"config/sidebar": "description", class WikiController(RedditController): allow_stylesheets = True - @wiki_validate(pv=VWikiPageAndVersion(('page', 'v', 'v2'), - required=False, - restricted=False, - allow_hidden_revision=False), - page_name=VWikiPageName('page', - error_on_name_normalized=True)) + @validate(pv=VWikiPageAndVersion(('page', 'v', 'v2'), + required=False, + restricted=False, + allow_hidden_revision=False), + page_name=VWikiPageName('page', + error_on_name_normalized=True)) def GET_wiki_page(self, pv, page_name): message = None @@ -141,15 +141,15 @@ class WikiController(RedditController): edit_date=edit_date, page=page.name).render() @paginated_listing(max_page_size=100, backend='cassandra') - @wiki_validate(page=VWikiPage(('page'), restricted=False)) + @validate(page=VWikiPage(('page'), restricted=False)) def GET_wiki_revisions(self, num, after, reverse, count, page): revisions = page.get_revisions() builder = WikiRevisionBuilder(revisions, num=num, reverse=reverse, count=count, after=after, skip=not c.is_wiki_mod, wrap=default_thing_wrapper()) listing = WikiRevisionListing(builder).listing() return WikiRevisions(listing, page=page.name, may_revise=this_may_revise(page)).render() - @wiki_validate(wp=VWikiPageRevise('page'), - page=VWikiPageName('page')) + @validate(wp=VWikiPageRevise('page'), + page=VWikiPageName('page')) def GET_wiki_create(self, wp, page): api = c.render_style in extensions.API_TYPES error = c.errors.get(('WIKI_CREATE_ERROR', 'page')) @@ -174,7 +174,7 @@ class WikiController(RedditController): else: return WikiCreate(page=page, may_revise=True).render() - @wiki_validate(wp=VWikiPageRevise('page', restricted=True, required=True)) + @validate(wp=VWikiPageRevise('page', restricted=True, required=True)) def GET_wiki_revise(self, wp, page, message=None, **kw): wp = wp[0] previous = kw.get('previous', wp._get('revision')) @@ -204,7 +204,7 @@ class WikiController(RedditController): return redirect_to(str("%s/%s" % (c.wiki_base_url, page)), _code=301) @base_listing - @wiki_validate(page=VWikiPage('page', restricted=True)) + @validate(page=VWikiPage('page', restricted=True)) def GET_wiki_discussions(self, page, num, after, reverse, count): page_url = add_sr("%s/%s" % (c.wiki_base_url, page.name)) links = url_links(page_url) @@ -215,7 +215,7 @@ class WikiController(RedditController): return WikiDiscussions(listing, page=page.name, may_revise=this_may_revise(page)).render() - @wiki_validate(page=VWikiPage('page', restricted=True, modonly=True)) + @validate(page=VWikiPage('page', restricted=True, modonly=True)) def GET_wiki_settings(self, page): settings = {'permlevel': page._get('permlevel', 0)} mayedit = page.get_editor_accounts() @@ -226,9 +226,9 @@ class WikiController(RedditController): restricted=restricted, may_revise=True).render() - @wiki_validate(VModhash(), - page=VWikiPage('page', restricted=True, modonly=True), - permlevel=VInt('permlevel')) + @validate(VModhash(), + page=VWikiPage('page', restricted=True, modonly=True), + permlevel=VInt('permlevel')) def POST_wiki_settings(self, page, permlevel): oldpermlevel = page.permlevel try: @@ -240,8 +240,13 @@ class WikiController(RedditController): description=description) return self.GET_wiki_settings(page=page.name) + def on_validation_error(self, error): + RedditController.on_validation_error(self, error) + if error.code: + self.handle_error(error.code, error.name) + def handle_error(self, code, reason=None, **data): - abort(WikiError(code, reason, **data)) + abort(reddit_http_error(code, reason, **data)) def pre(self): RedditController.pre(self) @@ -274,11 +279,11 @@ class WikiController(RedditController): class WikiApiController(WikiController): - @wiki_validate(VModhash(), - pageandprevious=VWikiPageRevise(('page', 'previous'), restricted=True), - content=VMarkdown(('content'), renderer='wiki'), - page_name=VWikiPageName('page'), - reason=VPrintable('reason', 256)) + @validate(VModhash(), + pageandprevious=VWikiPageRevise(('page', 'previous'), restricted=True), + content=VMarkdown(('content'), renderer='wiki'), + page_name=VWikiPageName('page'), + reason=VPrintable('reason', 256)) @api_doc(api_section.wiki, uri='/api/wiki/edit') def POST_wiki_edit(self, pageandprevious, content, page_name, reason): page, previous = pageandprevious @@ -325,11 +330,11 @@ class WikiApiController(WikiController): self.handle_error(409, 'EDIT_CONFLICT', newcontent=e.new, newrevision=page.revision, diffcontent=e.htmldiff) return json.dumps({}) - @wiki_validate(VModhash(), - VWikiModerator(), - page=VWikiPage('page'), - act=VOneOf('act', ('del', 'add')), - user=VExistingUname('username')) + @validate(VModhash(), + VWikiModerator(), + page=VWikiPage('page'), + act=VOneOf('act', ('del', 'add')), + user=VExistingUname('username')) @api_doc(api_section.wiki, uri='/api/wiki/alloweditor/:act') def POST_wiki_allow_editor(self, act, page, user): if not user: @@ -342,9 +347,9 @@ class WikiApiController(WikiController): self.handle_error(400, 'INVALID_ACTION') return json.dumps({}) - @wiki_validate(VModhash(), - VWikiModerator(), - pv=VWikiPageAndVersion(('page', 'revision'))) + @validate(VModhash(), + VWikiModerator(), + pv=VWikiPageAndVersion(('page', 'revision'))) @api_doc(api_section.wiki, uri='/api/wiki/hide') def POST_wiki_revision_hide(self, pv): page, revision = pv @@ -352,9 +357,9 @@ class WikiApiController(WikiController): self.handle_error(400, 'INVALID_REVISION') return json.dumps({'status': revision.toggle_hide()}) - @wiki_validate(VModhash(), - VWikiModerator(), - pv=VWikiPageAndVersion(('page', 'revision'))) + @validate(VModhash(), + VWikiModerator(), + pv=VWikiPageAndVersion(('page', 'revision'))) @api_doc(api_section.wiki, uri='/api/wiki/revert') def POST_wiki_revision_revert(self, pv): page, revision = pv diff --git a/r2/r2/lib/errors.py b/r2/r2/lib/errors.py index b2f4c58ae..d674f9934 100644 --- a/r2/r2/lib/errors.py +++ b/r2/r2/lib/errors.py @@ -20,7 +20,7 @@ # Inc. All Rights Reserved. ############################################################################### -from webob.exc import HTTPBadRequest, HTTPForbidden, HTTPError +from webob.exc import HTTPBadRequest, HTTPForbidden, status_map from r2.lib.utils import Storage, tup from pylons import request from pylons.i18n import _ @@ -122,16 +122,27 @@ error_list = dict(( )) errors = Storage([(e, e) for e in error_list.keys()]) -class Error(object): +class RedditError(Exception): + name = None + fields = None + code = None + + def __init__(self, name=None, msg_params=None, fields=None, code=None): + Exception.__init__(self) + + if name is not None: + self.name = name + + self.i18n_message = error_list.get(self.name) + self.msg_params = msg_params or {} + + if fields is not None: + # list of fields in the original form that caused the error + self.fields = tup(fields) + + if code is not None: + self.code = code - def __init__(self, name, i18n_message, msg_params, field=None, code=None): - self.name = name - self.i18n_message = i18n_message - self.msg_params = msg_params - # list of fields in the original form that caused the error - self.fields = tup(field) if field else [] - self.code = code - @property def message(self): return _(self.i18n_message) % self.msg_params @@ -142,7 +153,11 @@ class Error(object): yield ('message', _(self.message)) def __repr__(self): - return '' % self.name + return '' % self.name + + def __str__(self): + return repr(self) + class ErrorSet(object): def __init__(self): @@ -168,12 +183,16 @@ class ErrorSet(object): def __len__(self): return len(self.errors) - - def add(self, error_name, msg_params={}, field=None, code=None): - msg = error_list.get(error_name) + + def add(self, error_name, msg_params=None, field=None, code=None): for field_name in tup(field): - e = Error(error_name, msg, msg_params, field=field_name, code=code) - self.errors[(error_name, field_name)] = e + e = RedditError(error_name, msg_params, fields=field_name, + code=code) + self.add_error(e) + + def add_error(self, error): + for field_name in tup(error.fields): + self.errors[(error.name, field_name)] = error def remove(self, pair): """Expectes an (error_name, field_name) tuple and removes it @@ -181,13 +200,6 @@ class ErrorSet(object): if self.errors.has_key(pair): del self.errors[pair] -class WikiError(HTTPError): - def __init__(self, code, reason=None, **data): - self.code = code - data['reason'] = self.explanation = reason or 'UNKNOWN_ERROR' - self.error_data = data - HTTPError.__init__(self) - class ForbiddenError(HTTPForbidden): def __init__(self, error): HTTPForbidden.__init__(self) @@ -201,5 +213,27 @@ class BadRequestError(HTTPBadRequest): 'explanation': error_list[error], } -class UserRequiredException(Exception): pass -class VerifiedUserRequiredException(Exception): pass + +def reddit_http_error(code=400, error='UNKNOWN_ERROR', **data): + exc = status_map[code]() + + data['reason'] = exc.explanation = error + if error in error_list: + data['explanation'] = exc.explanation = error_list[error] + + # omit 'fields' json attribute if it is empty + if 'fields' in data and not data['fields']: + del data['fields'] + + exc.error_data = data + return exc + + +class UserRequiredException(RedditError): + name = errors.USER_REQUIRED + code = 403 + + +class VerifiedUserRequiredException(RedditError): + name = errors.VERIFIED_USER_REQUIRED + code = 403 diff --git a/r2/r2/lib/validator/validator.py b/r2/r2/lib/validator/validator.py index 5e46c8347..321ac664e 100644 --- a/r2/r2/lib/validator/validator.py +++ b/r2/r2/lib/validator/validator.py @@ -39,7 +39,7 @@ from r2.models import * from r2.lib.authorize import Address, CreditCard from r2.lib.utils import constant_time_compare -from r2.lib.errors import errors, UserRequiredException +from r2.lib.errors import errors, RedditError, UserRequiredException from r2.lib.errors import VerifiedUserRequiredException from copy import copy @@ -167,19 +167,29 @@ def set_api_docs(fn, simple_vals, param_vals): param_info.update(validator.param_docs()) doc['parameters'] = param_info -make_validated_kw = _make_validated_kw def validate(*simple_vals, **param_vals): + """Validation decorator that delegates error handling to the controller. + + Runs the validators specified and calls self.on_validation_error to + process each error. This allows controllers to define their own fatal + error processing logic. + """ def val(fn): @wraps(fn) def newfn(self, *a, **env): try: kw = _make_validated_kw(fn, simple_vals, param_vals, env) + except RedditError as err: + self.on_validation_error(err) + + for err in c.errors: + self.on_validation_error(c.errors[err]) + + try: return fn(self, *a, **kw) - except UserRequiredException: - return self.intermediate_redirect('/login') - except VerifiedUserRequiredException: - return self.intermediate_redirect('/verify') + except RedditError as err: + self.on_validation_error(err) set_api_docs(newfn, simple_vals, param_vals) return newfn diff --git a/r2/r2/lib/validator/wiki.py b/r2/r2/lib/validator/wiki.py index 744f43bae..93c3ab96e 100644 --- a/r2/r2/lib/validator/wiki.py +++ b/r2/r2/lib/validator/wiki.py @@ -33,9 +33,7 @@ from pylons import c, g, request from r2.models.wiki import WikiPage, WikiRevision from r2.lib.validator import ( Validator, - validate, VSrModerator, - make_validated_kw, set_api_docs, ) from r2.lib.db import tdb_cassandra @@ -44,20 +42,6 @@ MAX_PAGE_NAME_LENGTH = g.wiki_max_page_name_length MAX_SEPARATORS = g.wiki_max_page_separators -def wiki_validate(*simple_vals, **param_vals): - def val(fn): - @wraps(fn) - def newfn(self, *a, **env): - kw = make_validated_kw(fn, simple_vals, param_vals, env) - for e in c.errors: - e = c.errors[e] - if e.code: - self.handle_error(e.code, e.name) - return fn(self, *a, **kw) - set_api_docs(newfn, simple_vals, param_vals) - return newfn - return val - def this_may_revise(page=None): if not c.user_is_loggedin: return False