diff --git a/r2/r2/controllers/reddit_base.py b/r2/r2/controllers/reddit_base.py index ce2ca3b7f..0a59deb1e 100644 --- a/r2/r2/controllers/reddit_base.py +++ b/r2/r2/controllers/reddit_base.py @@ -660,6 +660,7 @@ class MinimalController(BaseController): # push data to statsd g.stats.transact('web.%s' % action, (end_time - c.start_time).total_seconds()) + g.stats.flush_cassandra_events() def abort404(self): abort(404, "not found") diff --git a/r2/r2/lib/app_globals.py b/r2/r2/lib/app_globals.py index b40503da4..75c10adf3 100755 --- a/r2/r2/lib/app_globals.py +++ b/r2/r2/lib/app_globals.py @@ -26,7 +26,6 @@ import signal from datetime import timedelta, datetime from urlparse import urlparse import json -from pycassa.pool import ConnectionPool as PycassaConnectionPool from r2.lib.cache import LocalCache, SelfEmptyingCache from r2.lib.cache import CMemcache, StaleCacheChain from r2.lib.cache import HardCache, MemcacheChain, MemcacheChain, HardcacheChain @@ -36,7 +35,7 @@ from r2.lib.db.stats import QueryStats from r2.lib.translation import get_active_langs from r2.lib.lock import make_lock_factory from r2.lib.manager import db_manager -from r2.lib.stats import Stats, CacheStats +from r2.lib.stats import Stats, CacheStats, StatsCollectingConnectionPool class Globals(object): @@ -215,6 +214,9 @@ class Globals(object): self.memcache = CMemcache(self.memcaches, num_clients = num_mc_clients) self.make_lock = make_lock_factory(self.memcache) + self.stats = Stats(global_conf.get('statsd_addr'), + global_conf.get('statsd_sample_rate')) + if not self.cassandra_seeds: raise ValueError("cassandra_seeds not set in the .ini") @@ -222,8 +224,9 @@ class Globals(object): keyspace = "reddit" self.cassandra_pools = { "main": - PycassaConnectionPool( + StatsCollectingConnectionPool( keyspace, + stats=self.stats, logging_name="main", server_list=self.cassandra_seeds, pool_size=len(self.cassandra_seeds), @@ -232,8 +235,9 @@ class Globals(object): prefill=False ), "noretries": - PycassaConnectionPool( + StatsCollectingConnectionPool( keyspace, + stats=self.stats, logging_name="noretries", server_list=self.cassandra_seeds, pool_size=len(self.cassandra_seeds), @@ -291,9 +295,6 @@ class Globals(object): cache_negative_results = True) self.cache_chains.update(hardcache=self.hardcache) - self.stats = Stats(global_conf.get('statsd_addr'), - global_conf.get('statsd_sample_rate')) - # I know this sucks, but we need non-request-threads to be # able to reset the caches, so we need them be able to close # around 'cache_chains' without being able to call getattr on diff --git a/r2/r2/lib/stats.py b/r2/r2/lib/stats.py index 598166428..d5f266fd8 100644 --- a/r2/r2/lib/stats.py +++ b/r2/r2/lib/stats.py @@ -1,6 +1,10 @@ +import collections import random import time +from pycassa import columnfamily +from pycassa import pool + from r2.lib import cache from r2.lib import utils @@ -9,6 +13,8 @@ class Stats: # sample_rate. CACHE_SAMPLE_RATE = 0.01 + CASSANDRA_KEY_SUFFIXES = ['error', 'ok'] + def __init__(self, addr, sample_rate): if addr: import statsd @@ -23,6 +29,7 @@ class Stats: self.port = None self.sample_rate = None self.connection = None + self.cassandra_events = collections.defaultdict(int) def get_timer(self, name): if self.connection: @@ -77,6 +84,36 @@ class Stats: return wrap_processor return decorator + def cassandra_event(self, operation, column_families, success, + service_time): + if not isinstance(column_families, list): + column_families = [column_families] + for cf in column_families: + key = '.'.join([ + cf, operation, self.CASSANDRA_KEY_SUFFIXES[success]]) + self.cassandra_events[key + '.time'] += service_time + self.cassandra_events[key] += 1 + + def flush_cassandra_events(self): + events = self.cassandra_events + self.cassandra_events = collections.defaultdict(int) + if self.connection: + data = {} + for k, v in events.iteritems(): + if k.endswith('.time'): + suffix = '|ms' + # these stats get stored under timers, so chop off ".time" + k = k[:-5] + if k.endswith('.ok'): + # only report the mean over the duration of this request + v /= events.get(k, 1) + # chop off the ".ok" since we aren't storing error times + k = k[:-3] + else: + suffix = '|c' + data['cassandra.' + k] = str(v) + suffix + self.connection.send(data) + class CacheStats: def __init__(self, parent, cache_name): self.parent = parent @@ -94,3 +131,74 @@ class CacheStats: if delta: self.parent.cache_count(self.miss_stat_name, delta=delta) self.parent.cache_count(self.total_stat_name, delta=delta) + +class StatsCollectingConnectionPool(pool.ConnectionPool): + def __init__(self, keyspace, stats=None, *args, **kwargs): + pool.ConnectionPool.__init__(self, keyspace, *args, **kwargs) + self.stats = stats + + def _get_new_wrapper(self, server): + cf_types = (columnfamily.ColumnParent, columnfamily.ColumnPath) + + def get_cf_name_from_args(args, kwargs): + for v in args: + if isinstance(v, cf_types): + return v.column_family + for v in kwargs.itervalues(): + if isinstance(v, cf_types): + return v.column_family + return None + + def get_cf_name_from_batch_mutation(args, kwargs): + cf_names = set() + mutation_map = args[0] + for key_mutations in mutation_map.itervalues(): + cf_names.update(key_mutations) + return list(cf_names) + + instrumented_methods = dict( + get=get_cf_name_from_args, + get_slice=get_cf_name_from_args, + multiget_slice=get_cf_name_from_args, + get_count=get_cf_name_from_args, + multiget_count=get_cf_name_from_args, + get_range_slices=get_cf_name_from_args, + get_indexed_slices=get_cf_name_from_args, + insert=get_cf_name_from_args, + batch_mutate=get_cf_name_from_batch_mutation, + add=get_cf_name_from_args, + remove=get_cf_name_from_args, + remove_counter=get_cf_name_from_args, + truncate=lambda args, kwargs: args[0], + ) + + def record_error(method_name, cf_name, service_time): + if cf_name and self.stats: + self.stats.cassandra_event(method_name, cf_name, False, + service_time) + + def record_success(method_name, cf_name, service_time): + if cf_name and self.stats: + self.stats.cassandra_event(method_name, cf_name, True, + service_time) + + def instrument(f, get_cf_name): + def call_with_instrumentation(*args, **kwargs): + cf_name = get_cf_name(args, kwargs) + start = time.time() + try: + result = f(*args, **kwargs) + except: + record_error(f.__name__, cf_name, time.time() - start) + raise + else: + record_success(f.__name__, cf_name, time.time() - start) + return result + return call_with_instrumentation + + wrapper = pool.ConnectionPool._get_new_wrapper(self, server) + for method_name, get_cf_name in instrumented_methods.iteritems(): + f = getattr(wrapper, method_name) + setattr(wrapper, method_name, instrument(f, get_cf_name)) + return wrapper +