mirror of
https://github.com/reddit-archive/reddit.git
synced 2026-01-24 06:18:08 -05:00
Refactor/rename OAuth2 token model so we can reuse it.
The OAuth2 Cassandra models are a perfect fit for other places in the app that need randomly generated tokens.
This commit is contained in:
@@ -31,7 +31,7 @@ from r2.config.extensions import set_extension
|
||||
from reddit_base import RedditController, MinimalController, require_https
|
||||
from r2.lib.db.thing import NotFound
|
||||
from r2.models import Account
|
||||
from r2.models.oauth2 import OAuth2Client, OAuth2AuthorizationCode, OAuth2AccessToken
|
||||
from r2.models.token import OAuth2Client, OAuth2AuthorizationCode, OAuth2AccessToken
|
||||
from r2.controllers.errors import errors
|
||||
from validator import validate, VRequired, VOneOf, VUser, VModhash
|
||||
from r2.lib.pages import OAuth2AuthorizationPage
|
||||
|
||||
@@ -27,7 +27,7 @@ from r2.models import Friends, All, Sub, NotFound, DomainSR, Random, Mod, Random
|
||||
from r2.models import Link, Printable, Trophy, bidding, PromotionWeights, Comment
|
||||
from r2.models import Flair, FlairTemplate, FlairTemplateBySubredditIndex
|
||||
from r2.models import USER_FLAIR, LINK_FLAIR
|
||||
from r2.models.oauth2 import OAuth2Client
|
||||
from r2.models.token import OAuth2Client
|
||||
from r2.models import traffic
|
||||
from r2.models import ModAction
|
||||
from r2.models import Thing
|
||||
|
||||
@@ -36,6 +36,6 @@ from bidding import *
|
||||
from mail_queue import Email, has_opted_out, opt_count
|
||||
from gold import *
|
||||
from admintools import *
|
||||
from oauth2 import *
|
||||
from token import *
|
||||
from modaction import *
|
||||
from promo import *
|
||||
|
||||
@@ -22,22 +22,34 @@
|
||||
|
||||
from os import urandom
|
||||
from base64 import urlsafe_b64encode
|
||||
|
||||
from pycassa.system_manager import ASCII_TYPE, DATE_TYPE, UTF8_TYPE
|
||||
|
||||
from r2.lib.db import tdb_cassandra
|
||||
|
||||
def generate_token(size):
|
||||
return urlsafe_b64encode(urandom(size)).rstrip('=')
|
||||
|
||||
class OAuth2Token(tdb_cassandra.Thing):
|
||||
"""An OAuth2 authorization code for completing authorization flow"""
|
||||
class Token(tdb_cassandra.Thing):
|
||||
"""A unique randomly-generated token used for authentication."""
|
||||
|
||||
_extra_schema_creation_args = dict(
|
||||
key_validation_class=ASCII_TYPE,
|
||||
default_validation_class=UTF8_TYPE,
|
||||
column_validation_classes=dict(
|
||||
date=DATE_TYPE,
|
||||
used=ASCII_TYPE
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _new(cls, **kwargs):
|
||||
if "_id" not in kwargs:
|
||||
kwargs["_id"] = cls._generate_unique_token()
|
||||
|
||||
client = cls(**kwargs)
|
||||
client._commit()
|
||||
return client
|
||||
token = cls(**kwargs)
|
||||
token._commit()
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _generate_unique_token(cls):
|
||||
@@ -58,7 +70,24 @@ class OAuth2Token(tdb_cassandra.Thing):
|
||||
except tdb_cassandra.NotFound:
|
||||
return False
|
||||
|
||||
class OAuth2Client(OAuth2Token):
|
||||
class ConsumableToken(Token):
|
||||
_defaults = dict(used=False)
|
||||
_bool_props = ("used",)
|
||||
_warn_on_partial_ttl = False
|
||||
|
||||
@classmethod
|
||||
def get_token(cls, _id):
|
||||
token = super(ConsumableToken, cls).get_token(_id)
|
||||
if token and not token.used:
|
||||
return token
|
||||
else:
|
||||
return None
|
||||
|
||||
def consume(self):
|
||||
self.used = True
|
||||
self._commit()
|
||||
|
||||
class OAuth2Client(Token):
|
||||
"""A client registered for OAuth2 access"""
|
||||
token_size = 10
|
||||
client_secret_size = 20
|
||||
@@ -78,18 +107,17 @@ class OAuth2Client(OAuth2Token):
|
||||
kwargs["secret"] = generate_token(cls.client_secret_size)
|
||||
return super(OAuth2Client, cls)._new(**kwargs)
|
||||
|
||||
class OAuth2AuthorizationCode(OAuth2Token):
|
||||
class OAuth2AuthorizationCode(ConsumableToken):
|
||||
"""An OAuth2 authorization code for completing authorization flow"""
|
||||
token_size = 20
|
||||
_ttl = 10*60
|
||||
_defaults = dict(client_id="",
|
||||
redirect_uri="",
|
||||
scope="",
|
||||
used=False,
|
||||
)
|
||||
_bool_props = ("used",)
|
||||
_defaults = dict(ConsumableToken._defaults.items() + [
|
||||
("client_id", ""),
|
||||
("redirect_uri", ""),
|
||||
("scope", ""),
|
||||
]
|
||||
)
|
||||
_int_props = ("user_id",)
|
||||
_warn_on_partial_ttl = False
|
||||
_use_db = True
|
||||
_connection_pool = 'main'
|
||||
|
||||
@@ -101,25 +129,14 @@ class OAuth2AuthorizationCode(OAuth2Token):
|
||||
user_id=user_id,
|
||||
scope=scope)
|
||||
|
||||
@classmethod
|
||||
def get_token(cls, _id):
|
||||
token = super(OAuth2AuthorizationCode, cls).get_token(_id)
|
||||
if token and not token.used:
|
||||
return token
|
||||
else:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def use_token(cls, _id, client_id, redirect_uri):
|
||||
token = cls.get_token(_id)
|
||||
if token and token.client_id == client_id and token.redirect_uri == redirect_uri:
|
||||
token.used = True
|
||||
token._commit()
|
||||
token.consume()
|
||||
return token
|
||||
else:
|
||||
return False
|
||||
|
||||
class OAuth2AccessToken(OAuth2Token):
|
||||
class OAuth2AccessToken(Token):
|
||||
"""An OAuth2 access token for accessing protected resources"""
|
||||
token_size = 20
|
||||
_ttl = 10*60
|
||||
Reference in New Issue
Block a user