mirror of
https://github.com/reddit-archive/reddit.git
synced 2026-04-27 03:00:12 -04:00
HSTS redirects: Avoid redirecting with extensions
This commit is contained in:
@@ -1412,10 +1412,12 @@ class FormsController(RedditController):
|
||||
dest=VDestination())
|
||||
def POST_logout(self, dest):
|
||||
"""wipe login cookie and redirect to referer."""
|
||||
if hsts_eligible():
|
||||
dest = hsts_modify_redirect(dest)
|
||||
|
||||
# Check eligibility before calling logout(), as logout() changes
|
||||
# cookies that hsts_eligible() looks at
|
||||
is_hsts_eligible = hsts_eligible()
|
||||
self.logout()
|
||||
return self.redirect(dest)
|
||||
self.hsts_redirect(dest, is_hsts_eligible=is_hsts_eligible)
|
||||
|
||||
|
||||
@validate(VUser(),
|
||||
|
||||
@@ -121,9 +121,7 @@ class PostController(ApiController):
|
||||
return LoginPage(user_login = request.POST.get('user'),
|
||||
dest = dest).render()
|
||||
|
||||
if hsts_eligible():
|
||||
dest = hsts_modify_redirect(dest)
|
||||
return self.redirect(dest)
|
||||
return self.hsts_redirect(dest)
|
||||
|
||||
@csrf_exempt
|
||||
@validate(dest = VDestination(default = "/"))
|
||||
@@ -136,9 +134,7 @@ class PostController(ApiController):
|
||||
return LoginPage(user_reg = request.POST.get('user'),
|
||||
dest = dest).render()
|
||||
|
||||
if hsts_eligible():
|
||||
dest = hsts_modify_redirect(dest)
|
||||
return self.redirect(dest)
|
||||
return self.hsts_redirect(dest)
|
||||
|
||||
def GET_login(self, *a, **kw):
|
||||
return self.redirect('/login' + query_string(dict(dest="/")))
|
||||
|
||||
@@ -1308,6 +1308,17 @@ class MinimalController(BaseController):
|
||||
|
||||
return request.method.upper() != "POST"
|
||||
|
||||
@classmethod
|
||||
def hsts_redirect(cls, dest, is_hsts_eligible=None):
|
||||
"""Redirect to `dest` via the HSTS grant endpoint"""
|
||||
if is_hsts_eligible is None:
|
||||
is_hsts_eligible = hsts_eligible()
|
||||
if is_hsts_eligible:
|
||||
dest = hsts_modify_redirect(dest)
|
||||
return cls.redirect(dest, preserve_extension=False)
|
||||
else:
|
||||
return cls.redirect(dest)
|
||||
|
||||
|
||||
class OAuth2ResourceController(MinimalController):
|
||||
defer_ratelimiting = True
|
||||
|
||||
@@ -157,6 +157,7 @@ class BaseController(WSGIController):
|
||||
Node: for development purposes, also checks that the port
|
||||
matches the request port
|
||||
"""
|
||||
preserve_extension = kw.pop("preserve_extension", True)
|
||||
u = UrlParser(url)
|
||||
|
||||
if u.is_reddit_url():
|
||||
@@ -169,7 +170,7 @@ class BaseController(WSGIController):
|
||||
u.mk_cname(**kw)
|
||||
|
||||
# make sure the extensions agree with the current page
|
||||
if c.extension:
|
||||
if preserve_extension and c.extension:
|
||||
u.set_extension(c.extension)
|
||||
|
||||
# unparse and encode it un utf8
|
||||
@@ -197,15 +198,17 @@ class BaseController(WSGIController):
|
||||
abort(302, location=path)
|
||||
|
||||
@classmethod
|
||||
def redirect(cls, dest, code = 302):
|
||||
def redirect(cls, dest, code=302, preserve_extension=True):
|
||||
"""
|
||||
Reformats the new Location (dest) using format_output_url and
|
||||
sends the user to that location with the provided HTTP code.
|
||||
"""
|
||||
dest = cls.format_output_url(dest or "/")
|
||||
dest = cls.format_output_url(dest or "/",
|
||||
preserve_extension=preserve_extension)
|
||||
response.status_int = code
|
||||
response.headers['Location'] = dest
|
||||
|
||||
|
||||
class EmbedHandler(urllib2.BaseHandler, urllib2.HTTPHandler,
|
||||
urllib2.HTTPErrorProcessor, urllib2.HTTPDefaultErrorHandler):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user