diff --git a/r2/r2/lib/db/tdb_cassandra.py b/r2/r2/lib/db/tdb_cassandra.py index 1ce77e668..2a8b00995 100644 --- a/r2/r2/lib/db/tdb_cassandra.py +++ b/r2/r2/lib/db/tdb_cassandra.py @@ -925,7 +925,7 @@ class ColumnQuery(object): # Logic of standard reddit query is opposite of cassandra self.column_reversed = False - def __iter__(self): + def __iter__(self, yield_column_names=False): # Get the max number of columns we could grab in this query total_columns = self.cls._cf.get_count(self.rowkey, column_start=self.column_start, @@ -950,14 +950,87 @@ class ColumnQuery(object): column_start = columns[-1].keys()[0] objs = self.column_to_obj(columns) - objs, is_single = tup(objs, ret_is_single=True) - for obj in objs: - yield obj + if yield_column_names: + column_names = [column.keys()[0] for column in columns] + if len(column_names) == 1: + ret = (column_names[0], objs), + else: + ret = zip(column_names, objs) + else: + ret = objs + + ret, is_single = tup(ret, ret_is_single=True) + for r in ret: + yield r def __repr__(self): return "<%s(%s-%r)>" % (self.__class__.__name__, self.cls.__name__, self.rowkey) +class MultiColumnQuery(object): + def __init__(self, queries, num, sort_key=None): + self.num = num + self._queries = queries + self.sort_key = sort_key # python doesn't sort UUID1's correctly, need to pass in a sorter + self._rules = [] # dummy parameter to mimic tdb_sql queries + + def _after(self, thing): + for q in self._queries: + q._after(thing) + + def _reverse(self): + for q in self._queries: + q._reverse() + + def __setattr__(self, attr, val): + # Catch _limit to set on all queries + if attr == '_limit': + for q in self._queries: + q._limit = val + else: + object.__setattr__(self, attr, val) + + def __iter__(self): + + if self.sort_key: + def sort_key(tup): + # Need to point the supplied sort key at the correct item in + # the (sortable, item, generator) tuple + return self.sort_key(tup[0]) + else: + def sort_key(tup): + return tup[0] + + top_items = [] + for q in self._queries: + try: + gen = q.__iter__(yield_column_names=True) + column_name, item = gen.next() + top_items.append((column_name, item, gen)) + except StopIteration: + pass + top_items.sort(key=sort_key) + + def _update(top_items): + # Remove the first item from combined query and update the list + head = top_items.pop(0) + item = head[1] + gen = head[2] + + # Try to get a new item from the query that gave us the current one + try: + column_name, item = gen.next() + top_items.append((column_name, item, gen)) # if multiple queues have the same item value the sort is somewhat undefined + top_items.sort(key=sort_key) + except StopIteration: + pass + + num_ret = 0 + while top_items and num_ret < self.num: + yield top_items[0][1] + _update(top_items) + num_ret += 1 + class Query(object): """A query across a CF. Note that while you can query rows from a CF that has a RandomPartitioner, you won't get them in any sort