From cc40171cf50f67cd85f1cd409e864f754fbbbceb Mon Sep 17 00:00:00 2001 From: Dave Pifke Date: Tue, 13 Mar 2012 06:48:15 +0000 Subject: [PATCH] OAuth client/token relationships with Accounts Update the models to be able to update and query: - Which OAuth2Clients are developed by a user? - Who are the developers of an OAuth2Client? - What OAuth2AccessTokens are outstanding for a user? There's some duplicated code here, and tests are needed, but it's functional. --- r2/r2/models/token.py | 168 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 166 insertions(+), 2 deletions(-) diff --git a/r2/r2/models/token.py b/r2/r2/models/token.py index 3eb6bc1de..bada06d47 100644 --- a/r2/r2/models/token.py +++ b/r2/r2/models/token.py @@ -26,6 +26,8 @@ from base64 import urlsafe_b64encode from pycassa.system_manager import ASCII_TYPE, DATE_TYPE, UTF8_TYPE from r2.lib.db import tdb_cassandra +from r2.lib.db.thing import NotFound +from r2.models.account import Account def generate_token(size): @@ -105,12 +107,107 @@ class OAuth2Client(Token): _use_db = True _connection_pool = "main" + _developer_colname_prefix = 'has_developer_' + @classmethod def _new(cls, **kwargs): if "secret" not in kwargs: kwargs["secret"] = generate_token(cls.client_secret_size) return super(OAuth2Client, cls)._new(**kwargs) + @property + def _developers(self): + """Returns a list of users who are developers of this client.""" + + devs = [] + for k in [ k for k, v in self._t.iteritems() if k.startswith(self._developer_colname_prefix) and v ]: + try: + dev = Account._byID36(k[len(self._developer_colname_prefix):]) + if dev._deleted or dev._spam: + raise NotFound + except NotFound: + # Developer account is no longer valid; ignore + pass + else: + devs.append(dev) + + return devs + + def _developer_colname(self, account): + """Developer access is granted by way of adding a column with the + account's ID36 to the client object. This function returns the + column name for a given Account. + """ + + return ''.join((self._developer_colname_prefix, account._id36)) + + def has_developer(self, account): + """Returns a boolean indicating whether or not the supplied Account is a developer of this application.""" + + if account._deleted or account._spam: + return False + else: + return getattr(self, self._developer_colname(account), False) + + def add_developer(self, account, grantor=None): + """Grants developer access to the supplied Account.""" + + if not getattr(self, self._developer_colname(account), False): + setattr(self, self._developer_colname(account), True) + self._commit() + + # Also update index + OAuth2ClientsByAccount._set_values(account._id36, {self._id: ''}) + + def remove_developer(self, account): + """Revokes the supplied Account's developer access.""" + + if hasattr(self, self._developer_colname(account)): + del self[self._developer_colname(account)] + self._commit() + + # Also update index + try: + cba = OAuth2ClientsByAccount._byID(account._id36) + del cba[self._id] + except (tdb_cassandra.NotFound, KeyError): + pass + else: + cba._commit() + + @classmethod + def _by_developer(cls, account): + """Returns the list of clients for which Account is a developer.""" + + if account._deleted or account._spam: + return [] + + try: + cba = OAuth2ClientsByAccount._byID(account._id36) + except tdb_cassandra.NotFound: + return [] + + clients = [] + for cid in cba._values().iterkeys(): + try: + client = cls._byID(cid) + if not client.has_developer(account): + raise NotFound + except tdb_cassandra.NotFound: + pass + else: + clients.append(client) + + return clients + +class OAuth2ClientsByAccount(tdb_cassandra.View): + """Index providing access to the list of OAuth2Clients of which an Account is a developer.""" + + _use_db = True + _type_prefix = 'OAuth2ClientsByAccount' + _view_of = OAuth2Client + _connection_pool = 'main' + class OAuth2AuthorizationCode(ConsumableToken): """An OAuth2 authorization code for completing authorization flow""" @@ -159,9 +256,76 @@ class OAuth2AccessToken(Token): @classmethod def _new(cls, user_id, scope): return super(OAuth2AccessToken, cls)._new( - user_id=user_id, - scope=scope) + user_id=user_id, + scope=scope) + def _on_create(self): + """Updates the OAuth2AccessTokensByAccount index upon creation.""" + + OAuth2AccessTokensByAccount._set_values(self.user_id, {self._id: ''}) + return super(OAuth2AccessToken, self)._on_create() + + @property + def is_valid(self): + """Returns boolean indicating whether or not this access token is still valid.""" + + if getattr(self, 'revoked', False): + return False + + try: + account = Account._byID36(self.user_id) + if account._deleted or account._spam: + raise NotFound + except NotFound: + return False + + return True + + def revoke(self): + """Revokes (invalidates) this access token.""" + + self.revoked = True + self._commit() + + try: + tba = OAuth2AccessTokensByAccount._byID(self.user_id) + del tba[self._id] + except (tdb_cassandra.NotFound, KeyError): + # Not fatal, since self.is_valid() will still be False. + pass + else: + tba._commit() + + @classmethod + def _by_account(cls, account): + """Returns a (possibly empty) list of valid access tokens for a given user Account.""" + + try: + tba = OAuth2AccessTokensByAccount._byID(account._id36) + except tdb_cassandra.NotFound: + return [] + + tokens = [] + for tid in tba._values().iterkeys(): + try: + token = cls._byID(tid) + if not token.is_valid: + raise NotFound + except tdb_cassandra.NotFound: + pass + else: + tokens.append(token) + + return tokens + +class OAuth2AccessTokensByAccount(tdb_cassandra.View): + """Index listing the outstanding access tokens for an account.""" + + _use_db = True + _ttl = OAuth2AccessToken._ttl + _type_prefix = 'AccountOAuth2AccessToken' + _view_of = OAuth2AccessToken + _connection_pool = 'main' class EmailVerificationToken(ConsumableToken): _use_db = True