diff --git a/r2/r2/lib/cache.py b/r2/r2/lib/cache.py index 3c1a3ec81..52831b730 100644 --- a/r2/r2/lib/cache.py +++ b/r2/r2/lib/cache.py @@ -41,6 +41,10 @@ from r2.lib.sgm import sgm # get this into our namespace so that it's class NoneResult(object): pass class CacheUtils(object): + # Caches that never expire entries should set this to true, so that + # CacheChain can properly count hits and misses. + permanent = False + def incr_multi(self, keys, delta=1, prefix=''): for k in keys: try: @@ -185,6 +189,7 @@ class CMemcache(CacheUtils): class HardCache(CacheUtils): backend = None + permanent = True def __init__(self, gc): self.backend = HardCacheBackend(gc) @@ -354,37 +359,41 @@ class CacheChain(CacheUtils, local): flush_all = make_set_fn('flush_all') cache_negative_results = False - def get(self, key, default = None, allow_local = True): - for c in self.caches: - if not allow_local and isinstance(c,LocalCache): - continue + def get(self, key, default = None, allow_local = True, stale=None): + stat_outcome = False # assume a miss until a result is found + try: + for c in self.caches: + if not allow_local and isinstance(c,LocalCache): + continue - val = c.get(key) + val = c.get(key) - if val is not None: - if self.stats: + if val is not None: + if not c.permanent: + stat_outcome = True + + #update other caches + for d in self.caches: + if c is d: + break # so we don't set caches later in the chain + d.set(key, val) + + if val == NoneResult: + return default + else: + return val + + if self.cache_negative_results: + for c in self.caches[:-1]: + c.set(key, NoneResult) + + return default + finally: + if self.stats: + if stat_outcome: self.stats.cache_hit() - - #update other caches - for d in self.caches: - if c is d: - break # so we don't set caches later in the chain - d.set(key, val) - - if val == NoneResult: - return default else: - return val - - #didn't find anything - if self.stats: - self.stats.cache_miss() - - if self.cache_negative_results: - for c in self.caches[:-1]: - c.set(key, NoneResult) - - return default + self.stats.cache_miss() def get_multi(self, keys, prefix='', allow_local = True, **kw): l = lambda ks: self.simple_get_multi(ks, allow_local = allow_local, **kw) @@ -393,16 +402,25 @@ class CacheChain(CacheUtils, local): def simple_get_multi(self, keys, allow_local = True, stale=None): out = {} need = set(keys) + hits = 0 + misses = 0 for c in self.caches: if not allow_local and isinstance(c, LocalCache): continue + if c.permanent and not misses: + # Once we reach a "permanent" cache, we count any outstanding + # items as misses. + misses = len(need) + if len(out) == len(keys): # we've found them all break r = c.simple_get_multi(need) #update other caches if r: + if not c.permanent: + hits += len(r) for d in self.caches: if c is d: break # so we don't set caches later in the chain @@ -421,8 +439,8 @@ class CacheChain(CacheUtils, local): if v != NoneResult) if self.stats: - self.stats.cache_hit(len(out)) - self.stats.cache_miss(len(need)) + self.stats.cache_hit(hits) + self.stats.cache_miss(misses) return out @@ -638,6 +656,8 @@ class CassandraCacheChain(CacheChain): class CassandraCache(CacheUtils): + permanent = True + """A cache that uses a Cassandra ColumnFamily. Uses only the column-name 'value'""" def __init__(self, column_family, client,