mirror of
https://github.com/joaovitoriasilva/endurain.git
synced 2026-01-07 23:13:57 -05:00
Implement refresh token rotation and reuse detection
Adds rotated refresh token tracking to detect and prevent token reuse attacks. Introduces new models, schemas, and utilities for storing and checking rotated tokens, and invalidates all sessions in a token family if reuse is detected. Updates session and authentication logic to support token families, rotation counts, and last rotation timestamps. Includes Alembic migration for new columns and tables, and schedules cleanup of expired rotated tokens. Also improves frontend logout to refresh tokens before logging out.
This commit is contained in:
@@ -28,6 +28,7 @@ import password_reset_tokens.models
|
||||
import sign_up_tokens.models
|
||||
import server_settings.models
|
||||
import session.models
|
||||
import session.rotated_refresh_tokens.models
|
||||
import users.user.models
|
||||
import users.user_goals.models
|
||||
import users.user_default_gear.models
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""v0.16.4 migration
|
||||
|
||||
Revision ID: ef6cd7775aa2
|
||||
Revision ID: ed5f1c867943
|
||||
Revises: 2af2c0629b37
|
||||
Create Date: 2025-12-16 12:47:18.298420
|
||||
Create Date: 2025-12-18 12:02:47.808747
|
||||
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,7 @@ import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ef6cd7775aa2"
|
||||
revision: str = "ed5f1c867943"
|
||||
down_revision: Union[str, None] = "2af2c0629b37"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@@ -94,12 +94,10 @@ def upgrade() -> None:
|
||||
sa.ForeignKeyConstraint(
|
||||
["idp_id"],
|
||||
["identity_providers.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -116,7 +114,20 @@ def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_oauth_states_user_id"), "oauth_states", ["user_id"], unique=False
|
||||
)
|
||||
# Add oauth_state_id and tokens_exchanged to users_sessions table
|
||||
|
||||
# Delete all existing sessions before altering user_sessions table
|
||||
op.execute("DELETE FROM users_sessions")
|
||||
|
||||
# Add new columns to users_sessions
|
||||
op.add_column(
|
||||
"users_sessions",
|
||||
sa.Column(
|
||||
"last_activity_at",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
comment="Last activity timestamp for idle timeout",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"users_sessions",
|
||||
sa.Column(
|
||||
@@ -131,23 +142,36 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"tokens_exchanged",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
comment="Prevents duplicate token exchange for mobile",
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE users_sessions
|
||||
SET tokens_exchanged = false
|
||||
WHERE tokens_exchanged IS NULL;
|
||||
"""
|
||||
)
|
||||
op.alter_column(
|
||||
op.add_column(
|
||||
"users_sessions",
|
||||
"tokens_exchanged",
|
||||
nullable=False,
|
||||
comment="Prevents duplicate token exchange for mobile",
|
||||
existing_type=sa.Boolean(),
|
||||
sa.Column(
|
||||
"token_family_id",
|
||||
sa.String(length=36),
|
||||
nullable=False,
|
||||
comment="UUID identifying token family for reuse detection",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"users_sessions",
|
||||
sa.Column(
|
||||
"rotation_count",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
comment="Number of times refresh token has been rotated",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"users_sessions",
|
||||
sa.Column(
|
||||
"last_rotation_at",
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
comment="Timestamp of last token rotation",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_users_sessions_oauth_state_id"),
|
||||
@@ -155,43 +179,84 @@ def upgrade() -> None:
|
||||
["oauth_state_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_users_sessions_token_family_id"),
|
||||
"users_sessions",
|
||||
["token_family_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
None, "users_sessions", "oauth_states", ["oauth_state_id"], ["id"]
|
||||
)
|
||||
# Add last_activity_at column with default value = created_at
|
||||
op.add_column(
|
||||
"users_sessions",
|
||||
|
||||
# Create rotated_refresh_tokens table
|
||||
op.create_table(
|
||||
"rotated_refresh_tokens",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"last_activity_at",
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
comment="Last activity timestamp for idle timeout",
|
||||
"token_family_id",
|
||||
sa.String(length=36),
|
||||
nullable=False,
|
||||
comment="UUID of the token family",
|
||||
),
|
||||
sa.Column(
|
||||
"hashed_token",
|
||||
sa.String(length=255),
|
||||
nullable=False,
|
||||
comment="Hashed old refresh token",
|
||||
),
|
||||
sa.Column(
|
||||
"rotation_count",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
comment="Which rotation this token belonged to",
|
||||
),
|
||||
sa.Column(
|
||||
"rotated_at",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
comment="When this token was rotated",
|
||||
),
|
||||
sa.Column(
|
||||
"expires_at",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
comment="Cleanup marker (rotated_at + 60 seconds)",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["token_family_id"],
|
||||
["users_sessions.token_family_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("hashed_token"),
|
||||
)
|
||||
|
||||
# Backfill existing sessions: set last_activity_at = created_at
|
||||
op.execute(
|
||||
"UPDATE users_sessions SET last_activity_at = created_at WHERE last_activity_at IS NULL"
|
||||
)
|
||||
|
||||
# Make column non-nullable after backfill
|
||||
op.alter_column(
|
||||
"users_sessions",
|
||||
"last_activity_at",
|
||||
nullable=False,
|
||||
comment="Last activity timestamp for idle timeout",
|
||||
existing_type=sa.DateTime(),
|
||||
op.create_index(
|
||||
op.f("ix_rotated_refresh_tokens_token_family_id"),
|
||||
"rotated_refresh_tokens",
|
||||
["token_family_id"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("users_sessions", "last_activity_at")
|
||||
op.drop_constraint(None, "users_sessions", type_="foreignkey")
|
||||
op.drop_index(
|
||||
op.f("ix_users_sessions_token_family_id"), table_name="users_sessions"
|
||||
)
|
||||
op.drop_index(op.f("ix_users_sessions_oauth_state_id"), table_name="users_sessions")
|
||||
op.drop_column("users_sessions", "last_rotation_at")
|
||||
op.drop_column("users_sessions", "rotation_count")
|
||||
op.drop_column("users_sessions", "token_family_id")
|
||||
op.drop_column("users_sessions", "tokens_exchanged")
|
||||
op.drop_column("users_sessions", "oauth_state_id")
|
||||
op.drop_column("users_sessions", "last_activity_at")
|
||||
op.drop_index(
|
||||
op.f("ix_rotated_refresh_tokens_token_family_id"),
|
||||
table_name="rotated_refresh_tokens",
|
||||
)
|
||||
op.drop_table("rotated_refresh_tokens")
|
||||
op.drop_index(op.f("ix_oauth_states_user_id"), table_name="oauth_states")
|
||||
op.drop_index(op.f("ix_oauth_states_used"), table_name="oauth_states")
|
||||
op.drop_index(op.f("ix_oauth_states_idp_id"), table_name="oauth_states")
|
||||
|
||||
@@ -31,6 +31,8 @@ import profile.utils as profile_utils
|
||||
import core.database as core_database
|
||||
import core.rate_limit as core_rate_limit
|
||||
|
||||
import session.rotated_refresh_tokens.utils as rotated_tokens_utils
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
|
||||
@@ -325,13 +327,29 @@ async def refresh_token(
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# NEW: Validate session hasn't exceeded idle or absolute timeout
|
||||
# Validate session hasn't exceeded idle or absolute timeout
|
||||
session_utils.validate_session_timeout(session)
|
||||
|
||||
# Check for token reuse BEFORE validating token
|
||||
# Hash the incoming token to compare with rotated tokens
|
||||
hashed_refresh_token = password_hasher.hash_password(refresh_token_value)
|
||||
is_reused, in_grace = rotated_tokens_utils.check_token_reuse(
|
||||
hashed_refresh_token, db
|
||||
)
|
||||
|
||||
if is_reused and not in_grace:
|
||||
# Token theft detected - invalidate entire family
|
||||
rotated_tokens_utils.invalidate_token_family(session.token_family_id, db)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token reuse detected. All sessions invalidated.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Validate CSRF token matches session
|
||||
# Note: CSRF token is stored in session during initial authentication
|
||||
# For now, we validate that a CSRF token was provided (checked by middleware)
|
||||
# Future enhancement: Store CSRF token hash in session for validation
|
||||
# Note: CSRF token is stored in session during initial auth
|
||||
# For now, we validate that a CSRF token was provided
|
||||
# Future enhancement: Store CSRF token hash in session
|
||||
|
||||
is_valid = password_hasher.verify(refresh_token_value, session.refresh_token)
|
||||
|
||||
@@ -355,6 +373,15 @@ async def refresh_token(
|
||||
# Check if the user is active
|
||||
users_utils.check_user_is_active(user)
|
||||
|
||||
# Store old refresh token BEFORE rotating
|
||||
# This enables detection if the old token is reused later
|
||||
rotated_tokens_utils.store_rotated_token(
|
||||
session.refresh_token,
|
||||
session.token_family_id,
|
||||
session.rotation_count,
|
||||
db,
|
||||
)
|
||||
|
||||
# Create the tokens
|
||||
(
|
||||
session_id,
|
||||
@@ -365,7 +392,9 @@ async def refresh_token(
|
||||
new_csrf_token,
|
||||
) = auth_utils.create_tokens(user, token_manager, session.id)
|
||||
|
||||
# Edit the session and store it in the database
|
||||
# Edit session and store in database
|
||||
# Note: edit_session automatically increments rotation_count
|
||||
# and updates last_rotation_at
|
||||
session_utils.edit_session(session, request, new_refresh_token, password_hasher, db)
|
||||
|
||||
# Opportunistically refresh IdP tokens for all linked identity providers
|
||||
|
||||
@@ -13,6 +13,8 @@ import sign_up_tokens.utils as sign_up_tokens_utils
|
||||
|
||||
import session.utils as session_utils
|
||||
|
||||
import session.rotated_refresh_tokens.utils as rotated_tokens_utils
|
||||
|
||||
import auth.oauth_state.utils as oauth_state_utils
|
||||
|
||||
import core.logger as core_logger
|
||||
@@ -90,6 +92,14 @@ def start_scheduler():
|
||||
"delete expired sessions from the database",
|
||||
)
|
||||
|
||||
add_scheduler_job(
|
||||
rotated_tokens_utils.cleanup_expired_rotated_tokens,
|
||||
"interval",
|
||||
5,
|
||||
[],
|
||||
"delete expired rotated tokens from the database",
|
||||
)
|
||||
|
||||
|
||||
def add_scheduler_job(func, interval, minutes, args, description):
|
||||
try:
|
||||
|
||||
@@ -405,3 +405,39 @@ def delete_idle_sessions(cutoff_time: datetime, db: Session) -> int:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete idle sessions",
|
||||
) from err
|
||||
|
||||
|
||||
def delete_sessions_by_family(token_family_id: str, db: Session) -> int:
|
||||
"""
|
||||
Delete all sessions belonging to a token family.
|
||||
|
||||
Args:
|
||||
token_family_id: The family ID to delete sessions for.
|
||||
db: The SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during deletion (500).
|
||||
"""
|
||||
try:
|
||||
num_deleted = (
|
||||
db.query(session_models.UsersSessions)
|
||||
.filter(session_models.UsersSessions.token_family_id == token_family_id)
|
||||
.delete()
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return num_deleted
|
||||
except Exception as err:
|
||||
db.rollback()
|
||||
core_logger.print_to_log(
|
||||
f"Error in delete_sessions_by_family: {err}",
|
||||
"error",
|
||||
exc=err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete sessions by family",
|
||||
) from err
|
||||
|
||||
@@ -28,6 +28,7 @@ class UsersSessions(Base):
|
||||
last_activity_at (datetime): Timestamp of last user activity (for idle timeout).
|
||||
expires_at (datetime): Timestamp when the session expires.
|
||||
user (User): Relationship to the User model.
|
||||
rotated_refresh_tokens (list): Rotated tokens for this session.
|
||||
"""
|
||||
|
||||
__tablename__ = "users_sessions"
|
||||
@@ -77,9 +78,32 @@ class UsersSessions(Base):
|
||||
nullable=False,
|
||||
comment="Prevents duplicate token exchange for mobile",
|
||||
)
|
||||
token_family_id = Column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
comment="UUID identifying token family for reuse detection",
|
||||
)
|
||||
rotation_count = Column(
|
||||
Integer,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="Number of times refresh token has been rotated",
|
||||
)
|
||||
last_rotation_at = Column(
|
||||
DateTime, nullable=True, comment="Timestamp of last token rotation"
|
||||
)
|
||||
|
||||
# Define a relationship to the User model
|
||||
user = relationship("User", back_populates="users_sessions")
|
||||
|
||||
# Define a relationship to the OAuthState model
|
||||
oauth_state = relationship("OAuthState", back_populates="users_sessions")
|
||||
|
||||
# Define a relationship to RotatedRefreshToken model
|
||||
rotated_refresh_tokens = relationship(
|
||||
"RotatedRefreshToken",
|
||||
back_populates="user_session",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
5
backend/app/session/rotated_refresh_tokens/__init__.py
Normal file
5
backend/app/session/rotated_refresh_tokens/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Rotated refresh tokens module for token reuse detection."""
|
||||
|
||||
from session.rotated_refresh_tokens.models import RotatedRefreshToken
|
||||
|
||||
__all__ = ["RotatedRefreshToken"]
|
||||
158
backend/app/session/rotated_refresh_tokens/crud.py
Normal file
158
backend/app/session/rotated_refresh_tokens/crud.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""CRUD operations for rotated refresh tokens."""
|
||||
|
||||
from datetime import datetime
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.rotated_refresh_tokens.models as rotated_token_models
|
||||
import session.rotated_refresh_tokens.schema as rotated_token_schema
|
||||
import core.logger as core_logger
|
||||
|
||||
|
||||
def get_rotated_token_by_hash(
|
||||
hashed_token: str, db: Session
|
||||
) -> rotated_token_models.RotatedRefreshToken | None:
|
||||
"""
|
||||
Retrieve a rotated token by its hashed value.
|
||||
|
||||
Args:
|
||||
hashed_token: The hashed refresh token to search for.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
The RotatedRefreshToken if found, None otherwise.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during retrieval (500).
|
||||
"""
|
||||
try:
|
||||
return (
|
||||
db.query(rotated_token_models.RotatedRefreshToken)
|
||||
.filter(
|
||||
rotated_token_models.RotatedRefreshToken.hashed_token == hashed_token
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Error in get_rotated_token_by_hash: {err}",
|
||||
"error",
|
||||
exc=err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve rotated token",
|
||||
) from err
|
||||
|
||||
|
||||
def create_rotated_token(
|
||||
rotated_token: rotated_token_schema.RotatedRefreshTokenCreate,
|
||||
db: Session,
|
||||
) -> rotated_token_models.RotatedRefreshToken:
|
||||
"""
|
||||
Store a rotated refresh token in the database.
|
||||
|
||||
Args:
|
||||
rotated_token: The rotated token data to store.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
The created RotatedRefreshToken object.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during creation (500).
|
||||
"""
|
||||
try:
|
||||
db_rotated_token = rotated_token_models.RotatedRefreshToken(
|
||||
token_family_id=rotated_token.token_family_id,
|
||||
hashed_token=rotated_token.hashed_token,
|
||||
rotation_count=rotated_token.rotation_count,
|
||||
rotated_at=rotated_token.rotated_at,
|
||||
expires_at=rotated_token.expires_at,
|
||||
)
|
||||
|
||||
db.add(db_rotated_token)
|
||||
db.commit()
|
||||
db.refresh(db_rotated_token)
|
||||
|
||||
return db_rotated_token
|
||||
except Exception as err:
|
||||
db.rollback()
|
||||
core_logger.print_to_log(
|
||||
f"Error in create_rotated_token: {err}",
|
||||
"error",
|
||||
exc=err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to store rotated token",
|
||||
) from err
|
||||
|
||||
|
||||
def delete_expired_tokens(cutoff_time: datetime, db: Session) -> int:
|
||||
"""
|
||||
Delete rotated tokens older than the cutoff time.
|
||||
|
||||
Args:
|
||||
cutoff_time: Tokens with expires_at before this will be deleted.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
Number of tokens deleted.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during deletion (500).
|
||||
"""
|
||||
try:
|
||||
num_deleted = (
|
||||
db.query(rotated_token_models.RotatedRefreshToken)
|
||||
.filter(rotated_token_models.RotatedRefreshToken.expires_at < cutoff_time)
|
||||
.delete()
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return num_deleted
|
||||
except Exception as err:
|
||||
db.rollback()
|
||||
core_logger.print_to_log(
|
||||
f"Error in delete_expired_tokens: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete expired tokens",
|
||||
) from err
|
||||
|
||||
|
||||
def delete_by_family(token_family_id: str, db: Session) -> int:
|
||||
"""
|
||||
Delete all rotated tokens for a specific token family.
|
||||
|
||||
Args:
|
||||
token_family_id: The family ID to delete tokens for.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
Number of tokens deleted.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during deletion (500).
|
||||
"""
|
||||
try:
|
||||
num_deleted = (
|
||||
db.query(rotated_token_models.RotatedRefreshToken)
|
||||
.filter(
|
||||
rotated_token_models.RotatedRefreshToken.token_family_id
|
||||
== token_family_id
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return num_deleted
|
||||
except Exception as err:
|
||||
db.rollback()
|
||||
core_logger.print_to_log(f"Error in delete_by_family: {err}", "error", exc=err)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete tokens by family",
|
||||
) from err
|
||||
51
backend/app/session/rotated_refresh_tokens/models.py
Normal file
51
backend/app/session/rotated_refresh_tokens/models.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from core.database import Base
|
||||
|
||||
|
||||
class RotatedRefreshToken(Base):
|
||||
"""
|
||||
Represents a rotated refresh token in the system.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the rotated token.
|
||||
token_family_id: UUID of the token family.
|
||||
hashed_token: Hashed old refresh token.
|
||||
rotation_count: Which rotation this token belonged to.
|
||||
rotated_at: When this token was rotated.
|
||||
expires_at: Cleanup marker (rotated_at + 60 seconds).
|
||||
user_session: Relationship to UsersSessions model.
|
||||
"""
|
||||
|
||||
__tablename__ = "rotated_refresh_tokens"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
token_family_id = Column(
|
||||
String(36),
|
||||
ForeignKey("users_sessions.token_family_id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="UUID of the token family",
|
||||
)
|
||||
hashed_token = Column(
|
||||
String(255), nullable=False, unique=True, comment="Hashed old refresh token"
|
||||
)
|
||||
rotation_count = Column(
|
||||
Integer, nullable=False, comment="Which rotation this token belonged to"
|
||||
)
|
||||
rotated_at = Column(DateTime, nullable=False, comment="When this token was rotated")
|
||||
expires_at = Column(
|
||||
DateTime, nullable=False, comment="Cleanup marker (rotated_at + 60 seconds)"
|
||||
)
|
||||
|
||||
# Define a relationship to UsersSessions model
|
||||
user_session = relationship(
|
||||
"UsersSessions",
|
||||
back_populates="rotated_refresh_tokens",
|
||||
)
|
||||
54
backend/app/session/rotated_refresh_tokens/schema.py
Normal file
54
backend/app/session/rotated_refresh_tokens/schema.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class RotatedRefreshTokenCreate(BaseModel):
|
||||
"""
|
||||
Schema for creating a rotated refresh token record for security audit and reuse detection.
|
||||
|
||||
Attributes:
|
||||
token_family_id (str): UUID of the token family this token belongs to.
|
||||
hashed_token (str): Hashed old refresh token that was rotated out.
|
||||
rotation_count (int): Sequential rotation number for this token.
|
||||
rotated_at (datetime): Timestamp when this token was rotated.
|
||||
expires_at (datetime): Cleanup marker timestamp (rotated_at + 60 seconds).
|
||||
|
||||
Config:
|
||||
from_attributes (bool): Allows model initialization from attributes.
|
||||
extra (str): Forbids extra fields not defined in the model.
|
||||
validate_assignment (bool): Enables validation on assignment.
|
||||
"""
|
||||
|
||||
token_family_id: str = Field(..., description="UUID of the token family")
|
||||
hashed_token: str = Field(..., description="Hashed old refresh token")
|
||||
rotation_count: int = Field(
|
||||
..., description="Which rotation this token belonged to"
|
||||
)
|
||||
rotated_at: datetime = Field(..., description="When this token was rotated")
|
||||
expires_at: datetime = Field(
|
||||
..., description="Cleanup marker (rotated_at + 60 seconds)"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True, extra="forbid", validate_assignment=True
|
||||
)
|
||||
|
||||
|
||||
class RotatedRefreshTokenRead(RotatedRefreshTokenCreate):
|
||||
"""
|
||||
Schema for reading a rotated refresh token record from the database.
|
||||
|
||||
Inherits all attributes from RotatedRefreshTokenCreate and adds:
|
||||
id (int): Unique identifier for the rotated token record.
|
||||
|
||||
Config:
|
||||
from_attributes (bool): Allows model initialization from attributes.
|
||||
extra (str): Forbids extra fields not defined in the model.
|
||||
validate_assignment (bool): Enables validation on assignment.
|
||||
"""
|
||||
|
||||
id: int = Field(..., description="Unique identifier for the rotated token record")
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True, extra="forbid", validate_assignment=True
|
||||
)
|
||||
165
backend/app/session/rotated_refresh_tokens/utils.py
Normal file
165
backend/app/session/rotated_refresh_tokens/utils.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Utility functions for refresh token reuse detection."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.rotated_refresh_tokens.crud as rotated_token_crud
|
||||
import session.rotated_refresh_tokens.schema as rotated_token_schema
|
||||
import session.crud as session_crud
|
||||
import core.logger as core_logger
|
||||
from core.database import SessionLocal
|
||||
|
||||
|
||||
# Grace period for token reuse (60 seconds)
|
||||
# Allows for network retries/delays without false positives
|
||||
TOKEN_REUSE_GRACE_PERIOD_SECONDS = 60
|
||||
|
||||
|
||||
def store_rotated_token(
|
||||
hashed_token: str,
|
||||
token_family_id: str,
|
||||
rotation_count: int,
|
||||
db: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Store an old refresh token after rotation for reuse detection.
|
||||
|
||||
Args:
|
||||
hashed_token: The hashed refresh token being rotated out.
|
||||
token_family_id: UUID of the token family.
|
||||
rotation_count: Current rotation count for this token.
|
||||
db: Database session.
|
||||
|
||||
Raises:
|
||||
HTTPException: If storage fails (500).
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=TOKEN_REUSE_GRACE_PERIOD_SECONDS)
|
||||
|
||||
rotated_token = rotated_token_schema.RotatedRefreshTokenCreate(
|
||||
token_family_id=token_family_id,
|
||||
hashed_token=hashed_token,
|
||||
rotation_count=rotation_count,
|
||||
rotated_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
rotated_token_crud.create_rotated_token(rotated_token, db)
|
||||
|
||||
|
||||
def check_token_reuse(hashed_token: str, db: Session) -> tuple[bool, bool]:
|
||||
"""
|
||||
Check if a refresh token has been reused (already rotated).
|
||||
|
||||
Args:
|
||||
hashed_token: The hashed refresh token to check.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_reused, in_grace_period):
|
||||
- (False, False): Token is valid, not reused
|
||||
- (True, True): Reused but within 60s grace period
|
||||
- (True, False): Reused after grace period - THEFT!
|
||||
|
||||
Raises:
|
||||
HTTPException: If lookup fails (500).
|
||||
"""
|
||||
rotated_token = rotated_token_crud.get_rotated_token_by_hash(hashed_token, db)
|
||||
|
||||
if not rotated_token:
|
||||
return (False, False)
|
||||
|
||||
# Token was already rotated - check grace period
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if now <= rotated_token.expires_at:
|
||||
# Within grace period - might be legitimate retry
|
||||
core_logger.print_to_log(
|
||||
f"Token reuse within grace period for family "
|
||||
f"{rotated_token.token_family_id}",
|
||||
"warning",
|
||||
context={
|
||||
"token_family_id": rotated_token.token_family_id,
|
||||
"rotation_count": rotated_token.rotation_count,
|
||||
},
|
||||
)
|
||||
return (True, True)
|
||||
|
||||
# Past grace period - likely theft!
|
||||
core_logger.print_to_log(
|
||||
f"Token reuse detected after grace period for family "
|
||||
f"{rotated_token.token_family_id}",
|
||||
"error",
|
||||
context={
|
||||
"token_family_id": rotated_token.token_family_id,
|
||||
"rotation_count": rotated_token.rotation_count,
|
||||
"rotated_at": rotated_token.rotated_at.isoformat(),
|
||||
},
|
||||
)
|
||||
return (True, False)
|
||||
|
||||
|
||||
def invalidate_token_family(token_family_id: str, db: Session) -> int:
|
||||
"""
|
||||
Invalidate all sessions in a token family due to reuse detection.
|
||||
|
||||
Args:
|
||||
token_family_id: The family ID to invalidate.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
Number of sessions invalidated.
|
||||
|
||||
Raises:
|
||||
HTTPException: If invalidation fails (500).
|
||||
"""
|
||||
# Delete all sessions in the family
|
||||
num_sessions_deleted = session_crud.delete_sessions_by_family(token_family_id, db)
|
||||
|
||||
# Delete all rotated tokens for this family
|
||||
num_tokens_deleted = rotated_token_crud.delete_by_family(token_family_id, db)
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Invalidated token family {token_family_id} due to reuse: "
|
||||
f"{num_sessions_deleted} sessions, {num_tokens_deleted} tokens",
|
||||
"error",
|
||||
context={
|
||||
"token_family_id": token_family_id,
|
||||
"sessions_deleted": num_sessions_deleted,
|
||||
"tokens_deleted": num_tokens_deleted,
|
||||
},
|
||||
)
|
||||
|
||||
return num_sessions_deleted
|
||||
|
||||
|
||||
def cleanup_expired_rotated_tokens() -> None:
|
||||
"""
|
||||
Cleanup job to delete expired rotated tokens.
|
||||
|
||||
This function is called by the scheduler to periodically remove
|
||||
tokens that have exceeded the grace period. Should run every 5
|
||||
minutes.
|
||||
|
||||
Raises:
|
||||
Any exceptions are caught, logged, and not propagated to avoid
|
||||
breaking the scheduler.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
cutoff_time = datetime.now(timezone.utc)
|
||||
deleted_count = rotated_token_crud.delete_expired_tokens(cutoff_time, db)
|
||||
|
||||
if deleted_count > 0:
|
||||
core_logger.print_to_log(
|
||||
f"Cleaned up {deleted_count} expired rotated tokens",
|
||||
"info",
|
||||
)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Error in cleanup_expired_rotated_tokens: {err}",
|
||||
"error",
|
||||
exc=err,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
@@ -22,8 +22,6 @@ class UsersSessions(BaseModel):
|
||||
from_attributes (bool): Allows model initialization from attributes.
|
||||
extra (str): Forbids extra fields not defined in the model.
|
||||
validate_assignment (bool): Enables validation on assignment.
|
||||
Validators:
|
||||
expires_at: Ensures that the expiration timestamp is after the creation timestamp.
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="Unique session identifier")
|
||||
@@ -46,6 +44,15 @@ class UsersSessions(BaseModel):
|
||||
tokens_exchanged: bool = Field(
|
||||
default=False, description="Prevents duplicate token exchange for mobile"
|
||||
)
|
||||
token_family_id: str = Field(
|
||||
..., description="UUID identifying token family for reuse detection"
|
||||
)
|
||||
rotation_count: int = Field(
|
||||
default=0, description="Number of times refresh token has been rotated"
|
||||
)
|
||||
last_rotation_at: datetime | None = Field(
|
||||
None, description="Timestamp of last token rotation"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True, extra="forbid", validate_assignment=True
|
||||
|
||||
@@ -144,6 +144,9 @@ def create_session_object(
|
||||
expires_at=refresh_token_exp,
|
||||
oauth_state_id=oauth_state_id,
|
||||
tokens_exchanged=False,
|
||||
token_family_id=session_id,
|
||||
rotation_count=0,
|
||||
last_rotation_at=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -168,6 +171,9 @@ def edit_session_object(
|
||||
user_agent = get_user_agent(request)
|
||||
device_info = parse_user_agent(user_agent)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
new_rotation_count = session.rotation_count + 1
|
||||
|
||||
return session_schema.UsersSessions(
|
||||
id=session.id,
|
||||
user_id=session.user_id,
|
||||
@@ -179,10 +185,13 @@ def edit_session_object(
|
||||
browser=device_info.browser,
|
||||
browser_version=device_info.browser_version,
|
||||
created_at=session.created_at,
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
last_activity_at=now,
|
||||
expires_at=refresh_token_exp,
|
||||
oauth_state_id=session.oauth_state_id,
|
||||
tokens_exchanged=session.tokens_exchanged,
|
||||
token_family_id=session.token_family_id,
|
||||
rotation_count=new_rotation_count,
|
||||
last_rotation_at=now,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -48,6 +48,15 @@ export const useAuthStore = defineStore('auth', {
|
||||
actions: {
|
||||
async logoutUser(router = null, locale = null) {
|
||||
try {
|
||||
// Ensure we have fresh tokens before logout (handles page refresh case)
|
||||
if (!this.csrfToken || !this.accessToken) {
|
||||
try {
|
||||
await this.refreshAccessToken()
|
||||
} catch (refreshError) {
|
||||
console.error('Failed to refresh tokens before logout:', refreshError)
|
||||
// Continue with logout attempt even if refresh fails
|
||||
}
|
||||
}
|
||||
await session.logoutUser()
|
||||
} catch (error) {
|
||||
console.error('Error during logout:', error)
|
||||
|
||||
Reference in New Issue
Block a user