Implement OAuth 2.1 CSRF bootstrap pattern for refresh

Adopts the OAuth 2.1 bootstrap pattern by not storing the CSRF token hash on initial login or token exchange, allowing the first /refresh call after a page reload to establish the CSRF binding. Updates CSRF validation logic to only require the CSRF token if provided, and documents the security model. Exempts the /refresh endpoint from CSRF middleware for the bootstrap scenario. Also ensures rotated refresh tokens are deleted when a session is deleted.
This commit is contained in:
João Vitória Silva
2025-12-18 17:06:53 +00:00
parent dc7990875c
commit ded195a202
5 changed files with 47 additions and 16 deletions

View File

@@ -464,9 +464,10 @@ async def exchange_tokens_for_session(
(access_token_exp - datetime.now(timezone.utc)).total_seconds()
)
# Update session with the actual hashed refresh token and CSRF hash
# Update session with the actual hashed refresh token
# Note: csrf_token_hash is NOT stored here (OAuth 2.1 bootstrap pattern).
# The first /refresh call after page reload establishes the CSRF binding.
session_obj.refresh_token = password_hasher.hash_password(refresh_token)
session_obj.csrf_token_hash = password_hasher.hash_password(csrf_token)
db.commit()
# Set refresh token cookie for web clients (enables logout)

View File

@@ -302,6 +302,13 @@ async def refresh_token(
This endpoint validates the provided refresh token, checks session status,
validates the CSRF token (web clients only), and issues new tokens.
OAuth 2.1 Bootstrap Pattern for Page Reload:
On page reload, in-memory tokens are lost but httpOnly cookie persists.
- If no CSRF header: Allow refresh (page reload scenario)
- If CSRF header provided: Validate it (legitimate request with cached token)
- Security: httpOnly cookie + SameSite=Strict prevents CSRF at browser level
- CSRF validation adds defense-in-depth but is not the primary protection
Args:
response: The HTTP response object.
request: The HTTP request object.
@@ -313,14 +320,14 @@ async def refresh_token(
token_manager: Utility for creating tokens.
db: Database session.
client_type: Client type (\"web\" or \"mobile\").
x_csrf_token: CSRF token header (web clients only, via dependency).
x_csrf_token: CSRF token header (web clients only, optional on page reload).
Returns:
dict: Contains session_id, access_token, csrf_token, token_type, expires_in.
Raises:
HTTPException: If session not found, refresh token invalid,
user is inactive, or CSRF token is missing/invalid.
user is inactive, or CSRF token is invalid (when provided).
"""
# Get the session from the database
session = session_crud.get_session_by_id(token_session_id, db)
@@ -338,11 +345,15 @@ async def refresh_token(
# Verify CSRF token for web clients only
# Mobile clients don't use CSRF tokens
# Note: Middleware already checks header presence for web clients
if client_type == "web":
if not x_csrf_token or not password_hasher.verify(
x_csrf_token, session.csrf_token_hash
):
# OAuth 2.1 Bootstrap Pattern for page reload:
# - On page reload, in-memory tokens are lost but httpOnly cookie persists
# - If x_csrf_token is None (missing from request), allow refresh anyway
# - Security: httpOnly cookie + SameSite=Strict prevents CSRF at browser level
# - If x_csrf_token is provided, it MUST be valid (prevent partial CSRF)
# - CSRF token is defense-in-depth; SameSite=Strict is primary protection
if client_type == "web" and x_csrf_token and session.csrf_token_hash is not None:
# CSRF token was provided: validate it
if not password_hasher.verify(x_csrf_token, session.csrf_token_hash):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid CSRF token",

View File

@@ -174,6 +174,10 @@ def complete_login(
) = create_tokens(user, token_manager)
# Create the session and store it in the database
# Note: csrf_token is NOT stored on initial login (csrf_token_hash = None).
# This enables the OAuth 2.1 bootstrap pattern where the first /refresh call
# after page reload establishes the CSRF binding. The httpOnly cookie is
# sufficient authentication for the bootstrap refresh.
session_utils.create_session(
session_id,
user,
@@ -181,7 +185,6 @@ def complete_login(
refresh_token,
password_hasher,
db,
csrf_token=csrf_token,
)
# Access token and CSRF token returned in body for in-memory storage

View File

@@ -115,6 +115,7 @@ class CSRFMiddleware(BaseHTTPMiddleware):
self.exempt_paths = [
"/api/v1/auth/login",
"/api/v1/auth/mfa/verify",
"/api/v1/auth/refresh", # Bootstrap pattern: first refresh has no CSRF
"/api/v1/password-reset/request",
"/api/v1/password-reset/confirm",
"/api/v1/sign-up/request",

View File

@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
import auth.oauth_state.models as oauth_state_models
import session.models as session_models
import session.schema as session_schema
import session.rotated_refresh_tokens.crud as rotated_tokens_crud
import core.logger as core_logger
@@ -321,10 +322,30 @@ def delete_session(session_id: str, user_id: int, db: Session) -> None:
HTTPException: If the session is not found (404) or if an error occurs during deletion (500).
Notes:
- Deletes rotated tokens associated with the session before deleting the session
- Rolls back the transaction and logs the error if an unexpected exception occurs.
- Commits the transaction if the session is successfully deleted.
"""
try:
# Get the session to retrieve token_family_id before deletion
session = (
db.query(session_models.UsersSessions)
.filter(
session_models.UsersSessions.id == session_id,
session_models.UsersSessions.user_id == user_id,
)
.first()
)
# Check if the session was found
if session is None:
raise SessionNotFoundError(
f"Session {session_id} not found for user {user_id}"
)
# Delete rotated tokens for this session's family (foreign key constraint)
rotated_tokens_crud.delete_by_family(session.token_family_id, db)
# Delete the session
num_deleted = (
db.query(session_models.UsersSessions)
@@ -335,12 +356,6 @@ def delete_session(session_id: str, user_id: int, db: Session) -> None:
.delete()
)
# Check if the session was found and deleted
if num_deleted == 0:
raise SessionNotFoundError(
f"Session {session_id} not found for user {user_id}"
)
# Commit the transaction
db.commit()
except SessionNotFoundError as err: