HSTS redirects: Avoid redirecting with extensions

This commit is contained in:
Keith Mitchell
2015-01-28 14:27:30 -08:00
parent 85f79ef889
commit b57d0679df
4 changed files with 24 additions and 12 deletions

View File

@@ -1412,10 +1412,12 @@ class FormsController(RedditController):
dest=VDestination())
def POST_logout(self, dest):
"""wipe login cookie and redirect to referer."""
if hsts_eligible():
dest = hsts_modify_redirect(dest)
# Check eligibility before calling logout(), as logout() changes
# cookies that hsts_eligible() looks at
is_hsts_eligible = hsts_eligible()
self.logout()
return self.redirect(dest)
self.hsts_redirect(dest, is_hsts_eligible=is_hsts_eligible)
@validate(VUser(),

View File

@@ -121,9 +121,7 @@ class PostController(ApiController):
return LoginPage(user_login = request.POST.get('user'),
dest = dest).render()
if hsts_eligible():
dest = hsts_modify_redirect(dest)
return self.redirect(dest)
return self.hsts_redirect(dest)
@csrf_exempt
@validate(dest = VDestination(default = "/"))
@@ -136,9 +134,7 @@ class PostController(ApiController):
return LoginPage(user_reg = request.POST.get('user'),
dest = dest).render()
if hsts_eligible():
dest = hsts_modify_redirect(dest)
return self.redirect(dest)
return self.hsts_redirect(dest)
def GET_login(self, *a, **kw):
return self.redirect('/login' + query_string(dict(dest="/")))

View File

@@ -1308,6 +1308,17 @@ class MinimalController(BaseController):
return request.method.upper() != "POST"
@classmethod
def hsts_redirect(cls, dest, is_hsts_eligible=None):
"""Redirect to `dest` via the HSTS grant endpoint"""
if is_hsts_eligible is None:
is_hsts_eligible = hsts_eligible()
if is_hsts_eligible:
dest = hsts_modify_redirect(dest)
return cls.redirect(dest, preserve_extension=False)
else:
return cls.redirect(dest)
class OAuth2ResourceController(MinimalController):
defer_ratelimiting = True

View File

@@ -157,6 +157,7 @@ class BaseController(WSGIController):
Node: for development purposes, also checks that the port
matches the request port
"""
preserve_extension = kw.pop("preserve_extension", True)
u = UrlParser(url)
if u.is_reddit_url():
@@ -169,7 +170,7 @@ class BaseController(WSGIController):
u.mk_cname(**kw)
# make sure the extensions agree with the current page
if c.extension:
if preserve_extension and c.extension:
u.set_extension(c.extension)
# unparse and encode it un utf8
@@ -197,15 +198,17 @@ class BaseController(WSGIController):
abort(302, location=path)
@classmethod
def redirect(cls, dest, code = 302):
def redirect(cls, dest, code=302, preserve_extension=True):
"""
Reformats the new Location (dest) using format_output_url and
sends the user to that location with the provided HTTP code.
"""
dest = cls.format_output_url(dest or "/")
dest = cls.format_output_url(dest or "/",
preserve_extension=preserve_extension)
response.status_int = code
response.headers['Location'] = dest
class EmbedHandler(urllib2.BaseHandler, urllib2.HTTPHandler,
urllib2.HTTPErrorProcessor, urllib2.HTTPDefaultErrorHandler):