Tests: Refactor test initialization

Change RedditTest so that it maintains global state and will only fire up the app
context a single time. Inject mocks for the database manager and avoid making
calls to the database during thing initialization.
This commit is contained in:
Chris Stephens
2015-07-01 16:40:15 -07:00
parent 6579f515f0
commit 19d569696e
15 changed files with 46 additions and 66 deletions

View File

@@ -425,6 +425,11 @@ class RedditApp(PylonsApp):
def setup_app_env(self, environ, start_response):
PylonsApp.setup_app_env(self, environ, start_response)
from pylons import g
# When running tests don't load controllers or register hooks. Loading the
# controllers currently causes db initialization and runs queries.
if g.env == 'unit_test':
return
self.load()
def load(self):

View File

@@ -442,6 +442,7 @@ class Globals(object):
raise AttributeError
def setup(self):
self.env = 'unit_test' if 'test' in sys.argv[0] else ''
self.queues = queues.declare_queues(self)
self.extension_subdomains = dict(
@@ -915,9 +916,14 @@ class Globals(object):
def load_db_params(self):
self.databases = tuple(ConfigValue.to_iter(self.config.raw_data['databases']))
self.db_params = {}
self.predefined_type_ids = {}
if not self.databases:
return
if self.env == 'unit_test':
from mock import MagicMock
return MagicMock()
dbm = db_manager.db_manager()
db_param_names = ('name', 'db_host', 'db_user', 'db_pass', 'db_port',
'pool_size', 'max_overflow')
@@ -959,7 +965,6 @@ class Globals(object):
return params, flags
prefix = 'db_table_'
self.predefined_type_ids = {}
for k, v in self.config.raw_data.iteritems():
if not k.startswith(prefix):
continue

View File

@@ -533,7 +533,8 @@ class DataThing(object):
class ThingMeta(type):
def __init__(cls, name, bases, dct):
if name == 'Thing' or hasattr(cls, '_nodb') and cls._nodb: return
#print "checking thing", name
if g.env == 'unit_test':
return
#TODO exceptions
cls._type_name = name.lower()
@@ -666,6 +667,9 @@ class RelationMeta(type):
if name == 'RelationCls': return
#print "checking relation", name
if g.env == 'unit_test':
return
cls._type_name = name.lower()
try:
cls._type_id = tdb.rel_types_name[cls._type_name].type_id

View File

@@ -39,19 +39,7 @@ pkg_resources.working_set.add_entry(conf_dir)
pkg_resources.require('Paste')
pkg_resources.require('PasteScript')
def stage_for_paste():
wsgiapp = loadapp('config:test.ini', relative_to=conf_dir)
test_app = paste.fixture.TestApp(wsgiapp)
# this is basically what 'paster run' does (see r2/commands.py)
test_response = test_app.get("/_test_vars")
request_id = int(test_response.body)
test_app.pre_request_hook = lambda self: \
paste.registry.restorer.restoration_end()
test_app.post_request_hook = lambda self: \
paste.registry.restorer.restoration_begin(request_id)
paste.registry.restorer.restoration_begin(request_id)
_app_context = False
class RedditTestCase(TestCase):
@@ -61,6 +49,20 @@ class RedditTestCase(TestCase):
this isn't necessary as it'll save time.
"""
if not _app_context:
wsgiapp = loadapp('config:test.ini', relative_to=conf_dir)
test_app = paste.fixture.TestApp(wsgiapp)
# this is basically what 'paster run' does (see r2/commands.py)
test_response = test_app.get("/_test_vars")
request_id = int(test_response.body)
test_app.pre_request_hook = lambda self: \
paste.registry.restorer.restoration_end()
test_app.post_request_hook = lambda self: \
paste.registry.restorer.restoration_begin(request_id)
paste.registry.restorer.restoration_begin(request_id)
_app_context = True
def __init__(self, *args, **kwargs):
stage_for_paste()
TestCase.__init__(self, *args, **kwargs)

View File

@@ -21,9 +21,6 @@
# Inc. All Rights Reserved.
###############################################################################
from r2.tests import stage_for_paste
stage_for_paste()
import unittest
from r2.lib.media import _get_scrape_url

View File

@@ -23,10 +23,6 @@
import unittest
from r2.tests import stage_for_paste
stage_for_paste()
from r2.lib.permissions import PermissionSet, ModeratorPermissionSet
class TestPermissionSet(PermissionSet):

View File

@@ -21,10 +21,8 @@
# Inc. All Rights Reserved.
###############################################################################
from r2.tests import stage_for_paste
stage_for_paste()
import unittest
from r2.tests import RedditTestCase
from pylons import g

View File

@@ -21,9 +21,6 @@
# Inc. All Rights Reserved.
###############################################################################
from r2.tests import stage_for_paste
stage_for_paste()
import unittest
from r2.lib.providers.image_resizing.no_op import NoOpImageResizingProvider

View File

@@ -21,9 +21,6 @@
# Inc. All Rights Reserved.
###############################################################################
from r2.tests import stage_for_paste
stage_for_paste()
import unittest
from r2.lib.providers.image_resizing.unsplashit import UnsplashitImageResizingProvider

View File

@@ -24,15 +24,14 @@
import unittest
from r2.lib.utils import UrlParser
from r2.tests import stage_for_paste
from r2.tests import RedditTestCase
from pylons import g
class TestIsRedditURL(unittest.TestCase):
class TestIsRedditURL(RedditTestCase):
@classmethod
def setUpClass(cls):
stage_for_paste()
cls._old_offsite = g.offsite_subdomains
g.offsite_subdomains = ["blog"]
@@ -146,8 +145,7 @@ class TestIsRedditURL(unittest.TestCase):
self.assertIsSafeRedditUrl(u"/foo/bar/\xa0baz")
class TestSwitchSubdomainByExtension(unittest.TestCase):
class TestSwitchSubdomainByExtension(RedditTestCase):
@classmethod
def setUpClass(cls):
cls._old_domain = g.domain

View File

@@ -22,9 +22,7 @@
###############################################################################
import unittest
from r2.tests import stage_for_paste
stage_for_paste()
from r2.tests import RedditTestCase
from pylons import c
from r2.lib.errors import errors, ErrorSet

View File

@@ -29,8 +29,7 @@ from webob.exc import HTTPException
# Needs to be done before other r2 imports, since some code run on module import
# expects a sane pylons env
from r2.tests import stage_for_paste
stage_for_paste()
from r2.tests import RedditTestCase
from r2.lib.db.thing import NotFound
from r2.lib.errors import errors, ErrorSet, UserRequiredException
@@ -42,26 +41,14 @@ class TestVVerifyPassword(unittest.TestCase):
"""Test that only the current user's password satisfies VVerifyPassword"""
@classmethod
def setUpClass(cls):
cls._backup_user = c.user
cls._backup_loggedin = c.user_is_loggedin
# Create a dummy account for testing with; won't touch the database
# as long as we don't `._commit()`
name = "unit_tester_%s" % uuid.uuid4().hex
cls._password = uuid.uuid4().hex
try:
Account._by_name(name)
raise AccountExists
except NotFound:
cls._account = Account(
name=name,
password=bcrypt_password(cls._password)
)
@classmethod
def tearDownClass(cls):
c.user_is_loggedin = cls._backup_loggedin
c.user = cls._backup_user
cls._account = Account(
name=name,
password=bcrypt_password(cls._password)
)
def setUp(self):
c.user_is_loggedin = True

View File

@@ -19,7 +19,3 @@
# All portions of the code written by reddit are Copyright (c) 2006-2015 reddit
# Inc. All Rights Reserved.
###############################################################################
from r2.tests import stage_for_paste
stage_for_paste()

View File

@@ -35,9 +35,9 @@ class TestPermissionSet(PermissionSet):
class SRMemberTest(unittest.TestCase):
def setUp(self):
a = Account()
a._commit()
a._id = 1
sr = Subreddit()
sr._commit()
sr._id = 2
self.rel = SRMember(sr, a, 'test')
def test_get_permissions(self):

View File

@@ -91,8 +91,8 @@ setup(
# Extra dependencies that aren't needed for running the app.
# * https://pythonhosted.org/setuptools/setuptools.html#declaring-extras-optional-features-with-their-own-dependencies
# * https://github.com/pypa/sampleproject/blob/300f04dc44df51492deb859ac98ba521d2c7a17a/setup.py#L71-L77
extras_require = {
'test': ['mock'],
extras_require={
'test': ['mock', 'nose'],
},
dependency_links=[
"https://github.com/reddit/snudown/archive/v1.1.3.tar.gz#egg=snudown-1.1.3",