Implement refresh token rotation and reuse detection

Adds rotated refresh token tracking to detect and prevent token reuse attacks. Introduces new models, schemas, and utilities for storing and checking rotated tokens, and invalidates all sessions in a token family if reuse is detected. Updates session and authentication logic to support token families, rotation counts, and last rotation timestamps. Includes Alembic migration for new columns and tables, and schedules cleanup of expired rotated tokens. Also improves frontend logout to refresh tokens before logging out.
This commit is contained in:
João Vitória Silva
2025-12-18 12:32:13 +00:00
parent 04b489df7d
commit 4ee166fbfa
14 changed files with 671 additions and 48 deletions

View File

@@ -28,6 +28,7 @@ import password_reset_tokens.models
import sign_up_tokens.models
import server_settings.models
import session.models
import session.rotated_refresh_tokens.models
import users.user.models
import users.user_goals.models
import users.user_default_gear.models

View File

@@ -1,8 +1,8 @@
"""v0.16.4 migration
Revision ID: ef6cd7775aa2
Revision ID: ed5f1c867943
Revises: 2af2c0629b37
Create Date: 2025-12-16 12:47:18.298420
Create Date: 2025-12-18 12:02:47.808747
"""
@@ -13,7 +13,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "ef6cd7775aa2"
revision: str = "ed5f1c867943"
down_revision: Union[str, None] = "2af2c0629b37"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -94,12 +94,10 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(
["idp_id"],
["identity_providers.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
@@ -116,7 +114,20 @@ def upgrade() -> None:
op.create_index(
op.f("ix_oauth_states_user_id"), "oauth_states", ["user_id"], unique=False
)
# Add oauth_state_id and tokens_exchanged to users_sessions table
# Delete all existing sessions before altering user_sessions table
op.execute("DELETE FROM users_sessions")
# Add new columns to users_sessions
op.add_column(
"users_sessions",
sa.Column(
"last_activity_at",
sa.DateTime(),
nullable=False,
comment="Last activity timestamp for idle timeout",
),
)
op.add_column(
"users_sessions",
sa.Column(
@@ -131,23 +142,36 @@ def upgrade() -> None:
sa.Column(
"tokens_exchanged",
sa.Boolean(),
nullable=True,
nullable=False,
comment="Prevents duplicate token exchange for mobile",
),
)
op.execute(
"""
UPDATE users_sessions
SET tokens_exchanged = false
WHERE tokens_exchanged IS NULL;
"""
)
op.alter_column(
op.add_column(
"users_sessions",
"tokens_exchanged",
nullable=False,
comment="Prevents duplicate token exchange for mobile",
existing_type=sa.Boolean(),
sa.Column(
"token_family_id",
sa.String(length=36),
nullable=False,
comment="UUID identifying token family for reuse detection",
),
)
op.add_column(
"users_sessions",
sa.Column(
"rotation_count",
sa.Integer(),
nullable=False,
comment="Number of times refresh token has been rotated",
),
)
op.add_column(
"users_sessions",
sa.Column(
"last_rotation_at",
sa.DateTime(),
nullable=True,
comment="Timestamp of last token rotation",
),
)
op.create_index(
op.f("ix_users_sessions_oauth_state_id"),
@@ -155,43 +179,84 @@ def upgrade() -> None:
["oauth_state_id"],
unique=False,
)
op.create_index(
op.f("ix_users_sessions_token_family_id"),
"users_sessions",
["token_family_id"],
unique=True,
)
op.create_foreign_key(
None, "users_sessions", "oauth_states", ["oauth_state_id"], ["id"]
)
# Add last_activity_at column with default value = created_at
op.add_column(
"users_sessions",
# Create rotated_refresh_tokens table
op.create_table(
"rotated_refresh_tokens",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"last_activity_at",
sa.DateTime(),
nullable=True,
comment="Last activity timestamp for idle timeout",
"token_family_id",
sa.String(length=36),
nullable=False,
comment="UUID of the token family",
),
sa.Column(
"hashed_token",
sa.String(length=255),
nullable=False,
comment="Hashed old refresh token",
),
sa.Column(
"rotation_count",
sa.Integer(),
nullable=False,
comment="Which rotation this token belonged to",
),
sa.Column(
"rotated_at",
sa.DateTime(),
nullable=False,
comment="When this token was rotated",
),
sa.Column(
"expires_at",
sa.DateTime(),
nullable=False,
comment="Cleanup marker (rotated_at + 60 seconds)",
),
sa.ForeignKeyConstraint(
["token_family_id"],
["users_sessions.token_family_id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("hashed_token"),
)
# Backfill existing sessions: set last_activity_at = created_at
op.execute(
"UPDATE users_sessions SET last_activity_at = created_at WHERE last_activity_at IS NULL"
)
# Make column non-nullable after backfill
op.alter_column(
"users_sessions",
"last_activity_at",
nullable=False,
comment="Last activity timestamp for idle timeout",
existing_type=sa.DateTime(),
op.create_index(
op.f("ix_rotated_refresh_tokens_token_family_id"),
"rotated_refresh_tokens",
["token_family_id"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("users_sessions", "last_activity_at")
op.drop_constraint(None, "users_sessions", type_="foreignkey")
op.drop_index(
op.f("ix_users_sessions_token_family_id"), table_name="users_sessions"
)
op.drop_index(op.f("ix_users_sessions_oauth_state_id"), table_name="users_sessions")
op.drop_column("users_sessions", "last_rotation_at")
op.drop_column("users_sessions", "rotation_count")
op.drop_column("users_sessions", "token_family_id")
op.drop_column("users_sessions", "tokens_exchanged")
op.drop_column("users_sessions", "oauth_state_id")
op.drop_column("users_sessions", "last_activity_at")
op.drop_index(
op.f("ix_rotated_refresh_tokens_token_family_id"),
table_name="rotated_refresh_tokens",
)
op.drop_table("rotated_refresh_tokens")
op.drop_index(op.f("ix_oauth_states_user_id"), table_name="oauth_states")
op.drop_index(op.f("ix_oauth_states_used"), table_name="oauth_states")
op.drop_index(op.f("ix_oauth_states_idp_id"), table_name="oauth_states")

View File

@@ -31,6 +31,8 @@ import profile.utils as profile_utils
import core.database as core_database
import core.rate_limit as core_rate_limit
import session.rotated_refresh_tokens.utils as rotated_tokens_utils
# Define the API router
router = APIRouter()
@@ -325,13 +327,29 @@ async def refresh_token(
headers={"WWW-Authenticate": "Bearer"},
)
# NEW: Validate session hasn't exceeded idle or absolute timeout
# Validate session hasn't exceeded idle or absolute timeout
session_utils.validate_session_timeout(session)
# Check for token reuse BEFORE validating token
# Hash the incoming token to compare with rotated tokens
hashed_refresh_token = password_hasher.hash_password(refresh_token_value)
is_reused, in_grace = rotated_tokens_utils.check_token_reuse(
hashed_refresh_token, db
)
if is_reused and not in_grace:
# Token theft detected - invalidate entire family
rotated_tokens_utils.invalidate_token_family(session.token_family_id, db)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token reuse detected. All sessions invalidated.",
headers={"WWW-Authenticate": "Bearer"},
)
# Validate CSRF token matches session
# Note: CSRF token is stored in session during initial authentication
# For now, we validate that a CSRF token was provided (checked by middleware)
# Future enhancement: Store CSRF token hash in session for validation
# Note: CSRF token is stored in session during initial auth
# For now, we validate that a CSRF token was provided
# Future enhancement: Store CSRF token hash in session
is_valid = password_hasher.verify(refresh_token_value, session.refresh_token)
@@ -355,6 +373,15 @@ async def refresh_token(
# Check if the user is active
users_utils.check_user_is_active(user)
# Store old refresh token BEFORE rotating
# This enables detection if the old token is reused later
rotated_tokens_utils.store_rotated_token(
session.refresh_token,
session.token_family_id,
session.rotation_count,
db,
)
# Create the tokens
(
session_id,
@@ -365,7 +392,9 @@ async def refresh_token(
new_csrf_token,
) = auth_utils.create_tokens(user, token_manager, session.id)
# Edit the session and store it in the database
# Edit session and store in database
# Note: edit_session automatically increments rotation_count
# and updates last_rotation_at
session_utils.edit_session(session, request, new_refresh_token, password_hasher, db)
# Opportunistically refresh IdP tokens for all linked identity providers

View File

@@ -13,6 +13,8 @@ import sign_up_tokens.utils as sign_up_tokens_utils
import session.utils as session_utils
import session.rotated_refresh_tokens.utils as rotated_tokens_utils
import auth.oauth_state.utils as oauth_state_utils
import core.logger as core_logger
@@ -90,6 +92,14 @@ def start_scheduler():
"delete expired sessions from the database",
)
add_scheduler_job(
rotated_tokens_utils.cleanup_expired_rotated_tokens,
"interval",
5,
[],
"delete expired rotated tokens from the database",
)
def add_scheduler_job(func, interval, minutes, args, description):
try:

View File

@@ -405,3 +405,39 @@ def delete_idle_sessions(cutoff_time: datetime, db: Session) -> int:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete idle sessions",
) from err
def delete_sessions_by_family(token_family_id: str, db: Session) -> int:
"""
Delete all sessions belonging to a token family.
Args:
token_family_id: The family ID to delete sessions for.
db: The SQLAlchemy database session.
Returns:
Number of sessions deleted.
Raises:
HTTPException: If an error occurs during deletion (500).
"""
try:
num_deleted = (
db.query(session_models.UsersSessions)
.filter(session_models.UsersSessions.token_family_id == token_family_id)
.delete()
)
db.commit()
return num_deleted
except Exception as err:
db.rollback()
core_logger.print_to_log(
f"Error in delete_sessions_by_family: {err}",
"error",
exc=err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete sessions by family",
) from err

View File

@@ -28,6 +28,7 @@ class UsersSessions(Base):
last_activity_at (datetime): Timestamp of last user activity (for idle timeout).
expires_at (datetime): Timestamp when the session expires.
user (User): Relationship to the User model.
rotated_refresh_tokens (list): Rotated tokens for this session.
"""
__tablename__ = "users_sessions"
@@ -77,9 +78,32 @@ class UsersSessions(Base):
nullable=False,
comment="Prevents duplicate token exchange for mobile",
)
token_family_id = Column(
String(36),
nullable=False,
unique=True,
index=True,
comment="UUID identifying token family for reuse detection",
)
rotation_count = Column(
Integer,
default=0,
nullable=False,
comment="Number of times refresh token has been rotated",
)
last_rotation_at = Column(
DateTime, nullable=True, comment="Timestamp of last token rotation"
)
# Define a relationship to the User model
user = relationship("User", back_populates="users_sessions")
# Define a relationship to the OAuthState model
oauth_state = relationship("OAuthState", back_populates="users_sessions")
# Define a relationship to RotatedRefreshToken model
rotated_refresh_tokens = relationship(
"RotatedRefreshToken",
back_populates="user_session",
cascade="all, delete-orphan",
)

View File

@@ -0,0 +1,5 @@
"""Rotated refresh tokens module for token reuse detection."""
from session.rotated_refresh_tokens.models import RotatedRefreshToken
__all__ = ["RotatedRefreshToken"]

View File

@@ -0,0 +1,158 @@
"""CRUD operations for rotated refresh tokens."""
from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
import session.rotated_refresh_tokens.models as rotated_token_models
import session.rotated_refresh_tokens.schema as rotated_token_schema
import core.logger as core_logger
def get_rotated_token_by_hash(
hashed_token: str, db: Session
) -> rotated_token_models.RotatedRefreshToken | None:
"""
Retrieve a rotated token by its hashed value.
Args:
hashed_token: The hashed refresh token to search for.
db: Database session.
Returns:
The RotatedRefreshToken if found, None otherwise.
Raises:
HTTPException: If an error occurs during retrieval (500).
"""
try:
return (
db.query(rotated_token_models.RotatedRefreshToken)
.filter(
rotated_token_models.RotatedRefreshToken.hashed_token == hashed_token
)
.first()
)
except Exception as err:
core_logger.print_to_log(
f"Error in get_rotated_token_by_hash: {err}",
"error",
exc=err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve rotated token",
) from err
def create_rotated_token(
rotated_token: rotated_token_schema.RotatedRefreshTokenCreate,
db: Session,
) -> rotated_token_models.RotatedRefreshToken:
"""
Store a rotated refresh token in the database.
Args:
rotated_token: The rotated token data to store.
db: Database session.
Returns:
The created RotatedRefreshToken object.
Raises:
HTTPException: If an error occurs during creation (500).
"""
try:
db_rotated_token = rotated_token_models.RotatedRefreshToken(
token_family_id=rotated_token.token_family_id,
hashed_token=rotated_token.hashed_token,
rotation_count=rotated_token.rotation_count,
rotated_at=rotated_token.rotated_at,
expires_at=rotated_token.expires_at,
)
db.add(db_rotated_token)
db.commit()
db.refresh(db_rotated_token)
return db_rotated_token
except Exception as err:
db.rollback()
core_logger.print_to_log(
f"Error in create_rotated_token: {err}",
"error",
exc=err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to store rotated token",
) from err
def delete_expired_tokens(cutoff_time: datetime, db: Session) -> int:
"""
Delete rotated tokens older than the cutoff time.
Args:
cutoff_time: Tokens with expires_at before this will be deleted.
db: Database session.
Returns:
Number of tokens deleted.
Raises:
HTTPException: If an error occurs during deletion (500).
"""
try:
num_deleted = (
db.query(rotated_token_models.RotatedRefreshToken)
.filter(rotated_token_models.RotatedRefreshToken.expires_at < cutoff_time)
.delete()
)
db.commit()
return num_deleted
except Exception as err:
db.rollback()
core_logger.print_to_log(
f"Error in delete_expired_tokens: {err}", "error", exc=err
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete expired tokens",
) from err
def delete_by_family(token_family_id: str, db: Session) -> int:
"""
Delete all rotated tokens for a specific token family.
Args:
token_family_id: The family ID to delete tokens for.
db: Database session.
Returns:
Number of tokens deleted.
Raises:
HTTPException: If an error occurs during deletion (500).
"""
try:
num_deleted = (
db.query(rotated_token_models.RotatedRefreshToken)
.filter(
rotated_token_models.RotatedRefreshToken.token_family_id
== token_family_id
)
.delete()
)
db.commit()
return num_deleted
except Exception as err:
db.rollback()
core_logger.print_to_log(f"Error in delete_by_family: {err}", "error", exc=err)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete tokens by family",
) from err

View File

@@ -0,0 +1,51 @@
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
ForeignKey,
)
from sqlalchemy.orm import relationship
from core.database import Base
class RotatedRefreshToken(Base):
"""
Represents a rotated refresh token in the system.
Attributes:
id: Unique identifier for the rotated token.
token_family_id: UUID of the token family.
hashed_token: Hashed old refresh token.
rotation_count: Which rotation this token belonged to.
rotated_at: When this token was rotated.
expires_at: Cleanup marker (rotated_at + 60 seconds).
user_session: Relationship to UsersSessions model.
"""
__tablename__ = "rotated_refresh_tokens"
id = Column(Integer, primary_key=True, autoincrement=True)
token_family_id = Column(
String(36),
ForeignKey("users_sessions.token_family_id"),
nullable=False,
index=True,
comment="UUID of the token family",
)
hashed_token = Column(
String(255), nullable=False, unique=True, comment="Hashed old refresh token"
)
rotation_count = Column(
Integer, nullable=False, comment="Which rotation this token belonged to"
)
rotated_at = Column(DateTime, nullable=False, comment="When this token was rotated")
expires_at = Column(
DateTime, nullable=False, comment="Cleanup marker (rotated_at + 60 seconds)"
)
# Define a relationship to UsersSessions model
user_session = relationship(
"UsersSessions",
back_populates="rotated_refresh_tokens",
)

View File

@@ -0,0 +1,54 @@
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime
class RotatedRefreshTokenCreate(BaseModel):
"""
Schema for creating a rotated refresh token record for security audit and reuse detection.
Attributes:
token_family_id (str): UUID of the token family this token belongs to.
hashed_token (str): Hashed old refresh token that was rotated out.
rotation_count (int): Sequential rotation number for this token.
rotated_at (datetime): Timestamp when this token was rotated.
expires_at (datetime): Cleanup marker timestamp (rotated_at + 60 seconds).
Config:
from_attributes (bool): Allows model initialization from attributes.
extra (str): Forbids extra fields not defined in the model.
validate_assignment (bool): Enables validation on assignment.
"""
token_family_id: str = Field(..., description="UUID of the token family")
hashed_token: str = Field(..., description="Hashed old refresh token")
rotation_count: int = Field(
..., description="Which rotation this token belonged to"
)
rotated_at: datetime = Field(..., description="When this token was rotated")
expires_at: datetime = Field(
..., description="Cleanup marker (rotated_at + 60 seconds)"
)
model_config = ConfigDict(
from_attributes=True, extra="forbid", validate_assignment=True
)
class RotatedRefreshTokenRead(RotatedRefreshTokenCreate):
"""
Schema for reading a rotated refresh token record from the database.
Inherits all attributes from RotatedRefreshTokenCreate and adds:
id (int): Unique identifier for the rotated token record.
Config:
from_attributes (bool): Allows model initialization from attributes.
extra (str): Forbids extra fields not defined in the model.
validate_assignment (bool): Enables validation on assignment.
"""
id: int = Field(..., description="Unique identifier for the rotated token record")
model_config = ConfigDict(
from_attributes=True, extra="forbid", validate_assignment=True
)

View File

@@ -0,0 +1,165 @@
"""Utility functions for refresh token reuse detection."""
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
import session.rotated_refresh_tokens.crud as rotated_token_crud
import session.rotated_refresh_tokens.schema as rotated_token_schema
import session.crud as session_crud
import core.logger as core_logger
from core.database import SessionLocal
# Grace period for token reuse (60 seconds)
# Allows for network retries/delays without false positives
TOKEN_REUSE_GRACE_PERIOD_SECONDS = 60
def store_rotated_token(
hashed_token: str,
token_family_id: str,
rotation_count: int,
db: Session,
) -> None:
"""
Store an old refresh token after rotation for reuse detection.
Args:
hashed_token: The hashed refresh token being rotated out.
token_family_id: UUID of the token family.
rotation_count: Current rotation count for this token.
db: Database session.
Raises:
HTTPException: If storage fails (500).
"""
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=TOKEN_REUSE_GRACE_PERIOD_SECONDS)
rotated_token = rotated_token_schema.RotatedRefreshTokenCreate(
token_family_id=token_family_id,
hashed_token=hashed_token,
rotation_count=rotation_count,
rotated_at=now,
expires_at=expires_at,
)
rotated_token_crud.create_rotated_token(rotated_token, db)
def check_token_reuse(hashed_token: str, db: Session) -> tuple[bool, bool]:
"""
Check if a refresh token has been reused (already rotated).
Args:
hashed_token: The hashed refresh token to check.
db: Database session.
Returns:
Tuple of (is_reused, in_grace_period):
- (False, False): Token is valid, not reused
- (True, True): Reused but within 60s grace period
- (True, False): Reused after grace period - THEFT!
Raises:
HTTPException: If lookup fails (500).
"""
rotated_token = rotated_token_crud.get_rotated_token_by_hash(hashed_token, db)
if not rotated_token:
return (False, False)
# Token was already rotated - check grace period
now = datetime.now(timezone.utc)
if now <= rotated_token.expires_at:
# Within grace period - might be legitimate retry
core_logger.print_to_log(
f"Token reuse within grace period for family "
f"{rotated_token.token_family_id}",
"warning",
context={
"token_family_id": rotated_token.token_family_id,
"rotation_count": rotated_token.rotation_count,
},
)
return (True, True)
# Past grace period - likely theft!
core_logger.print_to_log(
f"Token reuse detected after grace period for family "
f"{rotated_token.token_family_id}",
"error",
context={
"token_family_id": rotated_token.token_family_id,
"rotation_count": rotated_token.rotation_count,
"rotated_at": rotated_token.rotated_at.isoformat(),
},
)
return (True, False)
def invalidate_token_family(token_family_id: str, db: Session) -> int:
"""
Invalidate all sessions in a token family due to reuse detection.
Args:
token_family_id: The family ID to invalidate.
db: Database session.
Returns:
Number of sessions invalidated.
Raises:
HTTPException: If invalidation fails (500).
"""
# Delete all sessions in the family
num_sessions_deleted = session_crud.delete_sessions_by_family(token_family_id, db)
# Delete all rotated tokens for this family
num_tokens_deleted = rotated_token_crud.delete_by_family(token_family_id, db)
core_logger.print_to_log(
f"Invalidated token family {token_family_id} due to reuse: "
f"{num_sessions_deleted} sessions, {num_tokens_deleted} tokens",
"error",
context={
"token_family_id": token_family_id,
"sessions_deleted": num_sessions_deleted,
"tokens_deleted": num_tokens_deleted,
},
)
return num_sessions_deleted
def cleanup_expired_rotated_tokens() -> None:
"""
Cleanup job to delete expired rotated tokens.
This function is called by the scheduler to periodically remove
tokens that have exceeded the grace period. Should run every 5
minutes.
Raises:
Any exceptions are caught, logged, and not propagated to avoid
breaking the scheduler.
"""
db = SessionLocal()
try:
cutoff_time = datetime.now(timezone.utc)
deleted_count = rotated_token_crud.delete_expired_tokens(cutoff_time, db)
if deleted_count > 0:
core_logger.print_to_log(
f"Cleaned up {deleted_count} expired rotated tokens",
"info",
)
except Exception as err:
core_logger.print_to_log(
f"Error in cleanup_expired_rotated_tokens: {err}",
"error",
exc=err,
)
finally:
db.close()

View File

@@ -22,8 +22,6 @@ class UsersSessions(BaseModel):
from_attributes (bool): Allows model initialization from attributes.
extra (str): Forbids extra fields not defined in the model.
validate_assignment (bool): Enables validation on assignment.
Validators:
expires_at: Ensures that the expiration timestamp is after the creation timestamp.
"""
id: str = Field(..., description="Unique session identifier")
@@ -46,6 +44,15 @@ class UsersSessions(BaseModel):
tokens_exchanged: bool = Field(
default=False, description="Prevents duplicate token exchange for mobile"
)
token_family_id: str = Field(
..., description="UUID identifying token family for reuse detection"
)
rotation_count: int = Field(
default=0, description="Number of times refresh token has been rotated"
)
last_rotation_at: datetime | None = Field(
None, description="Timestamp of last token rotation"
)
model_config = ConfigDict(
from_attributes=True, extra="forbid", validate_assignment=True

View File

@@ -144,6 +144,9 @@ def create_session_object(
expires_at=refresh_token_exp,
oauth_state_id=oauth_state_id,
tokens_exchanged=False,
token_family_id=session_id,
rotation_count=0,
last_rotation_at=None,
)
@@ -168,6 +171,9 @@ def edit_session_object(
user_agent = get_user_agent(request)
device_info = parse_user_agent(user_agent)
now = datetime.now(timezone.utc)
new_rotation_count = session.rotation_count + 1
return session_schema.UsersSessions(
id=session.id,
user_id=session.user_id,
@@ -179,10 +185,13 @@ def edit_session_object(
browser=device_info.browser,
browser_version=device_info.browser_version,
created_at=session.created_at,
last_activity_at=datetime.now(timezone.utc),
last_activity_at=now,
expires_at=refresh_token_exp,
oauth_state_id=session.oauth_state_id,
tokens_exchanged=session.tokens_exchanged,
token_family_id=session.token_family_id,
rotation_count=new_rotation_count,
last_rotation_at=now,
)

View File

@@ -48,6 +48,15 @@ export const useAuthStore = defineStore('auth', {
actions: {
async logoutUser(router = null, locale = null) {
try {
// Ensure we have fresh tokens before logout (handles page refresh case)
if (!this.csrfToken || !this.accessToken) {
try {
await this.refreshAccessToken()
} catch (refreshError) {
console.error('Failed to refresh tokens before logout:', refreshError)
// Continue with logout attempt even if refresh fails
}
}
await session.logoutUser()
} catch (error) {
console.error('Error during logout:', error)