Add MFA backup codes support for user authentication

Introduces database models, migration, API endpoints, and business logic for MFA backup codes as a fallback authentication method. Users can generate, view status, and consume backup codes; codes are securely hashed and invalidated upon use. Integrates backup code verification into MFA flows, updates user and profile logic, and ensures codes are managed on MFA enable/disable actions.
This commit is contained in:
João Vitória Silva
2025-12-18 23:08:27 +00:00
parent 0ba4d7123c
commit 17ef865b5c
13 changed files with 618 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ from alembic import context
import auth.identity_providers.models
import auth.mfa_backup_codes.models
import auth.oauth_state.models
import activities.activity.models
import activities.activity_exercise_titles.models

View File

@@ -245,11 +245,69 @@ def upgrade() -> None:
["token_family_id"],
unique=False,
)
# Create mfa_backup_codes table
op.create_table(
"mfa_backup_codes",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"user_id",
sa.Integer(),
nullable=False,
comment="User who owns this backup code",
),
sa.Column(
"code_hash",
sa.String(length=255),
nullable=False,
comment="Argon2 hash of the backup code",
),
sa.Column(
"used",
sa.Boolean(),
nullable=False,
comment="Whether this code has been consumed",
),
sa.Column(
"used_at", sa.DateTime(), nullable=True, comment="When this code was used"
),
sa.Column(
"created_at",
sa.DateTime(),
nullable=False,
comment="When this code was generated",
),
sa.Column(
"expires_at",
sa.DateTime(),
nullable=True,
comment="Optional expiry for code rotation policy",
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("code_hash"),
)
op.create_index(
"idx_user_unused_codes", "mfa_backup_codes", ["user_id", "used"], unique=False
)
op.create_index(
op.f("ix_mfa_backup_codes_used"), "mfa_backup_codes", ["used"], unique=False
)
op.create_index(
op.f("ix_mfa_backup_codes_user_id"),
"mfa_backup_codes",
["user_id"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_mfa_backup_codes_user_id"), table_name="mfa_backup_codes")
op.drop_index(op.f("ix_mfa_backup_codes_used"), table_name="mfa_backup_codes")
op.drop_index("idx_user_unused_codes", table_name="mfa_backup_codes")
op.drop_table("mfa_backup_codes")
op.drop_constraint(None, "users_sessions", type_="foreignkey")
op.drop_index(
op.f("ix_users_sessions_token_family_id"), table_name="users_sessions"

View File

@@ -0,0 +1,170 @@
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
import auth.mfa_backup_codes.models as mfa_backup_codes_models
import auth.mfa_backup_codes.utils as mfa_backup_codes_utils
import auth.password_hasher as auth_password_hasher
import core.logger as core_logger
def get_user_backup_codes(
user_id: int, db: Session
) -> list[mfa_backup_codes_models.MFABackupCode]:
try:
return (
db.query(mfa_backup_codes_models.MFABackupCode)
.filter(
mfa_backup_codes_models.MFABackupCode.user_id == user_id,
)
.all()
)
except Exception as err:
core_logger.print_to_log(
f"Error in get_user_backup_codes: {err}", "error", exc=err
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve backup codes",
) from err
def get_user_unused_backup_codes(
user_id: int, db: Session
) -> list[mfa_backup_codes_models.MFABackupCode]:
try:
return (
db.query(mfa_backup_codes_models.MFABackupCode)
.filter(
mfa_backup_codes_models.MFABackupCode.user_id == user_id,
mfa_backup_codes_models.MFABackupCode.used == False,
)
.all()
)
except Exception as err:
core_logger.print_to_log(
f"Error in get_user_unused_backup_codes: {err}", "error", exc=err
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve unused backup codes",
) from err
def create_backup_codes(
user_id: int,
password_hasher: auth_password_hasher.PasswordHasher,
db: Session,
count: int = 10,
) -> list[str]:
try:
delete_user_backup_codes(user_id, db)
plaintext_codes = []
for _ in range(count):
# Generate unique code
code = mfa_backup_codes_utils.generate_backup_code()
# Hash the code (bcrypt)
code_hash = password_hasher.hash_password(code)
# Store hash in database
backup_code = mfa_backup_codes_models.MFABackupCode(
user_id=user_id,
code_hash=code_hash,
created_at=datetime.now(timezone.utc),
)
db.add(backup_code)
# Keep plaintext for return (only time it's available)
plaintext_codes.append(code)
db.commit()
core_logger.print_to_log(f"Created backup codes for user ID {user_id}", "info")
return plaintext_codes
except HTTPException as err:
raise err
except Exception as err:
core_logger.print_to_log(
f"Error creating backup codes for user ID {user_id}: {err}",
"error",
exc=err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to regenerate backup codes",
) from err
def mark_backup_code_as_used(
backup_code_hashed: str, user_id: int, db: Session
) -> None:
try:
db_code = (
db.query(mfa_backup_codes_models.MFABackupCode)
.filter(
mfa_backup_codes_models.MFABackupCode.user_id == user_id,
mfa_backup_codes_models.MFABackupCode.code_hash == backup_code_hashed,
mfa_backup_codes_models.MFABackupCode.used == False,
)
.first()
)
if db_code:
db_code.used = True
db_code.used_at = datetime.now(timezone.utc)
db.commit()
db.refresh(db_code)
core_logger.print_to_log(
f"Marked backup code as used for user ID {user_id}", "info"
)
else:
core_logger.print_to_log(
f"No unused backup code found to mark as used for user ID {user_id}",
"warning",
)
HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Backup code not found or already used",
)
except Exception as err:
db.rollback()
core_logger.print_to_log(
f"Error in mark_backup_code_as_used: {err}", "error", exc=err
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error",
) from err
def delete_user_backup_codes(user_id: int, db: Session) -> int:
try:
# Delete existing codes
num_deleted = (
db.query(mfa_backup_codes_models.MFABackupCode)
.filter(mfa_backup_codes_models.MFABackupCode.user_id == user_id)
.delete()
)
db.commit()
core_logger.print_to_log(
f"Deleted {num_deleted} backup codes for user ID: {user_id}", "info"
)
return num_deleted
except Exception as err:
db.rollback()
core_logger.print_to_log(
f"Error in delete_user_backup_codes: {err}", "error", exc=err
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error",
) from err

View File

@@ -0,0 +1,72 @@
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Index
from sqlalchemy.orm import relationship
from datetime import datetime, timezone
from core.database import Base
class MFABackupCode(Base):
"""
SQLAlchemy model for MFA backup codes.
This model stores hashed backup codes that users can use as a fallback
authentication method when their primary MFA device is unavailable.
Attributes:
id (int): Primary key, auto-incrementing identifier.
user_id (int): Foreign key to the users table, identifies the code owner.
code_hash (str): Argon2 hash of the backup code for secure storage.
used (bool): Flag indicating whether the code has been consumed.
used_at (datetime): Timestamp when the code was used, if applicable.
created_at (datetime): Timestamp when the code was generated (UTC).
expires_at (datetime): Optional expiration timestamp for code rotation.
Relationships:
user: Many-to-one relationship with the User model.
Indexes:
- Primary index on user_id for foreign key constraint
- Unique index on code_hash to prevent duplicates
- Index on used for filtering consumed codes
- Composite index (idx_user_unused_codes) on user_id and used for
efficient lookups of available backup codes per user
"""
__tablename__ = "mfa_backup_codes"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(
Integer,
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="User who owns this backup code",
)
code_hash = Column(
String(255),
nullable=False,
unique=True,
comment="Argon2 hash of the backup code",
)
used = Column(
Boolean,
default=False,
nullable=False,
index=True,
comment="Whether this code has been consumed",
)
used_at = Column(DateTime, nullable=True, comment="When this code was used")
created_at = Column(
DateTime,
nullable=False,
default=lambda: datetime.now(timezone.utc),
comment="When this code was generated",
)
expires_at = Column(
DateTime, nullable=True, comment="Optional expiry for code rotation policy"
)
# Establish relationship back to User model
user = relationship("User", back_populates="mfa_backup_codes")
# Composite index for fast unused code lookups
__table_args__ = (Index("idx_user_unused_codes", "user_id", "used"),)

View File

@@ -0,0 +1,127 @@
import os
from datetime import datetime, timedelta, timezone
from typing import Annotated, Callable
from fastapi import (
APIRouter,
Depends,
HTTPException,
status,
Response,
Request,
)
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
import session.utils as session_utils
import auth.security as auth_security
import auth.utils as auth_utils
import auth.constants as auth_constants
import session.crud as session_crud
import auth.password_hasher as auth_password_hasher
import auth.token_manager as auth_token_manager
import auth.schema as auth_schema
import auth.mfa_backup_codes.schema as mfa_backup_codes_schema
import auth.mfa_backup_codes.crud as mfa_backup_codes_crud
import auth.identity_providers.utils as idp_utils
import users.user.crud as users_crud
import users.user.utils as users_utils
import profile.utils as profile_utils
import core.database as core_database
import core.rate_limit as core_rate_limit
import core.logger as core_logger
import session.rotated_refresh_tokens.utils as rotated_tokens_utils
# Define the API router
router = APIRouter()
@router.get(
"/status",
response_model=mfa_backup_codes_schema.MFABackupCodeStatus,
)
async def get_backup_code_status(
token_user_id: Annotated[
int,
Depends(auth_security.get_sub_from_refresh_token),
],
db: Annotated[
Session,
Depends(core_database.get_db),
],
):
codes = mfa_backup_codes_crud.get_user_backup_codes(token_user_id, db)
if not codes:
return mfa_backup_codes_schema.MFABackupCodeStatus(
has_codes=False,
total=0,
unused=0,
used=0,
created_at=None,
)
unused = sum(1 for code in codes if not code.used)
used = sum(1 for code in codes if code.used)
created_at = codes[0].created_at if codes else None
return mfa_backup_codes_schema.MFABackupCodeStatus(
has_codes=True,
total=len(codes),
unused=unused,
used=used,
created_at=created_at,
)
@router.post(
"",
response_model=mfa_backup_codes_schema.MFABackupCodesResponse,
)
@core_rate_limit.limiter.limit(core_rate_limit.MFA_VERIFY_LIMIT)
async def generate_mfa_backup_codes(
response: Response,
request: Request,
token_user_id: Annotated[
int,
Depends(auth_security.get_sub_from_refresh_token),
],
password_hasher: Annotated[
auth_password_hasher.PasswordHasher,
Depends(auth_password_hasher.get_password_hasher),
],
db: Annotated[
Session,
Depends(core_database.get_db),
],
):
user = users_crud.get_user_by_id(token_user_id, db)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if not user.mfa_enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MFA must be enabled to generate backup codes",
)
# Generate codes (invalidates old codes)
codes = mfa_backup_codes_crud.create_backup_codes(
token_user_id, password_hasher, db
)
# Log event
core_logger.print_to_log(f"User {user.id} generated MFA backup codes", "info")
return mfa_backup_codes_schema.MFABackupCodesResponse(
codes=codes,
created_at=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,15 @@
from pydantic import BaseModel, Field
from datetime import datetime
class MFABackupCodesResponse(BaseModel):
codes: list[str] = Field(..., description="10 one-time backup codes")
created_at: datetime
class MFABackupCodeStatus(BaseModel):
has_codes: bool
total: int
unused: int
used: int
created_at: datetime | None = None

View File

@@ -0,0 +1,58 @@
import secrets
import string
from datetime import datetime, timezone
import auth.mfa_backup_codes.crud as mfa_backup_codes_crud
import auth.password_hasher as auth_password_hasher
from sqlalchemy.orm import Session
def generate_backup_code() -> str:
"""
Generate a cryptographically secure 8-character backup code.
Format: XXXX-XXXX (uppercase alphanumeric, no ambiguous chars)
Excludes: 0, O, 1, I, l (to prevent confusion)
Entropy: ~40 bits per code (8 chars from 32-char alphabet)
Returns:
str: Formatted backup code (e.g., "A3K9-7BDF")
"""
alphabet = string.ascii_uppercase + string.digits
# Remove ambiguous characters
alphabet = (
alphabet.replace("0", "").replace("O", "").replace("1", "").replace("I", "")
)
# Generate 8 random characters
code = "".join(secrets.choice(alphabet) for _ in range(8))
# Format as XXXX-XXXX
return f"{code[:4]}-{code[4:]}"
def verify_and_consume_backup_code(
user_id: int,
code: str,
password_hasher: auth_password_hasher.PasswordHasher,
db: Session,
) -> bool:
# Get all unused codes for this user
unused_codes = mfa_backup_codes_crud.get_user_unused_backup_codes(user_id, db)
# Try each unused code (constant-time for each)
for unused_code in unused_codes:
if password_hasher.verify(code, unused_code.code_hash):
# Valid code found - mark as used
mfa_backup_codes_crud.mark_backup_code_as_used(
unused_code.code_hash, user_id, db
)
# Return success
return True
# No matching code found
return False

View File

@@ -227,8 +227,10 @@ async def verify_mfa_and_login(
detail="No pending MFA login found for this username",
)
# Verify the MFA code
if not profile_utils.verify_user_mfa(user_id, mfa_request.mfa_code, db):
# Verify the MFA code (TOTP or backup code)
if not profile_utils.verify_user_mfa(
user_id, mfa_request.mfa_code, password_hasher, db
):
# Record failed attempt and apply lockout if threshold exceeded
failed_count = pending_mfa_store.record_failed_attempt(mfa_request.username)
raise HTTPException(

View File

@@ -16,6 +16,10 @@ import activities.activity_summaries.router as activity_summaries_router
import activities.activity_workout_steps.router as activity_workout_steps_router
import activities.activity_workout_steps.public_router as activity_workout_steps_public_router
import auth.router as auth_router
import auth.mfa_backup_codes.router as mfa_backup_codes_router
import auth.identity_providers.router as identity_providers_router
import auth.identity_providers.public_router as identity_providers_public_router
import auth.security as auth_security
import core.config as core_config
import core.router as core_router
import followers.router as followers_router
@@ -26,8 +30,6 @@ import health_sleep.router as health_sleep_router
import health_weight.router as health_weight_router
import health_steps.router as health_steps_router
import health_targets.router as health_targets_router
import auth.identity_providers.router as identity_providers_router
import auth.identity_providers.public_router as identity_providers_public_router
import notifications.router as notifications_router
import password_reset_tokens.router as password_reset_tokens_router
import profile.browser_redirect_router as profile_browser_redirect_router
@@ -35,7 +37,6 @@ import profile.router as profile_router
import server_settings.public_router as server_settings_public_router
import server_settings.router as server_settings_router
import session.router as session_router
import auth.security as auth_security
import sign_up_tokens.router as sign_up_tokens_router
import strava.router as strava_router
import users.user.router as users_router
@@ -102,6 +103,12 @@ router.include_router(
prefix=core_config.ROOT_PATH + "/auth",
tags=["auth"],
)
router.include_router(
mfa_backup_codes_router.router,
prefix=core_config.ROOT_PATH + "/auth/mfa/backup-codes",
tags=["auth"],
dependencies=[Depends(auth_security.validate_access_token)],
)
router.include_router(
followers_router.router,
prefix=core_config.ROOT_PATH + "/followers",

View File

@@ -623,6 +623,10 @@ async def enable_mfa(
int,
Depends(auth_security.get_sub_from_access_token),
],
password_hasher: Annotated[
auth_password_hasher.PasswordHasher,
Depends(auth_password_hasher.get_password_hasher),
],
db: Annotated[Session, Depends(core_database.get_db)],
mfa_secret_store: Annotated[
profile_schema.MFASecretStore, Depends(profile_schema.get_mfa_secret_store)
@@ -634,6 +638,7 @@ async def enable_mfa(
Args:
request: MFA setup request with code.
token_user_id: User ID from access token.
password_hasher: Password hasher instance for backup code generation.
db: Database session.
mfa_secret_store: Temporary secret storage.
@@ -652,10 +657,15 @@ async def enable_mfa(
)
try:
profile_utils.enable_user_mfa(token_user_id, secret, request.mfa_code, db)
backup_codes = profile_utils.enable_user_mfa(
token_user_id, secret, request.mfa_code, password_hasher, db
)
# Clean up the temporary secret
mfa_secret_store.delete_secret(token_user_id)
return {"message": "MFA enabled successfully"}
return {
"message": "MFA enabled successfully",
"backup_codes": backup_codes,
}
except HTTPException:
# Clean up on error
mfa_secret_store.delete_secret(token_user_id)
@@ -693,6 +703,10 @@ async def verify_mfa(
int,
Depends(auth_security.get_sub_from_access_token),
],
password_hasher: Annotated[
auth_password_hasher.PasswordHasher,
Depends(auth_password_hasher.get_password_hasher),
],
db: Annotated[Session, Depends(core_database.get_db)],
):
"""
@@ -701,6 +715,7 @@ async def verify_mfa(
Args:
request: MFA request with code to verify.
token_user_id: User ID from access token.
password_hasher: Password hasher instance for backup code verification.
db: Database session.
Returns:
@@ -709,7 +724,9 @@ async def verify_mfa(
Raises:
HTTPException: If MFA code is invalid.
"""
is_valid = profile_utils.verify_user_mfa(token_user_id, request.mfa_code, db)
is_valid = profile_utils.verify_user_mfa(
token_user_id, request.mfa_code, password_hasher, db
)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid MFA code"

View File

@@ -14,6 +14,9 @@ import core.cryptography as core_cryptography
import core.logger as core_logger
import profile.schema as profile_schema
import users.user.crud as users_crud
import auth.password_hasher as auth_password_hasher
import auth.mfa_backup_codes.crud as mfa_backup_codes_crud
import auth.mfa_backup_codes.utils as mfa_backup_codes_utils
from profile.exceptions import (
MemoryAllocationError,
)
@@ -187,7 +190,13 @@ def setup_user_mfa(user_id: int, db: Session) -> profile_schema.MFASetupResponse
)
def enable_user_mfa(user_id: int, secret: str, mfa_code: str, db: Session):
def enable_user_mfa(
user_id: int,
secret: str,
mfa_code: str,
password_hasher: auth_password_hasher.PasswordHasher,
db: Session,
) -> list[str]:
"""
Enable MFA for user after verification.
@@ -195,8 +204,12 @@ def enable_user_mfa(user_id: int, secret: str, mfa_code: str, db: Session):
user_id: User ID to enable MFA for.
secret: TOTP secret to verify.
mfa_code: MFA code to verify.
password_hasher: Password hasher instance for backup code generation.
db: Database session.
Returns:
List of generated backup codes.
Raises:
HTTPException: If user not found, MFA enabled, code
invalid, or encryption fails.
@@ -232,8 +245,14 @@ def enable_user_mfa(user_id: int, secret: str, mfa_code: str, db: Session):
# Update user with MFA enabled and secret
users_crud.enable_user_mfa(user_id, encrypted_secret, db)
backup_codes = mfa_backup_codes_crud.create_backup_codes(
user_id, password_hasher, db
)
def disable_user_mfa(user_id: int, mfa_code: str, db: Session):
return backup_codes
def disable_user_mfa(user_id: int, mfa_code: str, db: Session) -> None:
"""
Disable MFA for user after verification.
@@ -277,14 +296,23 @@ def disable_user_mfa(user_id: int, mfa_code: str, db: Session):
# Disable MFA for user
users_crud.disable_user_mfa(user_id, db)
# Delete all backup codes for user
mfa_backup_codes_crud.delete_user_backup_codes(user_id, db)
def verify_user_mfa(user_id: int, mfa_code: str, db: Session) -> bool:
def verify_user_mfa(
user_id: int,
mfa_code: str,
password_hasher: auth_password_hasher.PasswordHasher,
db: Session,
) -> bool:
"""
Verify MFA code for user.
Verify MFA code for user (TOTP or backup code).
Args:
user_id: User ID to verify MFA for.
mfa_code: MFA code to verify.
mfa_code: MFA code to verify (6-digit TOTP or 8-character backup code).
password_hasher: Password hasher instance for backup code verification.
db: Database session.
Returns:
@@ -292,6 +320,11 @@ def verify_user_mfa(user_id: int, mfa_code: str, db: Session) -> bool:
Raises:
HTTPException: If user not found.
Notes:
- First tries TOTP verification (6 digits)
- If TOTP fails and code is 8 characters, tries backup code
- Backup codes are consumed on successful verification
"""
user = users_crud.get_user_by_id(user_id, db)
if not user:
@@ -302,16 +335,46 @@ def verify_user_mfa(user_id: int, mfa_code: str, db: Session) -> bool:
if not user.mfa_enabled or not user.mfa_secret:
return False
# Decrypt the secret
try:
secret = core_cryptography.decrypt_token_fernet(user.mfa_secret)
if not secret:
core_logger.print_to_log("Failed to decrypt MFA secret", "error")
# Normalize code (remove dashes, uppercase)
normalized_code = mfa_code.strip().replace("-", "").upper()
# Try TOTP first (6 digits)
if len(normalized_code) == 6 and normalized_code.isdigit():
try:
secret = core_cryptography.decrypt_token_fernet(user.mfa_secret)
if not secret:
core_logger.print_to_log("Failed to decrypt MFA secret", "error")
return False
if verify_totp(secret, normalized_code):
core_logger.print_to_log(
f"User {user_id} verified MFA with TOTP", "info"
)
return True
except Exception as err:
core_logger.print_to_log(
f"Error in TOTP verification: {err}", "error", exc=err
)
return False
return verify_totp(secret, mfa_code)
except Exception as err:
core_logger.print_to_log(f"Error in verify_user_mfa: {err}", "error", exc=err)
return False
# Try backup code (8 alphanumeric characters)
elif len(normalized_code) == 8:
try:
if mfa_backup_codes_utils.verify_and_consume_backup_code(
user_id, normalized_code, password_hasher, db
):
core_logger.print_to_log(
f"User {user_id} verified MFA with backup code", "warning"
)
return True
except Exception as err:
core_logger.print_to_log(
f"Error in backup code verification: {err}", "error", exc=err
)
return False
# Invalid format or code didn't match
return False
def is_mfa_enabled_for_user(user_id: int, db: Session) -> bool:

View File

@@ -52,6 +52,7 @@ class User(Base):
goals: List of user goals.
user_identity_providers: List of identity providers linked to the user.
oauth_states: List of OAuth states for the user (link mode).
mfa_backup_codes: List of MFA backup codes associated with the user.
"""
__tablename__ = "users"
@@ -274,3 +275,8 @@ class User(Base):
back_populates="user",
cascade="all, delete-orphan",
)
# Establish a one-to-many relationship with mfa_backup_codes
mfa_backup_codes = relationship(
"MFABackupCode", back_populates="user", cascade="all, delete-orphan"
)