diff --git a/r2/r2/controllers/apiv1.py b/r2/r2/controllers/apiv1.py index 46abf0443..4dbeac3a1 100644 --- a/r2/r2/controllers/apiv1.py +++ b/r2/r2/controllers/apiv1.py @@ -19,14 +19,18 @@ # All portions of the code written by reddit are Copyright (c) 2006-2013 reddit # Inc. All Rights Reserved. ############################################################################### +import json from pylons import c from r2.controllers.api_docs import api_doc, api_section from r2.controllers.oauth2 import require_oauth2_scope -from r2.controllers.reddit_base import OAuth2ResourceController +from r2.controllers.reddit_base import ( + abort_with_error, + OAuth2ResourceController, +) +from r2.lib.base import abort from r2.lib.jsontemplates import IdentityJsonTemplate, PrefsJsonTemplate from r2.lib.validator import ( - nop, validate, VContentLang, VList, @@ -44,7 +48,7 @@ PREFS_JSON_SPEC = VValidatedJSON.PartialObject({ }) PREFS_JSON_SPEC.spec["content_langs"] = VValidatedJSON.ArrayOf( - VContentLang("content_langs") + VContentLang("content_langs") ) @@ -81,3 +85,21 @@ class APIv1Controller(OAuth2ResourceController): """Return the preference settings of the logged in user""" resp = PrefsJsonTemplate(fields).data(c.oauth_user) return self.api_wrapper(resp) + + PREFS_JSON_VALIDATOR = VValidatedJSON("json", PREFS_JSON_SPEC, + body=True) + + @require_oauth2_scope("account") + @api_doc(api_section.account, json_model=PREFS_JSON_VALIDATOR) + @validate(validated_prefs=PREFS_JSON_VALIDATOR) + def PATCH_prefs(self, validated_prefs): + user_prefs = c.user.preferences() + for short_name, new_value in validated_prefs.iteritems(): + pref_name = "pref_" + short_name + if pref_name == "pref_content_langs": + new_value = vprefs.format_content_lang_pref(new_value) + user_prefs[pref_name] = new_value + vprefs.filter_prefs(user_prefs, c.user) + vprefs.set_prefs(c.user, user_prefs) + c.user._commit() + return self.api_wrapper(PrefsJsonTemplate().data(c.user)) diff --git a/r2/r2/controllers/error.py b/r2/r2/controllers/error.py index 2651ad590..7fab8c728 100644 --- a/r2/r2/controllers/error.py +++ b/r2/r2/controllers/error.py @@ -208,7 +208,11 @@ class ErrorController(RedditController): except Exception as e: return handle_awful_failure("ErrorController.GET_document: %r" % e) - POST_document = PUT_document = DELETE_document = GET_document + POST_document = GET_document + PUT_document = GET_document + PATCH_document = GET_document + DELETE_document = GET_document + def handle_awful_failure(fail_text): """ diff --git a/r2/r2/controllers/post.py b/r2/r2/controllers/post.py index f31534741..d397d8d04 100644 --- a/r2/r2/controllers/post.py +++ b/r2/r2/controllers/post.py @@ -19,7 +19,6 @@ # All portions of the code written by reddit are Copyright (c) 2006-2013 reddit # Inc. All Rights Reserved. ############################################################################### - from r2.lib.pages import * from reddit_base import cross_domain from api import ApiController @@ -50,7 +49,6 @@ class PostController(ApiController): langs.append(str(lang)) return format_content_lang_pref(langs) - @validate(pref_lang = VLang('lang'), all_langs = VOneOf('all-langs', ('all', 'some'), default='all')) def POST_unlogged_options(self, all_langs, pref_lang): diff --git a/r2/r2/controllers/reddit_base.py b/r2/r2/controllers/reddit_base.py index c66d42a86..b5f4941ff 100644 --- a/r2/r2/controllers/reddit_base.py +++ b/r2/r2/controllers/reddit_base.py @@ -724,12 +724,12 @@ def flatten_response(content): return "".join(_force_utf8(x) for x in tup(content) if x) -def abort_with_error(error): - if not error.code: +def abort_with_error(error, code=None): + if not code and not error.code: raise ValueError('Error %r missing status code' % error) abort(reddit_http_error( - code=error.code, + code=code or error.code, error_name=error.name, explanation=error.message, fields=error.fields, diff --git a/r2/r2/lib/validator/validator.py b/r2/r2/lib/validator/validator.py index db7bc43ff..2cefd4392 100644 --- a/r2/r2/lib/validator/validator.py +++ b/r2/r2/lib/validator/validator.py @@ -78,7 +78,7 @@ def can_comment_link(article): class Validator(object): default_param = None def __init__(self, param=None, default=None, post=True, get=True, url=True, - docs=None): + body=False, docs=None): if param: self.param = param else: @@ -86,6 +86,7 @@ class Validator(object): self.default = default self.post, self.get, self.url, self.docs = post, get, url, docs + self.body = body self.has_errors = False def set_error(self, error, msg_params={}, field=False, code=None): @@ -120,6 +121,8 @@ class Validator(object): val = request.GET[p] elif self.url and url.get(p): val = url[p] + elif self.body: + val = request.body else: val = self.default a.append(val) @@ -371,6 +374,17 @@ class VLang(Validator): self.param: "a valid IETF language tag (underscore separated)", } + +class VContentLang(VLang): + def run(self, lang): + if lang == "all": + return lang + try: + return VLang.validate_lang(lang, strict=True) + except ValueError: + self.set_error(errors.INVALID_LANG) + + class VRequired(Validator): def __init__(self, param, error, *a, **kw): Validator.__init__(self, param, *a, **kw) @@ -1436,6 +1450,9 @@ class VUserWithEmail(VExistingUname): class VBoolean(Validator): def run(self, val): + if val is True or val is False: + # val is already a bool object, no processing needed + return val lv = str(val).lower() if lv == 'off' or lv == '' or lv[0] in ("f", "n"): return False @@ -2436,7 +2453,7 @@ class VValidatedJSON(VJSON): def __init__(self, spec): self.spec = spec - def run(self, data): + def run(self, data, ignore_missing=False): if not isinstance(data, dict): raise RedditError('JSON_INVALID', code=400) @@ -2445,6 +2462,8 @@ class VValidatedJSON(VJSON): try: validated_data[key] = validator.run(data[key]) except KeyError: + if ignore_missing: + continue raise RedditError('JSON_MISSING_KEY', code=400, msg_params={'key': key}) return validated_data @@ -2470,6 +2489,10 @@ class VValidatedJSON(VJSON): spec_lines.append('}') return '\n'.join(spec_lines) + class PartialObject(Object): + def run(self, data): + super_ = super(VValidatedJSON.PartialObject, self) + return super_.run(data, ignore_missing=True) def __init__(self, param, spec, **kw): VJSON.__init__(self, param, **kw) diff --git a/r2/r2/models/account.py b/r2/r2/models/account.py index 494e90b28..716644023 100644 --- a/r2/r2/models/account.py +++ b/r2/r2/models/account.py @@ -123,7 +123,11 @@ class Account(Thing): state=0, modmsgtime=None, ) - _preference_attrs = (k for k in _defaults.keys() if k.startswith("pref_")) + _preference_attrs = tuple(k for k in _defaults.keys() + if k.startswith("pref_")) + + def preferences(self): + return {pref: getattr(self, pref) for pref in self._preference_attrs} def __eq__(self, other): if type(self) != type(other): diff --git a/r2/r2/models/token.py b/r2/r2/models/token.py index 3bebf1e6d..e0629dc63 100644 --- a/r2/r2/models/token.py +++ b/r2/r2/models/token.py @@ -101,6 +101,13 @@ class ConsumableToken(Token): class OAuth2Scope: scope_info = { + "account": { + "id": "account", + "name": _("Update account information"), + "description": _("Update preferences and related account " + "information. Will not have access to your email or " + "password."), + }, "edit": { "id": "edit", "name": _("Edit Posts"),