From 4cda2d017700c88736c3f0be244500e89eaa89a7 Mon Sep 17 00:00:00 2001 From: Max Goodman Date: Tue, 18 Oct 2011 15:17:14 -0700 Subject: [PATCH] Relax CORS origin restriction to all subdomains of trusted domains. --- r2/r2/controllers/api.py | 4 ++-- r2/r2/controllers/reddit_base.py | 27 ++++++++++++++++++--------- r2/r2/lib/app_globals.py | 8 ++++++-- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/r2/r2/controllers/api.py b/r2/r2/controllers/api.py index f6951a003..af890629a 100644 --- a/r2/r2/controllers/api.py +++ b/r2/r2/controllers/api.py @@ -364,11 +364,11 @@ class ApiController(RedditController): responder._send_data(modhash = user.modhash()) responder._send_data(cookie = user.make_cookie()) - @cross_domain(g.trusted_origins, allow_credentials=True) + @cross_domain(allow_credentials=True) def POST_login(self, *args, **kwargs): return self._handle_login(*args, **kwargs) - @cross_domain(g.trusted_origins, allow_credentials=True) + @cross_domain(allow_credentials=True) def POST_register(self, *args, **kwargs): return self._handle_register(*args, **kwargs) diff --git a/r2/r2/controllers/reddit_base.py b/r2/r2/controllers/reddit_base.py index 8de3c0649..b233c8ef7 100644 --- a/r2/r2/controllers/reddit_base.py +++ b/r2/r2/controllers/reddit_base.py @@ -26,7 +26,7 @@ from pylons.i18n import _ from pylons.i18n.translation import LanguageError from r2.lib.base import BaseController, proxyurl from r2.lib import pages, utils, filters, amqp -from r2.lib.utils import http_utils, UniqueIterator, ip_and_slash16 +from r2.lib.utils import http_utils, is_subdomain, UniqueIterator, ip_and_slash16 from r2.lib.cache import LocalCache, make_key, MemcachedError import random as rand from r2.models.account import valid_cookie, FakeAccount, valid_feed @@ -468,14 +468,26 @@ def paginated_listing(default_page_size=25, max_page_size=100): def base_listing(fn): return paginated_listing()(fn) -def cross_domain(origins, **options): +def is_trusted_origin(origin): + try: + origin = urlparse(origin) + except ValueError: + return False + + return any(is_subdomain(origin.hostname, domain) for domain in g.trusted_domains) + +def cross_domain(origin_check=is_trusted_origin, **options): """Set up cross domain validation and hoisting for a request handler.""" - origins = filter(None, origins) def cross_domain_wrap(fn): + cors_perms = { + "origin_check": origin_check, + "allow_credentials": bool(options.get("allow_credentials")) + } + def cross_domain_handler(self, *args, **kwargs): if request.params.get("hoist") == "cookie": # Cookie polling response - if g.origin in origins: + if cors_perms["origin_check"](g.origin): name = request.environ["pylons.routes_dict"]["action_name"] resp = fn(self, *args, **kwargs) c.cookies.add('hoist_%s' % name, ''.join(resp.content)) @@ -488,10 +500,7 @@ def cross_domain(origins, **options): self.check_cors() return fn(self, *args, **kwargs) - cross_domain_handler.cors_perms = { - "allowed_origins": origins, - "allow_credentials": bool(options.get("allow_credentials")) - } + cross_domain_handler.cors_perms = cors_perms return cross_domain_handler return cross_domain_wrap @@ -660,7 +669,7 @@ class MinimalController(BaseController): handler = getattr(self, method + "_" + action, None) cors = handler and getattr(handler, "cors_perms", None) - if cors and origin in cors["allowed_origins"]: + if cors and cors["origin_check"](origin): response.headers["Access-Control-Allow-Origin"] = origin if cors.get("allow_credentials"): response.headers["Access-Control-Allow-Credentials"] = "true" diff --git a/r2/r2/lib/app_globals.py b/r2/r2/lib/app_globals.py index 742160087..34f9a7da6 100755 --- a/r2/r2/lib/app_globals.py +++ b/r2/r2/lib/app_globals.py @@ -287,10 +287,14 @@ class Globals(object): origin_prefix = self.domain_prefix + "." if self.domain_prefix else "" self.origin = "http://" + origin_prefix + self.domain self.secure_domains = set([urlparse(self.payment_domain).netloc]) + + self.trusted_domains = set([self.domain]) + self.trusted_domains.update(self.authorized_cnames) if self.https_endpoint: - self.secure_domains.add(urlparse(self.https_endpoint).netloc) + https_url = urlparse(self.https_endpoint) + self.secure_domains.add(https_url.netloc) + self.trusted_domains.add(https_url.hostname) - self.trusted_origins = [self.origin, self.https_endpoint] + ['http://' + origin_prefix + cname for cname in self.authorized_cnames] # load the unique hashed names of files under static static_files = os.path.join(self.paths.get('static_files'), 'static')