mirror of
https://github.com/reddit-archive/reddit.git
synced 2026-04-27 03:00:12 -04:00
Allow "script" apps grant_type=password
Only application devs will be permitted to use password grants, to discourage widespread use. Password grants are intended to be a convenience for bots & personal scripts, to encourage use of OAuth tokens and ease transition to OAuth2. For more info on password grants, see: http://tools.ietf.org/html/rfc6749#section-4.3
This commit is contained in:
@@ -42,6 +42,7 @@ from r2.lib.validator import (
|
||||
nop,
|
||||
validate,
|
||||
VRequired,
|
||||
VThrottledLogin,
|
||||
VOneOf,
|
||||
VUser,
|
||||
VModhash,
|
||||
@@ -160,13 +161,11 @@ class OAuth2AccessController(MinimalController):
|
||||
except RequirementException:
|
||||
abort(401, headers=[("WWW-Authenticate", 'Basic realm="reddit"')])
|
||||
|
||||
@validate(grant_type = VOneOf("grant_type",
|
||||
("authorization_code", "refresh_token")),
|
||||
code = nop("code"),
|
||||
refresh_token = VOAuth2RefreshToken("refresh_token"),
|
||||
redirect_uri = VRequired("redirect_uri",
|
||||
errors.OAUTH2_INVALID_REDIRECT_URI))
|
||||
def POST_access_token(self, grant_type, code, refresh_token, redirect_uri):
|
||||
@validate(grant_type=VOneOf("grant_type",
|
||||
("authorization_code",
|
||||
"refresh_token",
|
||||
"password")))
|
||||
def POST_access_token(self, grant_type):
|
||||
"""
|
||||
Exchange an [OAuth 2.0](http://oauth.net/2/) authorization code
|
||||
or refresh token (from [/api/v1/authorize](#api_method_authorize)) for
|
||||
@@ -185,54 +184,122 @@ class OAuth2AccessController(MinimalController):
|
||||
|
||||
Per the OAuth specification, **grant_type** must
|
||||
be ``authorization_code`` for the initial access token or
|
||||
``refresh_token`` for renewing the access token. In either case,
|
||||
``refresh_token`` for renewing the access token.
|
||||
|
||||
**redirect_uri** must exactly match the value that was used in the call
|
||||
to [/api/v1/authorize](#api_method_authorize) that created this grant.
|
||||
"""
|
||||
|
||||
resp = {}
|
||||
if not (code or refresh_token):
|
||||
c.errors.add("NO_TEXT", field=("code", "refresh_token"))
|
||||
if not c.errors:
|
||||
access_token = None
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
auth_token = OAuth2AuthorizationCode.use_token(
|
||||
code, c.oauth2_client._id, redirect_uri)
|
||||
if auth_token:
|
||||
if auth_token.refreshable:
|
||||
refresh_token = OAuth2RefreshToken._new(
|
||||
auth_token.client_id, auth_token.user_id,
|
||||
auth_token.scope)
|
||||
access_token = OAuth2AccessToken._new(
|
||||
auth_token.client_id, auth_token.user_id,
|
||||
auth_token.scope,
|
||||
refresh_token._id if refresh_token else None)
|
||||
elif grant_type == "refresh_token" and refresh_token:
|
||||
access_token = OAuth2AccessToken._new(
|
||||
refresh_token.client_id, refresh_token.user_id,
|
||||
refresh_token.scope,
|
||||
refresh_token=refresh_token._id)
|
||||
|
||||
if access_token:
|
||||
resp["access_token"] = access_token._id
|
||||
resp["token_type"] = access_token.token_type
|
||||
resp["expires_in"] = int(access_token._ttl) if access_token._ttl else None
|
||||
resp["scope"] = access_token.scope
|
||||
if refresh_token:
|
||||
resp["refresh_token"] = refresh_token._id
|
||||
else:
|
||||
resp["error"] = "invalid_grant"
|
||||
if grant_type == "authorization_code":
|
||||
return self._access_token_code()
|
||||
elif grant_type == "refresh_token":
|
||||
return self._access_token_refresh()
|
||||
elif grant_type == "password":
|
||||
return self._access_token_password()
|
||||
else:
|
||||
if (errors.INVALID_OPTION, "grant_type") in c.errors:
|
||||
resp["error"] = "unsupported_grant_type"
|
||||
elif (errors.INVALID_OPTION, "scope") in c.errors:
|
||||
resp["error"] = "invalid_scope"
|
||||
else:
|
||||
resp["error"] = "invalid_request"
|
||||
resp = {"error": "unsupported_grant_type"}
|
||||
return self.api_wrapper(resp)
|
||||
|
||||
def _check_for_errors(self):
|
||||
resp = {}
|
||||
if (errors.INVALID_OPTION, "scope") in c.errors:
|
||||
resp["error"] = "invalid_scope"
|
||||
else:
|
||||
resp["error"] = "invalid_request"
|
||||
return resp
|
||||
|
||||
def _make_token_dict(self, access_token, refresh_token=None):
|
||||
if not access_token:
|
||||
return {"error": "invalid_grant"}
|
||||
expires_in = int(access_token._ttl) if access_token._ttl else None
|
||||
resp = {
|
||||
"access_token": access_token._id,
|
||||
"token_type": access_token.token_type,
|
||||
"expires_in": expires_in,
|
||||
"scope": access_token.scope,
|
||||
}
|
||||
if refresh_token:
|
||||
resp["refresh_token"] = refresh_token._id
|
||||
return resp
|
||||
|
||||
@validate(code=nop("code"),
|
||||
redirect_uri=VRequired("redirect_uri",
|
||||
errors.OAUTH2_INVALID_REDIRECT_URI))
|
||||
def _access_token_code(self, code, redirect_uri):
|
||||
if not code:
|
||||
c.errors.add("NO_TEXT", field="code")
|
||||
if c.errors:
|
||||
return self.api_wrapper(self._check_for_errors())
|
||||
|
||||
access_token = None
|
||||
refresh_token = None
|
||||
|
||||
auth_token = OAuth2AuthorizationCode.use_token(
|
||||
code, c.oauth2_client._id, redirect_uri)
|
||||
if auth_token:
|
||||
if auth_token.refreshable:
|
||||
refresh_token = OAuth2RefreshToken._new(
|
||||
auth_token.client_id, auth_token.user_id,
|
||||
auth_token.scope)
|
||||
access_token = OAuth2AccessToken._new(
|
||||
auth_token.client_id, auth_token.user_id,
|
||||
auth_token.scope,
|
||||
refresh_token._id if refresh_token else None)
|
||||
|
||||
resp = self._make_token_dict(access_token, refresh_token)
|
||||
|
||||
return self.api_wrapper(resp)
|
||||
|
||||
@validate(refresh_token=VOAuth2RefreshToken("refresh_token"))
|
||||
def _access_token_refresh(self, refresh_token):
|
||||
resp = {}
|
||||
access_token = None
|
||||
if refresh_token:
|
||||
access_token = OAuth2AccessToken._new(
|
||||
refresh_token.client_id, refresh_token.user_id,
|
||||
refresh_token.scope,
|
||||
refresh_token=refresh_token._id)
|
||||
else:
|
||||
c.errors.add("NO_TEXT", field="refresh_token")
|
||||
|
||||
if c.errors:
|
||||
resp = self._check_for_errors()
|
||||
else:
|
||||
resp = self._make_token_dict(access_token)
|
||||
return self.api_wrapper(resp)
|
||||
|
||||
@validate(user=VThrottledLogin(["username", "password"]),
|
||||
scope=nop("scope"))
|
||||
def _access_token_password(self, user, scope):
|
||||
# username:password auth via OAuth is only allowed for
|
||||
# private use scripts
|
||||
client = c.oauth2_client
|
||||
if client.app_type != "script":
|
||||
return self.api_wrapper({"error": "unauthorized_client",
|
||||
"error_description": "Only script apps may use password auth"})
|
||||
dev_ids = client._developer_ids
|
||||
if not user or user._id not in dev_ids:
|
||||
return self.api_wrapper({"error": "invalid_grant"})
|
||||
if c.errors:
|
||||
return self.api_wrapper(self._check_for_errors())
|
||||
|
||||
if scope:
|
||||
scope = OAuth2Scope(scope)
|
||||
if not scope.is_valid():
|
||||
c.errors.add(errors.INVALID_OPTION, "scope")
|
||||
return self.api_wrapper({"error": "invalid_scope"})
|
||||
else:
|
||||
scope = OAuth2Scope(OAuth2Scope.FULL_ACCESS)
|
||||
|
||||
access_token = OAuth2AccessToken._new(
|
||||
client._id,
|
||||
user._id,
|
||||
scope
|
||||
)
|
||||
resp = self._make_token_dict(access_token)
|
||||
return self.api_wrapper(resp)
|
||||
|
||||
|
||||
class OAuth2ResourceController(MinimalController):
|
||||
def pre(self):
|
||||
set_extension(request.environ, "json")
|
||||
@@ -256,10 +323,8 @@ class OAuth2ResourceController(MinimalController):
|
||||
oauth2_perms = getattr(handler, "oauth2_perms", None)
|
||||
if oauth2_perms:
|
||||
grant = OAuth2Scope(access_token.scope)
|
||||
if grant.subreddit_only and c.site.name not in grant.subreddits:
|
||||
self._auth_error(403, "insufficient_scope")
|
||||
required_scopes = set(oauth2_perms['allowed_scopes'])
|
||||
if not (grant.scopes >= required_scopes):
|
||||
required = set(oauth2_perms['allowed_scopes'])
|
||||
if not grant.has_access(c.site.name, required):
|
||||
self._auth_error(403, "insufficient_scope")
|
||||
else:
|
||||
self._auth_error(400, "invalid_request")
|
||||
|
||||
@@ -206,6 +206,9 @@ class OAuth2Scope:
|
||||
},
|
||||
}
|
||||
|
||||
# Special scope, granted implicitly to clients with app_type == "script"
|
||||
FULL_ACCESS = "*"
|
||||
|
||||
def __init__(self, scope_str=None, subreddits=None, scopes=None):
|
||||
if scope_str:
|
||||
self._parse_scope_str(scope_str)
|
||||
@@ -235,11 +238,22 @@ class OAuth2Scope:
|
||||
sr_part = ''
|
||||
return sr_part + ','.join(sorted(self.scopes))
|
||||
|
||||
def has_access(self, subreddit, required_scopes):
|
||||
if self.FULL_ACCESS in self.scopes:
|
||||
return True
|
||||
if self.subreddit_only and subreddit not in self.subreddits:
|
||||
return False
|
||||
return (self.scopes >= required_scopes)
|
||||
|
||||
def is_valid(self):
|
||||
return all(scope in self.scope_info for scope in self.scopes)
|
||||
|
||||
def details(self):
|
||||
return [(scope, self.scope_info[scope]) for scope in self.scopes]
|
||||
if self.FULL_ACCESS in self.scopes:
|
||||
scopes = self.scope_info.keys()
|
||||
else:
|
||||
scopes = self.scopes
|
||||
return [(scope, self.scope_info[scope]) for scope in scopes]
|
||||
|
||||
@classmethod
|
||||
def merge_scopes(cls, scopes):
|
||||
|
||||
Reference in New Issue
Block a user