Ensure timezone-aware datetime comparisons and OAuth state cleanup

Updated identity provider service to use timezone-aware datetime comparisons for token age, refresh, and expiry checks. Added a function to delete OAuth state by ID and integrated OAuth state cleanup into session deletion to prevent orphaned records.
This commit is contained in:
João Vitória Silva
2025-12-18 17:30:52 +00:00
parent ded195a202
commit 0ba4d7123c
3 changed files with 54 additions and 4 deletions

View File

@@ -2013,8 +2013,9 @@ class IdentityProviderService:
)
return False
# Calculate token age
token_age = now - token_timestamp
# Calculate token age (ensure timezone-aware comparison - DB stores naive UTC)
token_timestamp_aware = token_timestamp.replace(tzinfo=timezone.utc)
token_age = now - token_timestamp_aware
# Check if token exceeds maximum age
max_age = timedelta(days=MAX_IDP_TOKEN_AGE_DAYS)
@@ -2076,13 +2077,19 @@ class IdentityProviderService:
# Check if token was refreshed very recently (rate limiting)
if link.idp_refresh_token_updated_at:
time_since_refresh = now - link.idp_refresh_token_updated_at
# Ensure timezone-aware comparison (DB stores naive UTC datetimes)
updated_at_aware = link.idp_refresh_token_updated_at.replace(
tzinfo=timezone.utc
)
time_since_refresh = now - updated_at_aware
if time_since_refresh < timedelta(minutes=TOKEN_REFRESH_RATE_LIMIT_MINUTES):
# Refreshed less than defined - don't refresh again
return TokenAction.SKIP
# Check if access token is close to expiry
time_until_expiry = link.idp_access_token_expires_at - now
# Ensure timezone-aware comparison (DB stores naive UTC datetimes)
expires_at_aware = link.idp_access_token_expires_at.replace(tzinfo=timezone.utc)
time_until_expiry = expires_at_aware - now
if time_until_expiry < timedelta(minutes=TOKEN_EXPIRY_THRESHOLD_MINUTES):
# Token expires soon - should refresh
return TokenAction.REFRESH

View File

@@ -1,6 +1,7 @@
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
import auth.oauth_state.models as oauth_state_models
import session.models as session_models
@@ -183,6 +184,40 @@ def mark_oauth_state_used(
return oauth_state
def delete_oauth_state(oauth_state_id: str, db: Session) -> int:
"""
Delete OAuth state for a specific OAuth state ID.
Args:
oauth_state_id: The OAuth state ID to delete tokens for.
db: Database session.
Returns:
Number of OAuth states deleted.
Raises:
HTTPException: If an error occurs during deletion (500).
"""
try:
num_deleted = (
db.query(oauth_state_models.OAuthState)
.filter(oauth_state_models.OAuthState.id == oauth_state_id)
.delete()
)
db.commit()
return num_deleted
except Exception as err:
db.rollback()
core_logger.print_to_log(
f"Error in delete_oauth_state: {err}", "error", exc=err
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete OAuth state",
) from err
def delete_expired_oauth_states(db: Session) -> int:
"""
Delete OAuth states older than 10 minutes.

View File

@@ -5,6 +5,7 @@ from fastapi import HTTPException, status
from sqlalchemy.orm import Session
import auth.oauth_state.models as oauth_state_models
import auth.oauth_state.crud as oauth_state_crud
import session.models as session_models
import session.schema as session_schema
import session.rotated_refresh_tokens.crud as rotated_tokens_crud
@@ -343,6 +344,9 @@ def delete_session(session_id: str, user_id: int, db: Session) -> None:
f"Session {session_id} not found for user {user_id}"
)
# Store oauth_state_id before deleting session (if exists)
oauth_state_id_to_delete = session.oauth_state_id
# Delete rotated tokens for this session's family (foreign key constraint)
rotated_tokens_crud.delete_by_family(session.token_family_id, db)
@@ -356,6 +360,10 @@ def delete_session(session_id: str, user_id: int, db: Session) -> None:
.delete()
)
# Delete OAuth state after session is deleted if exists
if oauth_state_id_to_delete:
oauth_state_crud.delete_oauth_state(oauth_state_id_to_delete, db)
# Commit the transaction
db.commit()
except SessionNotFoundError as err: