From 576557244d1c43395b7862b7cff679eb5736483e Mon Sep 17 00:00:00 2001 From: Dave Pifke Date: Wed, 14 Mar 2012 18:49:51 +0000 Subject: [PATCH] Tweaks to OAuth2 models. Adds the clients-by-user lookup, deletion of clients, and renames some classes and methods to be more consistent. --- r2/r2/models/account.py | 6 ++++ r2/r2/models/token.py | 61 +++++++++++++++++++++++++++++++---------- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/r2/r2/models/account.py b/r2/r2/models/account.py index bcab96117..ba72a2149 100644 --- a/r2/r2/models/account.py +++ b/r2/r2/models/account.py @@ -374,6 +374,12 @@ class Account(Thing): for f in q: f._thing1.remove_enemy(f._thing2) + # Remove OAuth2Client developer permissions. This will delete any + # clients for which this account is the sole developer. + from r2.models.oauth2 import OAuth2Client + for client in OAuth2Client._by_developer(self): + client.remove_developer(self) + @property def subreddits(self): from subreddit import Subreddit diff --git a/r2/r2/models/token.py b/r2/r2/models/token.py index bada06d47..9b94cc320 100644 --- a/r2/r2/models/token.py +++ b/r2/r2/models/token.py @@ -149,7 +149,7 @@ class OAuth2Client(Token): else: return getattr(self, self._developer_colname(account), False) - def add_developer(self, account, grantor=None): + def add_developer(self, account): """Grants developer access to the supplied Account.""" if not getattr(self, self._developer_colname(account), False): @@ -157,18 +157,21 @@ class OAuth2Client(Token): self._commit() # Also update index - OAuth2ClientsByAccount._set_values(account._id36, {self._id: ''}) + OAuth2ClientsByDeveloper._set_values(account._id36, {self._id: ''}) def remove_developer(self, account): """Revokes the supplied Account's developer access.""" if hasattr(self, self._developer_colname(account)): del self[self._developer_colname(account)] + if not len(self._developers): + # No developers remain, delete the client + self.deleted = True self._commit() # Also update index try: - cba = OAuth2ClientsByAccount._byID(account._id36) + cba = OAuth2ClientsByDeveloper._byID(account._id36) del cba[self._id] except (tdb_cassandra.NotFound, KeyError): pass @@ -177,13 +180,13 @@ class OAuth2Client(Token): @classmethod def _by_developer(cls, account): - """Returns the list of clients for which Account is a developer.""" + """Returns a (possibly empty) list of clients for which Account is a developer.""" if account._deleted or account._spam: return [] try: - cba = OAuth2ClientsByAccount._byID(account._id36) + cba = OAuth2ClientsByDeveloper._byID(account._id36) except tdb_cassandra.NotFound: return [] @@ -191,7 +194,7 @@ class OAuth2Client(Token): for cid in cba._values().iterkeys(): try: client = cls._byID(cid) - if not client.has_developer(account): + if client.deleted or not client.has_developer(account): raise NotFound except tdb_cassandra.NotFound: pass @@ -200,11 +203,29 @@ class OAuth2Client(Token): return clients -class OAuth2ClientsByAccount(tdb_cassandra.View): + @classmethod + def _by_user(cls, account): + """Returns a (possibly empty) list of clients for which Account has outstanding access tokens.""" + + client_ids = set() + for token in OAuth2AccessToken._by_user(account): + if token.is_valid: + client_ids.add(token.client_id) + + return [ cls._byID(client_id) for client_id in client_ids ] + + def revoke(self, account): + """Revoke all of the outstanding OAuth2AccessTokens associated with this client and user Account.""" + + for token in OAuth2AccessToken._by_user(account): + if token.client_id == self._id: + token.revoke() + +class OAuth2ClientsByDeveloper(tdb_cassandra.View): """Index providing access to the list of OAuth2Clients of which an Account is a developer.""" _use_db = True - _type_prefix = 'OAuth2ClientsByAccount' + _type_prefix = 'OAuth2ClientsByDeveloper' _view_of = OAuth2Client _connection_pool = 'main' @@ -260,18 +281,28 @@ class OAuth2AccessToken(Token): scope=scope) def _on_create(self): - """Updates the OAuth2AccessTokensByAccount index upon creation.""" + """Updates the OAuth2AccessTokensByUser index upon creation.""" - OAuth2AccessTokensByAccount._set_values(self.user_id, {self._id: ''}) + OAuth2AccessTokensByUser._set_values(self.user_id, {self._id: ''}) return super(OAuth2AccessToken, self)._on_create() @property def is_valid(self): """Returns boolean indicating whether or not this access token is still valid.""" + # Has the token been revoked? if getattr(self, 'revoked', False): return False + # Is the OAuth2Client still valid? + try: + client = OAuth2Client._byID(self.client_id) + if client.deleted: + raise NotFound + except NotFound: + return False + + # Is the user account still valid? try: account = Account._byID36(self.user_id) if account._deleted or account._spam: @@ -288,7 +319,7 @@ class OAuth2AccessToken(Token): self._commit() try: - tba = OAuth2AccessTokensByAccount._byID(self.user_id) + tba = OAuth2AccessTokensByUser._byID(self.user_id) del tba[self._id] except (tdb_cassandra.NotFound, KeyError): # Not fatal, since self.is_valid() will still be False. @@ -297,11 +328,11 @@ class OAuth2AccessToken(Token): tba._commit() @classmethod - def _by_account(cls, account): + def _by_user(cls, account): """Returns a (possibly empty) list of valid access tokens for a given user Account.""" try: - tba = OAuth2AccessTokensByAccount._byID(account._id36) + tba = OAuth2AccessTokensByUser._byID(account._id36) except tdb_cassandra.NotFound: return [] @@ -318,12 +349,12 @@ class OAuth2AccessToken(Token): return tokens -class OAuth2AccessTokensByAccount(tdb_cassandra.View): +class OAuth2AccessTokensByUser(tdb_cassandra.View): """Index listing the outstanding access tokens for an account.""" _use_db = True _ttl = OAuth2AccessToken._ttl - _type_prefix = 'AccountOAuth2AccessToken' + _type_prefix = 'OAuth2AccessTokensByUser' _view_of = OAuth2AccessToken _connection_pool = 'main'