diff --git a/r2/r2/config/middleware.py b/r2/r2/config/middleware.py index 9f8588898..29d597c7d 100644 --- a/r2/r2/config/middleware.py +++ b/r2/r2/config/middleware.py @@ -425,6 +425,11 @@ class RedditApp(PylonsApp): def setup_app_env(self, environ, start_response): PylonsApp.setup_app_env(self, environ, start_response) + from pylons import g + # When running tests don't load controllers or register hooks. Loading the + # controllers currently causes db initialization and runs queries. + if g.env == 'unit_test': + return self.load() def load(self): diff --git a/r2/r2/lib/app_globals.py b/r2/r2/lib/app_globals.py index 32a339ab9..6c000893f 100644 --- a/r2/r2/lib/app_globals.py +++ b/r2/r2/lib/app_globals.py @@ -442,6 +442,7 @@ class Globals(object): raise AttributeError def setup(self): + self.env = 'unit_test' if 'test' in sys.argv[0] else '' self.queues = queues.declare_queues(self) self.extension_subdomains = dict( @@ -915,9 +916,14 @@ class Globals(object): def load_db_params(self): self.databases = tuple(ConfigValue.to_iter(self.config.raw_data['databases'])) self.db_params = {} + self.predefined_type_ids = {} if not self.databases: return + if self.env == 'unit_test': + from mock import MagicMock + return MagicMock() + dbm = db_manager.db_manager() db_param_names = ('name', 'db_host', 'db_user', 'db_pass', 'db_port', 'pool_size', 'max_overflow') @@ -959,7 +965,6 @@ class Globals(object): return params, flags prefix = 'db_table_' - self.predefined_type_ids = {} for k, v in self.config.raw_data.iteritems(): if not k.startswith(prefix): continue diff --git a/r2/r2/lib/db/thing.py b/r2/r2/lib/db/thing.py index 940ecb3be..52fb7a3c8 100644 --- a/r2/r2/lib/db/thing.py +++ b/r2/r2/lib/db/thing.py @@ -533,7 +533,8 @@ class DataThing(object): class ThingMeta(type): def __init__(cls, name, bases, dct): if name == 'Thing' or hasattr(cls, '_nodb') and cls._nodb: return - #print "checking thing", name + if g.env == 'unit_test': + return #TODO exceptions cls._type_name = name.lower() @@ -666,6 +667,9 @@ class RelationMeta(type): if name == 'RelationCls': return #print "checking relation", name + if g.env == 'unit_test': + return + cls._type_name = name.lower() try: cls._type_id = tdb.rel_types_name[cls._type_name].type_id diff --git a/r2/r2/tests/__init__.py b/r2/r2/tests/__init__.py index e1a1226d0..94499f2b7 100644 --- a/r2/r2/tests/__init__.py +++ b/r2/r2/tests/__init__.py @@ -39,19 +39,7 @@ pkg_resources.working_set.add_entry(conf_dir) pkg_resources.require('Paste') pkg_resources.require('PasteScript') - -def stage_for_paste(): - wsgiapp = loadapp('config:test.ini', relative_to=conf_dir) - test_app = paste.fixture.TestApp(wsgiapp) - - # this is basically what 'paster run' does (see r2/commands.py) - test_response = test_app.get("/_test_vars") - request_id = int(test_response.body) - test_app.pre_request_hook = lambda self: \ - paste.registry.restorer.restoration_end() - test_app.post_request_hook = lambda self: \ - paste.registry.restorer.restoration_begin(request_id) - paste.registry.restorer.restoration_begin(request_id) +_app_context = False class RedditTestCase(TestCase): @@ -61,6 +49,20 @@ class RedditTestCase(TestCase): this isn't necessary as it'll save time. """ + if not _app_context: + wsgiapp = loadapp('config:test.ini', relative_to=conf_dir) + test_app = paste.fixture.TestApp(wsgiapp) + + # this is basically what 'paster run' does (see r2/commands.py) + test_response = test_app.get("/_test_vars") + request_id = int(test_response.body) + test_app.pre_request_hook = lambda self: \ + paste.registry.restorer.restoration_end() + test_app.post_request_hook = lambda self: \ + paste.registry.restorer.restoration_begin(request_id) + paste.registry.restorer.restoration_begin(request_id) + + _app_context = True + def __init__(self, *args, **kwargs): - stage_for_paste() TestCase.__init__(self, *args, **kwargs) diff --git a/r2/r2/tests/unit/lib/media_test.py b/r2/r2/tests/unit/lib/media_test.py index f2725ae44..7ebc6e8b2 100644 --- a/r2/r2/tests/unit/lib/media_test.py +++ b/r2/r2/tests/unit/lib/media_test.py @@ -21,9 +21,6 @@ # Inc. All Rights Reserved. ############################################################################### -from r2.tests import stage_for_paste -stage_for_paste() - import unittest from r2.lib.media import _get_scrape_url diff --git a/r2/r2/tests/unit/lib/permissions_test.py b/r2/r2/tests/unit/lib/permissions_test.py index f4c5959a5..36b69fc97 100644 --- a/r2/r2/tests/unit/lib/permissions_test.py +++ b/r2/r2/tests/unit/lib/permissions_test.py @@ -23,10 +23,6 @@ import unittest -from r2.tests import stage_for_paste - -stage_for_paste() - from r2.lib.permissions import PermissionSet, ModeratorPermissionSet class TestPermissionSet(PermissionSet): diff --git a/r2/r2/tests/unit/lib/providers/image_resizing/imgix_test.py b/r2/r2/tests/unit/lib/providers/image_resizing/imgix_test.py index 919d881e5..f381ae4f3 100644 --- a/r2/r2/tests/unit/lib/providers/image_resizing/imgix_test.py +++ b/r2/r2/tests/unit/lib/providers/image_resizing/imgix_test.py @@ -21,10 +21,8 @@ # Inc. All Rights Reserved. ############################################################################### -from r2.tests import stage_for_paste -stage_for_paste() - import unittest +from r2.tests import RedditTestCase from pylons import g diff --git a/r2/r2/tests/unit/lib/providers/image_resizing/no_op_test.py b/r2/r2/tests/unit/lib/providers/image_resizing/no_op_test.py index 98d4ebb70..f9543b3e4 100644 --- a/r2/r2/tests/unit/lib/providers/image_resizing/no_op_test.py +++ b/r2/r2/tests/unit/lib/providers/image_resizing/no_op_test.py @@ -21,9 +21,6 @@ # Inc. All Rights Reserved. ############################################################################### -from r2.tests import stage_for_paste -stage_for_paste() - import unittest from r2.lib.providers.image_resizing.no_op import NoOpImageResizingProvider diff --git a/r2/r2/tests/unit/lib/providers/image_resizing/unsplashit_test.py b/r2/r2/tests/unit/lib/providers/image_resizing/unsplashit_test.py index 78e84bc74..c1ed26f66 100644 --- a/r2/r2/tests/unit/lib/providers/image_resizing/unsplashit_test.py +++ b/r2/r2/tests/unit/lib/providers/image_resizing/unsplashit_test.py @@ -21,9 +21,6 @@ # Inc. All Rights Reserved. ############################################################################### -from r2.tests import stage_for_paste -stage_for_paste() - import unittest from r2.lib.providers.image_resizing.unsplashit import UnsplashitImageResizingProvider diff --git a/r2/r2/tests/unit/lib/urlparser_test.py b/r2/r2/tests/unit/lib/urlparser_test.py index ed15d0f25..a78e0356d 100644 --- a/r2/r2/tests/unit/lib/urlparser_test.py +++ b/r2/r2/tests/unit/lib/urlparser_test.py @@ -24,15 +24,14 @@ import unittest from r2.lib.utils import UrlParser -from r2.tests import stage_for_paste +from r2.tests import RedditTestCase from pylons import g -class TestIsRedditURL(unittest.TestCase): +class TestIsRedditURL(RedditTestCase): @classmethod def setUpClass(cls): - stage_for_paste() cls._old_offsite = g.offsite_subdomains g.offsite_subdomains = ["blog"] @@ -146,8 +145,7 @@ class TestIsRedditURL(unittest.TestCase): self.assertIsSafeRedditUrl(u"/foo/bar/\xa0baz") - -class TestSwitchSubdomainByExtension(unittest.TestCase): +class TestSwitchSubdomainByExtension(RedditTestCase): @classmethod def setUpClass(cls): cls._old_domain = g.domain diff --git a/r2/r2/tests/unit/lib/validator/test_validator.py b/r2/r2/tests/unit/lib/validator/test_validator.py index 8f3e85e69..08706423e 100644 --- a/r2/r2/tests/unit/lib/validator/test_validator.py +++ b/r2/r2/tests/unit/lib/validator/test_validator.py @@ -22,9 +22,7 @@ ############################################################################### import unittest - -from r2.tests import stage_for_paste -stage_for_paste() +from r2.tests import RedditTestCase from pylons import c from r2.lib.errors import errors, ErrorSet diff --git a/r2/r2/tests/unit/lib/validator/test_vverifypassword.py b/r2/r2/tests/unit/lib/validator/test_vverifypassword.py index c429a37ef..60388d2c3 100644 --- a/r2/r2/tests/unit/lib/validator/test_vverifypassword.py +++ b/r2/r2/tests/unit/lib/validator/test_vverifypassword.py @@ -29,8 +29,7 @@ from webob.exc import HTTPException # Needs to be done before other r2 imports, since some code run on module import # expects a sane pylons env -from r2.tests import stage_for_paste -stage_for_paste() +from r2.tests import RedditTestCase from r2.lib.db.thing import NotFound from r2.lib.errors import errors, ErrorSet, UserRequiredException @@ -42,26 +41,14 @@ class TestVVerifyPassword(unittest.TestCase): """Test that only the current user's password satisfies VVerifyPassword""" @classmethod def setUpClass(cls): - cls._backup_user = c.user - cls._backup_loggedin = c.user_is_loggedin - # Create a dummy account for testing with; won't touch the database # as long as we don't `._commit()` name = "unit_tester_%s" % uuid.uuid4().hex cls._password = uuid.uuid4().hex - try: - Account._by_name(name) - raise AccountExists - except NotFound: - cls._account = Account( - name=name, - password=bcrypt_password(cls._password) - ) - - @classmethod - def tearDownClass(cls): - c.user_is_loggedin = cls._backup_loggedin - c.user = cls._backup_user + cls._account = Account( + name=name, + password=bcrypt_password(cls._password) + ) def setUp(self): c.user_is_loggedin = True diff --git a/r2/r2/tests/unit/models/__init__.py b/r2/r2/tests/unit/models/__init__.py index 7dde6b0ba..82ca0fb27 100644 --- a/r2/r2/tests/unit/models/__init__.py +++ b/r2/r2/tests/unit/models/__init__.py @@ -19,7 +19,3 @@ # All portions of the code written by reddit are Copyright (c) 2006-2015 reddit # Inc. All Rights Reserved. ############################################################################### - -from r2.tests import stage_for_paste - -stage_for_paste() diff --git a/r2/r2/tests/unit/models/subreddit_test.py b/r2/r2/tests/unit/models/subreddit_test.py index 28c07064a..0f969cb28 100644 --- a/r2/r2/tests/unit/models/subreddit_test.py +++ b/r2/r2/tests/unit/models/subreddit_test.py @@ -35,9 +35,9 @@ class TestPermissionSet(PermissionSet): class SRMemberTest(unittest.TestCase): def setUp(self): a = Account() - a._commit() + a._id = 1 sr = Subreddit() - sr._commit() + sr._id = 2 self.rel = SRMember(sr, a, 'test') def test_get_permissions(self): diff --git a/r2/setup.py b/r2/setup.py index cefc035da..891dca84f 100644 --- a/r2/setup.py +++ b/r2/setup.py @@ -91,8 +91,8 @@ setup( # Extra dependencies that aren't needed for running the app. # * https://pythonhosted.org/setuptools/setuptools.html#declaring-extras-optional-features-with-their-own-dependencies # * https://github.com/pypa/sampleproject/blob/300f04dc44df51492deb859ac98ba521d2c7a17a/setup.py#L71-L77 - extras_require = { - 'test': ['mock'], + extras_require={ + 'test': ['mock', 'nose'], }, dependency_links=[ "https://github.com/reddit/snudown/archive/v1.1.3.tar.gz#egg=snudown-1.1.3",