diff --git a/r2/r2/lib/db/userrel.py b/r2/r2/lib/db/userrel.py index 00cca834d..3346dfa3f 100644 --- a/r2/r2/lib/db/userrel.py +++ b/r2/r2/lib/db/userrel.py @@ -29,20 +29,25 @@ from r2.lib.memoize import memoize class UserRelManager(object): """Manages access to a relation between a type of thing and users.""" - def __init__(self, name, relation): + def __init__(self, name, relation, permission_class): self.name = name self.relation = relation + self.permission_class = permission_class def get(self, thing, user): if user: q = self.relation._fast_query([thing], [user], self.name) - return q.get((thing, user, self.name)) + rel = q.get((thing, user, self.name)) + if rel: + rel._permission_class = self.permission_class + return rel def add(self, thing, user, **attrs): if self.get(thing, user): return None r = self.relation(thing, user, self.name, **attrs) r._commit() + r._permission_class = self.permission_class return r def remove(self, thing, user): @@ -58,6 +63,7 @@ class UserRelManager(object): for k, v in attrs.iteritems(): setattr(r, k, v) r._commit() + r._permission_class = self.permission_class return r else: return self.add(thing, user, **attrs) @@ -71,18 +77,20 @@ class UserRelManager(object): return [r._thing1_id for r in q] def by_thing(self, thing): - return self.relation._query(self.relation.c._thing1_id == thing._id, - self.relation.c._name == self.name, - sort='_date') - + for r in self.relation._query(self.relation.c._thing1_id == thing._id, + self.relation.c._name == self.name, + sort='_date'): + r._permission_class = self.permission_class + yield r class MemoizedUserRelManager(UserRelManager): """Memoized manager for a relation to users.""" - def __init__(self, name, relation, + def __init__(self, name, relation, permission_class, disable_ids_fn=False, disable_reverse_ids_fn=False): - super(MemoizedUserRelManager, self).__init__(name, relation) + super(MemoizedUserRelManager, self).__init__( + name, relation, permission_class) self.disable_ids_fn = disable_ids_fn self.disable_reverse_ids_fn = disable_reverse_ids_fn @@ -114,7 +122,8 @@ class MemoizedUserRelManager(UserRelManager): return wrapper -def UserRel(name, relation, disable_ids_fn=False, disable_reverse_ids_fn=False): +def UserRel(name, relation, disable_ids_fn=False, disable_reverse_ids_fn=False, + permission_class=None): """Mixin for Thing subclasses for managing a relation to users. Provides the following suite of methods for a relation named "": @@ -130,7 +139,8 @@ def UserRel(name, relation, disable_ids_fn=False, disable_reverse_ids_fn=False): related to """ mgr = MemoizedUserRelManager( - name, relation, disable_ids_fn, disable_reverse_ids_fn) + name, relation, permission_class, + disable_ids_fn, disable_reverse_ids_fn) class UR: @classmethod diff --git a/r2/r2/lib/validator/validator.py b/r2/r2/lib/validator/validator.py index 90e854788..984ea6fe0 100644 --- a/r2/r2/lib/validator/validator.py +++ b/r2/r2/lib/validator/validator.py @@ -859,13 +859,15 @@ class VTrafficViewer(VSponsor): promote.is_traffic_viewer(thing, c.user)) class VSrModerator(Validator): - def __init__(self, fatal=True, *a, **kw): - Validator.__init__(self, *a, **kw) + def __init__(self, fatal=True, perms=(), *a, **kw): # If True, abort rather than setting an error self.fatal = fatal + self.perms = utils.tup(perms) + super(VSrModerator, self).__init__(*a, **kw) def run(self): - if not (c.user_is_loggedin and c.site.is_moderator(c.user) + if not (c.user_is_loggedin + and c.site.is_moderator_with_perms(c.user, *self.perms) or c.user_is_admin): if self.fatal: abort(403, "forbidden") diff --git a/r2/r2/models/subreddit.py b/r2/r2/models/subreddit.py index cc6d0f037..846b2a867 100644 --- a/r2/r2/models/subreddit.py +++ b/r2/r2/models/subreddit.py @@ -56,6 +56,81 @@ from r2.models.wiki import WikiPage import os.path import random +class PermissionSet(dict): + ALL = 'all' + + info = None + + def __init__(self, *args, **kwargs): + super(PermissionSet, self).__init__(*args, **kwargs) + + @classmethod + def loads(cls, encoded, validate=False): + if not encoded: + return cls() + result = cls(((term[1:], term[0] == '+') + for term in encoded.split(','))) + if result.get(cls.ALL) == False: + del result[cls.ALL] + if validate and not result.is_valid(): + raise ValueError + return result + + def dumps(self): + if self.is_superuser(): + return '+all' + return ','.join('-+'[bool(v)] + k for k, v in sorted(self.iteritems())) + + def is_superuser(self): + return super(PermissionSet, self).get(self.ALL) + + def is_valid(self): + if not self.info: + return False + for k in self: + if k != self.ALL and k not in self.info: + return False + return True + + def get(self, key, default=None): + if self.info and self.is_superuser(): + return True if key in self.info else default + return super(PermissionSet, self).get(key, default) + + def __getitem__(self, key): + if self.info and self.is_superuser(): + return key in self.info + return super(PermissionSet, self).get(key, False) + + +class ModeratorPermissionSet(PermissionSet): + info = dict( + access=dict( + title=_('access'), + description=_('manage the lists of contributors and banned users'), + ), + config=dict( + title=_('config'), + description=_('edit settings, sidebar, css, and images'), + ), + flair=dict( + title=_('flair'), + description=_('manage user flair, link flair, and flair templates'), + ), + posts=dict( + title=_('posts'), + description=_( + 'use the approve, remove, spam, distinguish, and nsfw buttons'), + ), + ) + + @classmethod + def loads(cls, encoded, **kwargs): + if encoded is None: + return cls(all=True) + return super(ModeratorPermissionSet, cls).loads(encoded, **kwargs) + + class SubredditExists(Exception): pass class Subreddit(Thing, Printable): @@ -825,6 +900,29 @@ class Subreddit(Thing, Printable): # is really slow return [rel._thing2_id for rel in list(merged)] + def is_moderator_with_perms(self, user, *perms): + rel = self.is_moderator(user) + if rel: + return all(rel.has_permission(perm) for perm in perms) + + def is_limited_moderator(self, user): + rel = self.is_moderator(user) + return rel and rel.permissions is not None + + def update_moderator_permissions(self, user, **kwargs): + """Grants or denies permissions to this moderator. + + Does nothing if the given user is not a moderator. + + Args are named parameters with bool or None values (use None to disable + granting or denying the permission). + """ + rel = self.get_moderator(user) + if rel: + rel.update_permissions(**kwargs) + rel._commit() + + class FakeSubreddit(Subreddit): over_18 = False _nodb = True @@ -1261,10 +1359,54 @@ Subreddit._specials.update(dict(friends = Friends, contrib = Contrib, all = All)) -class SRMember(Relation(Subreddit, Account)): pass +class SRMember(Relation(Subreddit, Account)): + _defaults = dict(encoded_permissions=None) + _permission_class = None + + def has_permission(self, perm): + """Returns whether this member has explicitly been granted a permission. + """ + return self.get_permissions().get(perm, False) + + def get_permissions(self): + """Returns permission set for this member (or None if N/A).""" + if not self._permission_class: + raise NotImplementedError + return self._permission_class.loads(self.encoded_permissions) + + def update_permissions(self, **kwargs): + """Grants or denies permissions to this member. + + Args are named parameters with bool or None values (use None to disable + granting or denying the permission). After calling this method, + the relation will be _dirty until _commit is called. + """ + if not self._permission_class: + raise NotImplementedError + perm_set = self._permission_class.loads(self.encoded_permissions) + if perm_set is None: + perm_set = self._permission_class() + for k, v in kwargs.iteritems(): + if v is None: + if k in perm_set: + del perm_set[k] + else: + perm_set[k] = v + self.encoded_permissions = perm_set.dumps() + + def set_permissions(self, perm_set): + """Assigns a permission set to this relation.""" + self.encoded_permissions = perm_set.dumps() + + def is_superuser(self): + return self.get_permissions().is_superuser() + + Subreddit.__bases__ += ( - UserRel('moderator', SRMember), - UserRel('moderator_invite', SRMember), + UserRel('moderator', SRMember, + permission_class=ModeratorPermissionSet), + UserRel('moderator_invite', SRMember, + permission_class=ModeratorPermissionSet), UserRel('contributor', SRMember), UserRel('subscriber', SRMember, disable_ids_fn=True), UserRel('banned', SRMember), @@ -1272,6 +1414,7 @@ Subreddit.__bases__ += ( UserRel('wikicontributor', SRMember), ) + class SubredditPopularityByLanguage(tdb_cassandra.View): _use_db = True _value_type = 'pickle' diff --git a/r2/r2/tests/__init__.py b/r2/r2/tests/__init__.py index 4293af01c..ffcc3ee7a 100644 --- a/r2/r2/tests/__init__.py +++ b/r2/r2/tests/__init__.py @@ -39,6 +39,21 @@ pkg_resources.working_set.add_entry(conf_dir) pkg_resources.require('Paste') pkg_resources.require('PasteScript') + +def stage_for_paste(): + wsgiapp = loadapp('config:test.ini', relative_to=conf_dir) + test_app = paste.fixture.TestApp(wsgiapp) + + # this is basically what 'paster run' does (see r2/commands.py) + test_response = test_app.get("/_test_vars") + request_id = int(test_response.body) + test_app.pre_request_hook = lambda self: \ + paste.registry.restorer.restoration_end() + test_app.post_request_hook = lambda self: \ + paste.registry.restorer.restoration_begin(request_id) + paste.registry.restorer.restoration_begin(request_id) + + class RedditTestCase(TestCase): """Base Test Case for tests that require the app environment to run. @@ -47,16 +62,5 @@ class RedditTestCase(TestCase): """ def __init__(self, *args, **kwargs): - wsgiapp = loadapp('config:test.ini', relative_to=conf_dir) - test_app = paste.fixture.TestApp(wsgiapp) - - # this is basically what 'paster run' does (see r2/commands.py) - test_response = test_app.get("/_test_vars") - request_id = int(test_response.body) - test_app.pre_request_hook = lambda self: \ - paste.registry.restorer.restoration_end() - test_app.post_request_hook = lambda self: \ - paste.registry.restorer.restoration_begin(request_id) - paste.registry.restorer.restoration_begin(request_id) - + stage_for_paste() TestCase.__init__(self, *args, **kwargs) diff --git a/r2/r2/tests/unit/models/__init__.py b/r2/r2/tests/unit/models/__init__.py new file mode 100644 index 000000000..0deedce08 --- /dev/null +++ b/r2/r2/tests/unit/models/__init__.py @@ -0,0 +1,3 @@ +from r2.tests import stage_for_paste + +stage_for_paste() diff --git a/r2/r2/tests/unit/models/subreddit_test.py b/r2/r2/tests/unit/models/subreddit_test.py new file mode 100644 index 000000000..a8a119952 --- /dev/null +++ b/r2/r2/tests/unit/models/subreddit_test.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python + +import unittest + +from r2.models.account import Account +from r2.models.subreddit import ( + ModeratorPermissionSet, + PermissionSet, + SRMember, + Subreddit, +) + +class TestPermissionSet(PermissionSet): + info = dict(x={}, y={}) + +class PermissionSetTest(unittest.TestCase): + def test_dumps(self): + self.assertEquals( + '+all', PermissionSet(all=True).dumps()) + self.assertEquals( + '+all', PermissionSet(all=True, other=True).dumps()) + self.assertEquals( + '+a,-b', PermissionSet(a=True, b=False).dumps()) + + def test_loads(self): + self.assertEquals("", TestPermissionSet.loads(None).dumps()) + self.assertEquals("", TestPermissionSet.loads("").dumps()) + self.assertEquals("+x,+y", TestPermissionSet.loads("+x,+y").dumps()) + self.assertEquals("+x,-y", TestPermissionSet.loads("+x,-y").dumps()) + self.assertEquals("+all", TestPermissionSet.loads("+x,-y,+all").dumps()) + self.assertEquals("+x,-y,+z", + TestPermissionSet.loads("+x,-y,+z").dumps()) + self.assertRaises(ValueError, + TestPermissionSet.loads, "+x,-y,+z", validate=True) + self.assertEquals( + "+x,-y", + TestPermissionSet.loads("-all,+x,-y", validate=True).dumps()) + + def test_is_superuser(self): + perm_set = PermissionSet() + self.assertFalse(perm_set.is_superuser()) + perm_set[perm_set.ALL] = True + self.assertTrue(perm_set.is_superuser()) + perm_set[perm_set.ALL] = False + self.assertFalse(perm_set.is_superuser()) + + def test_is_valid(self): + perm_set = PermissionSet() + self.assertFalse(perm_set.is_valid()) + + perm_set = TestPermissionSet() + self.assertTrue(perm_set.is_valid()) + perm_set['x'] = True + self.assertTrue(perm_set.is_valid()) + perm_set[perm_set.ALL] = True + self.assertTrue(perm_set.is_valid()) + perm_set['z'] = True + self.assertFalse(perm_set.is_valid()) + + def test_getitem(self): + perm_set = PermissionSet() + perm_set[perm_set.ALL] = True + self.assertFalse(perm_set['x']) + + perm_set = TestPermissionSet() + perm_set['x'] = True + self.assertTrue(perm_set['x']) + self.assertFalse(perm_set['y']) + perm_set['x'] = False + self.assertFalse(perm_set['x']) + perm_set[perm_set.ALL] = True + self.assertTrue(perm_set['x']) + self.assertTrue(perm_set['y']) + self.assertFalse(perm_set['z']) + self.assertTrue(perm_set.get('x', False)) + self.assertFalse(perm_set.get('z', False)) + self.assertTrue(perm_set.get('z', True)) + + +class ModeratorPermissionSetTest(unittest.TestCase): + def test_loads(self): + self.assertTrue(ModeratorPermissionSet.loads(None).is_superuser()) + self.assertFalse(ModeratorPermissionSet.loads('').is_superuser()) + + +class SRMemberTest(unittest.TestCase): + def setUp(self): + a = Account() + a._commit() + sr = Subreddit() + sr._commit() + self.rel = SRMember(sr, a, 'test') + + def test_get_permissions(self): + self.assertRaises(NotImplementedError, self.rel.get_permissions) + self.rel._permission_class = TestPermissionSet + self.assertEquals('', self.rel.get_permissions().dumps()) + self.rel.encoded_permissions = '+x,-y' + self.assertEquals('+x,-y', self.rel.get_permissions().dumps()) + + def test_has_permission(self): + self.assertRaises(NotImplementedError, self.rel.has_permission, 'x') + self.rel._permission_class = TestPermissionSet + self.assertFalse(self.rel.has_permission('x')) + self.rel.encoded_permissions = '+x,-y' + self.assertTrue(self.rel.has_permission('x')) + self.assertFalse(self.rel.has_permission('y')) + self.rel.encoded_permissions = '+all' + self.assertTrue(self.rel.has_permission('x')) + self.assertTrue(self.rel.has_permission('y')) + self.assertFalse(self.rel.has_permission('z')) + + def test_update_permissions(self): + self.assertRaises(NotImplementedError, + self.rel.update_permissions, x=True) + self.rel._permission_class = TestPermissionSet + self.rel.update_permissions(x=True, y=False) + self.assertEquals('+x,-y', self.rel.encoded_permissions) + self.rel.update_permissions(x=None) + self.assertEquals('-y', self.rel.encoded_permissions) + self.rel.update_permissions(y=None, z=None) + self.assertEquals('', self.rel.encoded_permissions) + self.rel.update_permissions(x=True, y=False, all=True) + self.assertEquals('+all', self.rel.encoded_permissions) + + def test_set_permissions(self): + self.rel.set_permissions(PermissionSet(x=True, y=False)) + self.assertEquals('+x,-y', self.rel.encoded_permissions) + + def test_is_superuser(self): + self.assertRaises(NotImplementedError, self.rel.is_superuser) + self.rel._permission_class = TestPermissionSet + self.assertFalse(self.rel.is_superuser()) + self.rel.encoded_permissions = '+all' + self.assertTrue(self.rel.is_superuser()) + +if __name__ == '__main__': + unittest.main()