diff --git a/r2/r2/models/subreddit.py b/r2/r2/models/subreddit.py index 644215229..821dd458d 100644 --- a/r2/r2/models/subreddit.py +++ b/r2/r2/models/subreddit.py @@ -26,6 +26,8 @@ import base64 import collections import datetime import hashlib +import itertools +import json from pylons import c, g from pylons.i18n import _ @@ -1269,6 +1271,174 @@ class MultiReddit(FakeSubreddit): queries = [get_gilded_comments(sr_id) for sr_id in self.kept_sr_ids] return MergedCachedQuery(queries) + +class TooManySubredditsException(Exception): + pass + + +class LabeledMulti(MultiReddit, tdb_cassandra.Thing): + """Thing with special columns that hold Subreddit ids and properties.""" + _use_db = True + _views = [] + _defaults = {'visibility': 'private'} + _extra_schema_creation_args = { + "key_validation_class": tdb_cassandra.UTF8_TYPE, + "column_name_class": tdb_cassandra.UTF8_TYPE, + "default_validation_class": tdb_cassandra.UTF8_TYPE, + } + _compare_with = tdb_cassandra.UTF8_TYPE + + SR_PREFIX = 'SR_' + MAX_SR_COUNT = 100 + + def __init__(self, _id=None, *args, **kwargs): + tdb_cassandra.Thing.__init__(self, _id, *args, **kwargs) + MultiReddit.__init__(self) + self._owner = None + + @classmethod + def _byID(cls, ids, return_dict=True, properties=None): + ret = super(cls, cls)._byID(ids, return_dict=False, + properties=properties) + if not ret: + return + ret = cls._load(ret) + if isinstance(ret, cls): + return ret + elif return_dict: + return {thing._id: thing for thing in ret} + else: + return ret + + @classmethod + def _load_no_lookup(cls, things, srs_dict, owners_dict): + things, single = tup(things, ret_is_single=True) + for thing in things: + thing._srs = [srs_dict[sr_id] for sr_id in thing.sr_ids] + thing._owner = owners_dict[thing.owner_fullname] + return things[0] if single else things + + @classmethod + def _load(cls, things): + things, single = tup(things, ret_is_single=True) + sr_ids = set(itertools.chain(*[thing.sr_ids for thing in things])) + owner_fullnames = set((thing.owner_fullname for thing in things)) + + srs = Subreddit._byID(sr_ids, data=True, return_dict=True) + owners = Thing._by_fullname(owner_fullnames, data=True, return_dict=True) + ret = cls._load_no_lookup(things, srs, owners) + return ret[0] if single else things + + @property + def sr_ids(self): + return self.sr_props.keys() + + @property + def srs(self): + return self._srs + + @property + def owner(self): + return self._owner + + @property + def sr_props(self): + # limit to max subreddit count, allowing a little fudge room for + # cassandra inconsistency + remaining = self.MAX_SR_COUNT + 10 + sr_columns = {} + for k, v in self._t.iteritems(): + if remaining <= 0: + break + remaining -= 1 + + if k.startswith(self.SR_PREFIX): + sr_columns[k] = v + return self.columns_to_sr_props(sr_columns) + + @property + def path(self): + return self._id + + @property + def name(self): + return self.path.split('/')[-1] + + def can_view(self, user): + return user == self.owner or self.visibility == 'public' + + def can_edit(self, user): + return user == self.owner + + @classmethod + def by_owner(cls, owner): + return list(LabeledMultiByOwner.query([owner._fullname])) + + @classmethod + def create(cls, path, owner, sr_props=None): + # sr_props is {sr_id: properties_dict} + sr_columns = cls.sr_props_to_columns(sr_props) if sr_props else {} + obj = cls(_id=path, owner_fullname=owner._fullname, **sr_columns) + obj._commit() + obj._owner = owner + return obj + + @classmethod + def sr_props_to_columns(cls, sr_props): + sr_columns = {cls.SR_PREFIX + str(sr_id): json.dumps(props) + for sr_id, props in sr_props.iteritems()} + return sr_columns + + @classmethod + def columns_to_sr_props(cls, columns): + ret = {} + for s, sr_prop_dump in columns.iteritems(): + sr_id = long(s.strip(cls.SR_PREFIX)) + sr_props = json.loads(sr_prop_dump) + ret[sr_id] = sr_props + return ret + + def _on_create(self): + for view in self._views: + view.add_object(self) + + def add_srs(self, sr_props): + """Add/overwrite subreddit(s).""" + sr_columns = self.sr_props_to_columns(sr_props) + + if len(self._srs) + len(sr_columns) > self.MAX_SR_COUNT: + raise TooManySubredditsException + + for attr, val in sr_columns.items(): + self.__setattr__(attr, val) + self._commit() + + def del_srs(self, sr_ids): + """Delete subreddit(s).""" + sr_ids = tup(sr_ids) + keys = self.sr_props_to_columns(dict.fromkeys(sr_ids, '')).keys() + for key in keys: + self.__delitem__(key) + self._commit() + + def delete(self): + # Do we want to actually delete objects? + self._destroy() + for view in self._views: + rowkey = view._rowkey(self) + column = view._obj_to_column(self) + view._remove(rowkey, column) + + +@tdb_cassandra.view_of(LabeledMulti) +class LabeledMultiByOwner(tdb_cassandra.View): + _use_db = True + + @classmethod + def _rowkey(cls, lm): + return lm.owner_fullname + + class RandomReddit(FakeSubreddit): name = 'random' header = ""