diff --git a/r2/r2/models/token.py b/r2/r2/models/token.py index f3f0b30d9..b4ad50183 100644 --- a/r2/r2/models/token.py +++ b/r2/r2/models/token.py @@ -27,8 +27,10 @@ from pycassa.system_manager import ASCII_TYPE, DATE_TYPE, UTF8_TYPE from r2.lib.db import tdb_cassandra + def generate_token(size): - return urlsafe_b64encode(urandom(size)).rstrip('=') + return urlsafe_b64encode(urandom(size)).rstrip("=") + class Token(tdb_cassandra.Thing): """A unique randomly-generated token used for authentication.""" @@ -68,7 +70,8 @@ class Token(tdb_cassandra.Thing): try: return cls._byID(_id) except tdb_cassandra.NotFound: - return False + return None + class ConsumableToken(Token): _defaults = dict(used=False) @@ -87,6 +90,7 @@ class ConsumableToken(Token): self.used = True self._commit() + class OAuth2Client(Token): """A client registered for OAuth2 access""" token_size = 10 @@ -99,7 +103,7 @@ class OAuth2Client(Token): redirect_uri="", ) _use_db = True - _connection_pool = 'main' + _connection_pool = "main" @classmethod def _new(cls, **kwargs): @@ -107,10 +111,11 @@ class OAuth2Client(Token): kwargs["secret"] = generate_token(cls.client_secret_size) return super(OAuth2Client, cls)._new(**kwargs) + class OAuth2AuthorizationCode(ConsumableToken): """An OAuth2 authorization code for completing authorization flow""" token_size = 20 - _ttl = 10*60 + _ttl = 10 * 60 _defaults = dict(ConsumableToken._defaults.items() + [ ("client_id", ""), ("redirect_uri", ""), @@ -119,7 +124,7 @@ class OAuth2AuthorizationCode(ConsumableToken): ) _int_props = ("user_id",) _use_db = True - _connection_pool = 'main' + _connection_pool = "main" @classmethod def _new(cls, client_id, redirect_uri, user_id, scope): @@ -132,20 +137,24 @@ class OAuth2AuthorizationCode(ConsumableToken): @classmethod def use_token(cls, _id, client_id, redirect_uri): token = cls.get_token(_id) - if token and token.client_id == client_id and token.redirect_uri == redirect_uri: + if token and (token.client_id == client_id and + token.redirect_uri == redirect_uri): token.consume() return token + else: + return None + class OAuth2AccessToken(Token): """An OAuth2 access token for accessing protected resources""" token_size = 20 - _ttl = 10*60 + _ttl = 10 * 60 _defaults = dict(scope="", token_type="bearer", ) _int_props = ("user_id",) _use_db = True - _connection_pool = 'main' + _connection_pool = "main" @classmethod def _new(cls, user_id, scope):