diff --git a/r2/r2/lib/app_globals.py b/r2/r2/lib/app_globals.py index f8c108dc3..5a8a4142e 100755 --- a/r2/r2/lib/app_globals.py +++ b/r2/r2/lib/app_globals.py @@ -589,11 +589,6 @@ class Globals(object): self.startup_timer.intermediate("cassandra") ################# POSTGRES - event.listens_for(engine.Engine, 'before_cursor_execute')( - self.stats.pg_before_cursor_execute) - event.listens_for(engine.Engine, 'after_cursor_execute')( - self.stats.pg_after_cursor_execute) - self.dbm = self.load_db_params() self.startup_timer.intermediate("postgres") diff --git a/r2/r2/lib/manager/db_manager.py b/r2/r2/lib/manager/db_manager.py index a2e4214f7..8c8187040 100644 --- a/r2/r2/lib/manager/db_manager.py +++ b/r2/r2/lib/manager/db_manager.py @@ -35,7 +35,7 @@ APPLICATION_NAME = "reddit@%s:%d" % (socket.gethostname(), os.getpid()) def get_engine(name, db_host='', db_user='', db_pass='', db_port='5432', - pool_size=5, max_overflow=5): + pool_size=5, max_overflow=5, g_override=None): db_port = int(db_port) arguments = { @@ -50,7 +50,7 @@ def get_engine(name, db_host='', db_user='', db_pass='', db_port='5432', arguments["password"] = db_pass dsn = "%20".join("%s=%s" % x for x in arguments.iteritems()) - return sqlalchemy.create_engine( + engine = sqlalchemy.create_engine( 'postgresql:///?dsn=' + dsn, strategy='threadlocal', pool_size=int(pool_size), @@ -60,6 +60,14 @@ def get_engine(name, db_host='', db_user='', db_pass='', db_port='5432', use_native_unicode=False, ) + if g_override: + sqlalchemy.event.listens_for(engine, 'before_cursor_execute')( + g_override.stats.pg_before_cursor_execute) + sqlalchemy.event.listens_for(engine, 'after_cursor_execute')( + g_override.stats.pg_after_cursor_execute) + + return engine + class db_manager: def __init__(self): @@ -83,7 +91,7 @@ class db_manager: self.avoid_master_reads[name] = avoid_master def setup_db(self, db_name, g_override=None, **params): - engine = get_engine(**params) + engine = get_engine(g_override=g_override, **params) self._engines[db_name] = engine self.test_engine(engine, g_override)