mirror of
https://github.com/reddit-archive/reddit.git
synced 2026-01-27 07:48:16 -05:00
Relax CORS origin restriction to all subdomains of trusted domains.
This commit is contained in:
@@ -364,11 +364,11 @@ class ApiController(RedditController):
|
||||
responder._send_data(modhash = user.modhash())
|
||||
responder._send_data(cookie = user.make_cookie())
|
||||
|
||||
@cross_domain(g.trusted_origins, allow_credentials=True)
|
||||
@cross_domain(allow_credentials=True)
|
||||
def POST_login(self, *args, **kwargs):
|
||||
return self._handle_login(*args, **kwargs)
|
||||
|
||||
@cross_domain(g.trusted_origins, allow_credentials=True)
|
||||
@cross_domain(allow_credentials=True)
|
||||
def POST_register(self, *args, **kwargs):
|
||||
return self._handle_register(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from pylons.i18n import _
|
||||
from pylons.i18n.translation import LanguageError
|
||||
from r2.lib.base import BaseController, proxyurl
|
||||
from r2.lib import pages, utils, filters, amqp
|
||||
from r2.lib.utils import http_utils, UniqueIterator, ip_and_slash16
|
||||
from r2.lib.utils import http_utils, is_subdomain, UniqueIterator, ip_and_slash16
|
||||
from r2.lib.cache import LocalCache, make_key, MemcachedError
|
||||
import random as rand
|
||||
from r2.models.account import valid_cookie, FakeAccount, valid_feed
|
||||
@@ -468,14 +468,26 @@ def paginated_listing(default_page_size=25, max_page_size=100):
|
||||
def base_listing(fn):
|
||||
return paginated_listing()(fn)
|
||||
|
||||
def cross_domain(origins, **options):
|
||||
def is_trusted_origin(origin):
|
||||
try:
|
||||
origin = urlparse(origin)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
return any(is_subdomain(origin.hostname, domain) for domain in g.trusted_domains)
|
||||
|
||||
def cross_domain(origin_check=is_trusted_origin, **options):
|
||||
"""Set up cross domain validation and hoisting for a request handler."""
|
||||
origins = filter(None, origins)
|
||||
def cross_domain_wrap(fn):
|
||||
cors_perms = {
|
||||
"origin_check": origin_check,
|
||||
"allow_credentials": bool(options.get("allow_credentials"))
|
||||
}
|
||||
|
||||
def cross_domain_handler(self, *args, **kwargs):
|
||||
if request.params.get("hoist") == "cookie":
|
||||
# Cookie polling response
|
||||
if g.origin in origins:
|
||||
if cors_perms["origin_check"](g.origin):
|
||||
name = request.environ["pylons.routes_dict"]["action_name"]
|
||||
resp = fn(self, *args, **kwargs)
|
||||
c.cookies.add('hoist_%s' % name, ''.join(resp.content))
|
||||
@@ -488,10 +500,7 @@ def cross_domain(origins, **options):
|
||||
self.check_cors()
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
cross_domain_handler.cors_perms = {
|
||||
"allowed_origins": origins,
|
||||
"allow_credentials": bool(options.get("allow_credentials"))
|
||||
}
|
||||
cross_domain_handler.cors_perms = cors_perms
|
||||
return cross_domain_handler
|
||||
return cross_domain_wrap
|
||||
|
||||
@@ -660,7 +669,7 @@ class MinimalController(BaseController):
|
||||
handler = getattr(self, method + "_" + action, None)
|
||||
cors = handler and getattr(handler, "cors_perms", None)
|
||||
|
||||
if cors and origin in cors["allowed_origins"]:
|
||||
if cors and cors["origin_check"](origin):
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
if cors.get("allow_credentials"):
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
@@ -287,10 +287,14 @@ class Globals(object):
|
||||
origin_prefix = self.domain_prefix + "." if self.domain_prefix else ""
|
||||
self.origin = "http://" + origin_prefix + self.domain
|
||||
self.secure_domains = set([urlparse(self.payment_domain).netloc])
|
||||
|
||||
self.trusted_domains = set([self.domain])
|
||||
self.trusted_domains.update(self.authorized_cnames)
|
||||
if self.https_endpoint:
|
||||
self.secure_domains.add(urlparse(self.https_endpoint).netloc)
|
||||
https_url = urlparse(self.https_endpoint)
|
||||
self.secure_domains.add(https_url.netloc)
|
||||
self.trusted_domains.add(https_url.hostname)
|
||||
|
||||
self.trusted_origins = [self.origin, self.https_endpoint] + ['http://' + origin_prefix + cname for cname in self.authorized_cnames]
|
||||
|
||||
# load the unique hashed names of files under static
|
||||
static_files = os.path.join(self.paths.get('static_files'), 'static')
|
||||
|
||||
Reference in New Issue
Block a user