diff --git a/r2/r2/controllers/oauth2.py b/r2/r2/controllers/oauth2.py index 9a96f5d23..a7b633204 100644 --- a/r2/r2/controllers/oauth2.py +++ b/r2/r2/controllers/oauth2.py @@ -42,6 +42,7 @@ from r2.lib.validator import ( nop, validate, VRequired, + VThrottledLogin, VOneOf, VUser, VModhash, @@ -160,13 +161,11 @@ class OAuth2AccessController(MinimalController): except RequirementException: abort(401, headers=[("WWW-Authenticate", 'Basic realm="reddit"')]) - @validate(grant_type = VOneOf("grant_type", - ("authorization_code", "refresh_token")), - code = nop("code"), - refresh_token = VOAuth2RefreshToken("refresh_token"), - redirect_uri = VRequired("redirect_uri", - errors.OAUTH2_INVALID_REDIRECT_URI)) - def POST_access_token(self, grant_type, code, refresh_token, redirect_uri): + @validate(grant_type=VOneOf("grant_type", + ("authorization_code", + "refresh_token", + "password"))) + def POST_access_token(self, grant_type): """ Exchange an [OAuth 2.0](http://oauth.net/2/) authorization code or refresh token (from [/api/v1/authorize](#api_method_authorize)) for @@ -185,54 +184,122 @@ class OAuth2AccessController(MinimalController): Per the OAuth specification, **grant_type** must be ``authorization_code`` for the initial access token or - ``refresh_token`` for renewing the access token. In either case, + ``refresh_token`` for renewing the access token. + **redirect_uri** must exactly match the value that was used in the call to [/api/v1/authorize](#api_method_authorize) that created this grant. """ - resp = {} - if not (code or refresh_token): - c.errors.add("NO_TEXT", field=("code", "refresh_token")) - if not c.errors: - access_token = None - - if grant_type == "authorization_code": - auth_token = OAuth2AuthorizationCode.use_token( - code, c.oauth2_client._id, redirect_uri) - if auth_token: - if auth_token.refreshable: - refresh_token = OAuth2RefreshToken._new( - auth_token.client_id, auth_token.user_id, - auth_token.scope) - access_token = OAuth2AccessToken._new( - auth_token.client_id, auth_token.user_id, - auth_token.scope, - refresh_token._id if refresh_token else None) - elif grant_type == "refresh_token" and refresh_token: - access_token = OAuth2AccessToken._new( - refresh_token.client_id, refresh_token.user_id, - refresh_token.scope, - refresh_token=refresh_token._id) - - if access_token: - resp["access_token"] = access_token._id - resp["token_type"] = access_token.token_type - resp["expires_in"] = int(access_token._ttl) if access_token._ttl else None - resp["scope"] = access_token.scope - if refresh_token: - resp["refresh_token"] = refresh_token._id - else: - resp["error"] = "invalid_grant" + if grant_type == "authorization_code": + return self._access_token_code() + elif grant_type == "refresh_token": + return self._access_token_refresh() + elif grant_type == "password": + return self._access_token_password() else: - if (errors.INVALID_OPTION, "grant_type") in c.errors: - resp["error"] = "unsupported_grant_type" - elif (errors.INVALID_OPTION, "scope") in c.errors: - resp["error"] = "invalid_scope" - else: - resp["error"] = "invalid_request" + resp = {"error": "unsupported_grant_type"} + return self.api_wrapper(resp) + + def _check_for_errors(self): + resp = {} + if (errors.INVALID_OPTION, "scope") in c.errors: + resp["error"] = "invalid_scope" + else: + resp["error"] = "invalid_request" + return resp + + def _make_token_dict(self, access_token, refresh_token=None): + if not access_token: + return {"error": "invalid_grant"} + expires_in = int(access_token._ttl) if access_token._ttl else None + resp = { + "access_token": access_token._id, + "token_type": access_token.token_type, + "expires_in": expires_in, + "scope": access_token.scope, + } + if refresh_token: + resp["refresh_token"] = refresh_token._id + return resp + + @validate(code=nop("code"), + redirect_uri=VRequired("redirect_uri", + errors.OAUTH2_INVALID_REDIRECT_URI)) + def _access_token_code(self, code, redirect_uri): + if not code: + c.errors.add("NO_TEXT", field="code") + if c.errors: + return self.api_wrapper(self._check_for_errors()) + + access_token = None + refresh_token = None + + auth_token = OAuth2AuthorizationCode.use_token( + code, c.oauth2_client._id, redirect_uri) + if auth_token: + if auth_token.refreshable: + refresh_token = OAuth2RefreshToken._new( + auth_token.client_id, auth_token.user_id, + auth_token.scope) + access_token = OAuth2AccessToken._new( + auth_token.client_id, auth_token.user_id, + auth_token.scope, + refresh_token._id if refresh_token else None) + + resp = self._make_token_dict(access_token, refresh_token) return self.api_wrapper(resp) + @validate(refresh_token=VOAuth2RefreshToken("refresh_token")) + def _access_token_refresh(self, refresh_token): + resp = {} + access_token = None + if refresh_token: + access_token = OAuth2AccessToken._new( + refresh_token.client_id, refresh_token.user_id, + refresh_token.scope, + refresh_token=refresh_token._id) + else: + c.errors.add("NO_TEXT", field="refresh_token") + + if c.errors: + resp = self._check_for_errors() + else: + resp = self._make_token_dict(access_token) + return self.api_wrapper(resp) + + @validate(user=VThrottledLogin(["username", "password"]), + scope=nop("scope")) + def _access_token_password(self, user, scope): + # username:password auth via OAuth is only allowed for + # private use scripts + client = c.oauth2_client + if client.app_type != "script": + return self.api_wrapper({"error": "unauthorized_client", + "error_description": "Only script apps may use password auth"}) + dev_ids = client._developer_ids + if not user or user._id not in dev_ids: + return self.api_wrapper({"error": "invalid_grant"}) + if c.errors: + return self.api_wrapper(self._check_for_errors()) + + if scope: + scope = OAuth2Scope(scope) + if not scope.is_valid(): + c.errors.add(errors.INVALID_OPTION, "scope") + return self.api_wrapper({"error": "invalid_scope"}) + else: + scope = OAuth2Scope(OAuth2Scope.FULL_ACCESS) + + access_token = OAuth2AccessToken._new( + client._id, + user._id, + scope + ) + resp = self._make_token_dict(access_token) + return self.api_wrapper(resp) + + class OAuth2ResourceController(MinimalController): def pre(self): set_extension(request.environ, "json") @@ -256,10 +323,8 @@ class OAuth2ResourceController(MinimalController): oauth2_perms = getattr(handler, "oauth2_perms", None) if oauth2_perms: grant = OAuth2Scope(access_token.scope) - if grant.subreddit_only and c.site.name not in grant.subreddits: - self._auth_error(403, "insufficient_scope") - required_scopes = set(oauth2_perms['allowed_scopes']) - if not (grant.scopes >= required_scopes): + required = set(oauth2_perms['allowed_scopes']) + if not grant.has_access(c.site.name, required): self._auth_error(403, "insufficient_scope") else: self._auth_error(400, "invalid_request") diff --git a/r2/r2/models/token.py b/r2/r2/models/token.py index 8dfc47340..1c956d9a0 100644 --- a/r2/r2/models/token.py +++ b/r2/r2/models/token.py @@ -206,6 +206,9 @@ class OAuth2Scope: }, } + # Special scope, granted implicitly to clients with app_type == "script" + FULL_ACCESS = "*" + def __init__(self, scope_str=None, subreddits=None, scopes=None): if scope_str: self._parse_scope_str(scope_str) @@ -235,11 +238,22 @@ class OAuth2Scope: sr_part = '' return sr_part + ','.join(sorted(self.scopes)) + def has_access(self, subreddit, required_scopes): + if self.FULL_ACCESS in self.scopes: + return True + if self.subreddit_only and subreddit not in self.subreddits: + return False + return (self.scopes >= required_scopes) + def is_valid(self): return all(scope in self.scope_info for scope in self.scopes) def details(self): - return [(scope, self.scope_info[scope]) for scope in self.scopes] + if self.FULL_ACCESS in self.scopes: + scopes = self.scope_info.keys() + else: + scopes = self.scopes + return [(scope, self.scope_info[scope]) for scope in scopes] @classmethod def merge_scopes(cls, scopes):