diff --git a/r2/r2/lib/cloudsearch.py b/r2/r2/lib/cloudsearch.py index de95f6e57..f50764767 100644 --- a/r2/r2/lib/cloudsearch.py +++ b/r2/r2/lib/cloudsearch.py @@ -509,11 +509,13 @@ INVALID_QUERY_CODES = ('CS-UnknownFieldInMatchExpression', 'CS-IncorrectFieldTypeInMatchExpression', 'CS-InvalidMatchSetExpression',) DEFAULT_FACETS = {"reddit": {"count":20}} -def basic_query(query=None, bq=None, faceting=DEFAULT_FACETS, size=1000, +def basic_query(query=None, bq=None, faceting=None, size=1000, start=0, rank="-relevance", return_fields=None, record_stats=False, search_api=None): if search_api is None: search_api = g.CLOUDSEARCH_SEARCH_API + if faceting is None: + faceting = DEFAULT_FACETS path = _encode_query(query, bq, faceting, size, start, rank, return_fields) timer = None if record_stats: @@ -600,7 +602,8 @@ class CloudSearchQuery(object): default_syntax = "plain" lucene_parser = None - def __init__(self, query, sr=None, sort=None, syntax=None, raw_sort=None): + def __init__(self, query, sr=None, sort=None, syntax=None, raw_sort=None, + faceting=None): if syntax is None: syntax = self.default_syntax elif syntax not in self.known_syntaxes: @@ -614,6 +617,7 @@ class CloudSearchQuery(object): self.sort = raw_sort else: self.sort = self.sorts[sort] + self.faceting = faceting self.bq = '' self.results = None @@ -644,8 +648,8 @@ class CloudSearchQuery(object): q = self.query if g.sqlprinting: g.log.info("%s", self) - return self._run_cached(q, self.bq, self.sort, start=start, num=num, - _update=_update) + return self._run_cached(q, self.bq, self.sort, self.faceting, + start=start, num=num, _update=_update) def customize_query(self, bq): return bq @@ -663,8 +667,8 @@ class CloudSearchQuery(object): return ''.join(result) @classmethod - def _run_cached(cls, query, bq, sort="relevance", start=0, num=1000, - _update=False): + def _run_cached(cls, query, bq, sort="relevance", faceting=None, start=0, + num=1000, _update=False): '''Query the cloudsearch API. _update parameter allows for supposed easy memoization at later date. @@ -701,7 +705,7 @@ class CloudSearchQuery(object): ''' response = basic_query(query=query, bq=bq, size=num, start=start, rank=sort, search_api=cls.search_api, - record_stats=True) + faceting=faceting, record_stats=True) warnings = response['info'].get('messages', []) for warning in warnings: