Relax CORS origin restriction to all subdomains of trusted domains.

This commit is contained in:
Max Goodman
2011-10-18 15:17:14 -07:00
parent 60ac992522
commit 4cda2d0177
3 changed files with 26 additions and 13 deletions

View File

@@ -364,11 +364,11 @@ class ApiController(RedditController):
responder._send_data(modhash = user.modhash())
responder._send_data(cookie = user.make_cookie())
@cross_domain(g.trusted_origins, allow_credentials=True)
@cross_domain(allow_credentials=True)
def POST_login(self, *args, **kwargs):
return self._handle_login(*args, **kwargs)
@cross_domain(g.trusted_origins, allow_credentials=True)
@cross_domain(allow_credentials=True)
def POST_register(self, *args, **kwargs):
return self._handle_register(*args, **kwargs)

View File

@@ -26,7 +26,7 @@ from pylons.i18n import _
from pylons.i18n.translation import LanguageError
from r2.lib.base import BaseController, proxyurl
from r2.lib import pages, utils, filters, amqp
from r2.lib.utils import http_utils, UniqueIterator, ip_and_slash16
from r2.lib.utils import http_utils, is_subdomain, UniqueIterator, ip_and_slash16
from r2.lib.cache import LocalCache, make_key, MemcachedError
import random as rand
from r2.models.account import valid_cookie, FakeAccount, valid_feed
@@ -468,14 +468,26 @@ def paginated_listing(default_page_size=25, max_page_size=100):
def base_listing(fn):
return paginated_listing()(fn)
def cross_domain(origins, **options):
def is_trusted_origin(origin):
try:
origin = urlparse(origin)
except ValueError:
return False
return any(is_subdomain(origin.hostname, domain) for domain in g.trusted_domains)
def cross_domain(origin_check=is_trusted_origin, **options):
"""Set up cross domain validation and hoisting for a request handler."""
origins = filter(None, origins)
def cross_domain_wrap(fn):
cors_perms = {
"origin_check": origin_check,
"allow_credentials": bool(options.get("allow_credentials"))
}
def cross_domain_handler(self, *args, **kwargs):
if request.params.get("hoist") == "cookie":
# Cookie polling response
if g.origin in origins:
if cors_perms["origin_check"](g.origin):
name = request.environ["pylons.routes_dict"]["action_name"]
resp = fn(self, *args, **kwargs)
c.cookies.add('hoist_%s' % name, ''.join(resp.content))
@@ -488,10 +500,7 @@ def cross_domain(origins, **options):
self.check_cors()
return fn(self, *args, **kwargs)
cross_domain_handler.cors_perms = {
"allowed_origins": origins,
"allow_credentials": bool(options.get("allow_credentials"))
}
cross_domain_handler.cors_perms = cors_perms
return cross_domain_handler
return cross_domain_wrap
@@ -660,7 +669,7 @@ class MinimalController(BaseController):
handler = getattr(self, method + "_" + action, None)
cors = handler and getattr(handler, "cors_perms", None)
if cors and origin in cors["allowed_origins"]:
if cors and cors["origin_check"](origin):
response.headers["Access-Control-Allow-Origin"] = origin
if cors.get("allow_credentials"):
response.headers["Access-Control-Allow-Credentials"] = "true"

View File

@@ -287,10 +287,14 @@ class Globals(object):
origin_prefix = self.domain_prefix + "." if self.domain_prefix else ""
self.origin = "http://" + origin_prefix + self.domain
self.secure_domains = set([urlparse(self.payment_domain).netloc])
self.trusted_domains = set([self.domain])
self.trusted_domains.update(self.authorized_cnames)
if self.https_endpoint:
self.secure_domains.add(urlparse(self.https_endpoint).netloc)
https_url = urlparse(self.https_endpoint)
self.secure_domains.add(https_url.netloc)
self.trusted_domains.add(https_url.hostname)
self.trusted_origins = [self.origin, self.https_endpoint] + ['http://' + origin_prefix + cname for cname in self.authorized_cnames]
# load the unique hashed names of files under static
static_files = os.path.join(self.paths.get('static_files'), 'static')