diff --git a/r2/r2/lib/db/tdb_cassandra.py b/r2/r2/lib/db/tdb_cassandra.py index 2a8b00995..8c98fa137 100644 --- a/r2/r2/lib/db/tdb_cassandra.py +++ b/r2/r2/lib/db/tdb_cassandra.py @@ -34,6 +34,7 @@ from r2.lib import cache from uuid import uuid1, UUID from itertools import chain import cPickle as pickle +from pycassa.util import OrderedDict connection_pools = g.cassandra_pools default_connection_pool = g.cassandra_default_pool @@ -874,12 +875,12 @@ class ColumnQuery(object): """ _chunk_size = 100 - def __init__(self, cls, rowkey, column_start="", column_finish="", + def __init__(self, cls, rowkeys, column_start="", column_finish="", column_count=100, column_reversed=True, column_to_obj=None, obj_to_column=None): self.cls = cls - self.rowkey = rowkey + self.rowkeys = rowkeys self.column_start = column_start self.column_finish = column_finish self._limit = column_count @@ -888,6 +889,19 @@ class ColumnQuery(object): self.obj_to_column = obj_to_column or self.default_obj_to_column self._rules = [] # dummy parameter to mimic tdb_sql queries + # Sorting for TimeUuid objects + if self.cls._compare_with == TIME_UUID_TYPE: + def sort_key(i): + return i.time + else: + def sort_key(i): + return i + self.sort_key = sort_key + + @staticmethod + def combine(queries): + raise NotImplementedError + @staticmethod def default_column_to_obj(columns): """ @@ -926,32 +940,55 @@ class ColumnQuery(object): self.column_reversed = False def __iter__(self, yield_column_names=False): - # Get the max number of columns we could grab in this query - total_columns = self.cls._cf.get_count(self.rowkey, - column_start=self.column_start, - column_finish=self.column_finish) - retrievable_columns = min(total_columns, self._limit) - retrieved = 0 column_start = self.column_start - while retrieved <= retrievable_columns: + while retrieved < self._limit: try: - r = self.cls._cf.get(self.rowkey, column_start=column_start, - column_finish=self.column_finish, - column_count=self._chunk_size, - column_reversed=self.column_reversed) + column_count = min(self._chunk_size, self._limit - retrieved) + if column_start: + column_count += 1 # cassandra includes column_start + r = self.cls._cf.multiget(self.rowkeys, + column_start=column_start, + column_finish=self.column_finish, + column_count=column_count, + column_reversed=self.column_reversed) + + # multiget returns OrderedDict {rowkey: {column_name: column_value}} + # combine into single OrderedDict of {column_name: column_value} + nrows = len(r.keys()) + if nrows == 0: + return + elif nrows == 1: + columns = r.values()[0] + else: + r_combined = {} + for d in r.values(): + r_combined.update(d) + columns = OrderedDict(sorted(r_combined.items(), + key=lambda t: self.sort_key(t[0]), + reverse=self.column_reversed)) except NotFoundException: return retrieved += self._chunk_size - columns = [{col_name: r[col_name]} for col_name in r if col_name != column_start] + + if column_start: + try: + del columns[column_start] + except KeyError: + columns.popitem(last=True) # remove extra column + if not columns: return - column_start = columns[-1].keys()[0] - objs = self.column_to_obj(columns) + + # Convert to list of columns + l_columns = [{col_name: columns[col_name]} for col_name in columns] + + column_start = l_columns[-1].keys()[0] + objs = self.column_to_obj(l_columns) if yield_column_names: - column_names = [column.keys()[0] for column in columns] + column_names = [column.keys()[0] for column in l_columns] if len(column_names) == 1: ret = (column_names[0], objs), else: @@ -965,7 +1002,7 @@ class ColumnQuery(object): def __repr__(self): return "<%s(%s-%r)>" % (self.__class__.__name__, self.cls.__name__, - self.rowkey) + self.rowkeys) class MultiColumnQuery(object): def __init__(self, queries, num, sort_key=None): @@ -1164,12 +1201,12 @@ class View(ThingBase): cls._set_values(rowkey, column, **kw) @classmethod - def query(cls, rowkey, after=None, reverse=False, count=1000): + def query(cls, rowkeys, after=None, reverse=False, count=1000): """Return a query to get objects from the underlying _view_of class.""" column_reversed = not reverse # Reverse convention for cassandra is opposite - q = cls._query_cls(cls, rowkey, column_count=count, + q = cls._query_cls(cls, rowkeys, column_count=count, column_reversed=column_reversed, column_to_obj=cls._column_to_obj, obj_to_column=cls._obj_to_column)