mirror of
https://github.com/joaovitoriasilva/endurain.git
synced 2026-01-07 23:13:57 -05:00
Add MFA backup codes support for user authentication
Introduces database models, migration, API endpoints, and business logic for MFA backup codes as a fallback authentication method. Users can generate, view status, and consume backup codes; codes are securely hashed and invalidated upon use. Integrates backup code verification into MFA flows, updates user and profile logic, and ensures codes are managed on MFA enable/disable actions.
This commit is contained in:
@@ -7,6 +7,7 @@ from alembic import context
|
||||
|
||||
|
||||
import auth.identity_providers.models
|
||||
import auth.mfa_backup_codes.models
|
||||
import auth.oauth_state.models
|
||||
import activities.activity.models
|
||||
import activities.activity_exercise_titles.models
|
||||
|
||||
@@ -245,11 +245,69 @@ def upgrade() -> None:
|
||||
["token_family_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Create mfa_backup_codes table
|
||||
op.create_table(
|
||||
"mfa_backup_codes",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
comment="User who owns this backup code",
|
||||
),
|
||||
sa.Column(
|
||||
"code_hash",
|
||||
sa.String(length=255),
|
||||
nullable=False,
|
||||
comment="Argon2 hash of the backup code",
|
||||
),
|
||||
sa.Column(
|
||||
"used",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
comment="Whether this code has been consumed",
|
||||
),
|
||||
sa.Column(
|
||||
"used_at", sa.DateTime(), nullable=True, comment="When this code was used"
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
comment="When this code was generated",
|
||||
),
|
||||
sa.Column(
|
||||
"expires_at",
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
comment="Optional expiry for code rotation policy",
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("code_hash"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_user_unused_codes", "mfa_backup_codes", ["user_id", "used"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_mfa_backup_codes_used"), "mfa_backup_codes", ["used"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_mfa_backup_codes_user_id"),
|
||||
"mfa_backup_codes",
|
||||
["user_id"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_mfa_backup_codes_user_id"), table_name="mfa_backup_codes")
|
||||
op.drop_index(op.f("ix_mfa_backup_codes_used"), table_name="mfa_backup_codes")
|
||||
op.drop_index("idx_user_unused_codes", table_name="mfa_backup_codes")
|
||||
op.drop_table("mfa_backup_codes")
|
||||
op.drop_constraint(None, "users_sessions", type_="foreignkey")
|
||||
op.drop_index(
|
||||
op.f("ix_users_sessions_token_family_id"), table_name="users_sessions"
|
||||
|
||||
0
backend/app/auth/mfa_backup_codes/__init__.py
Normal file
0
backend/app/auth/mfa_backup_codes/__init__.py
Normal file
170
backend/app/auth/mfa_backup_codes/crud.py
Normal file
170
backend/app/auth/mfa_backup_codes/crud.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
import auth.mfa_backup_codes.models as mfa_backup_codes_models
|
||||
import auth.mfa_backup_codes.utils as mfa_backup_codes_utils
|
||||
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import core.logger as core_logger
|
||||
|
||||
|
||||
def get_user_backup_codes(
|
||||
user_id: int, db: Session
|
||||
) -> list[mfa_backup_codes_models.MFABackupCode]:
|
||||
try:
|
||||
return (
|
||||
db.query(mfa_backup_codes_models.MFABackupCode)
|
||||
.filter(
|
||||
mfa_backup_codes_models.MFABackupCode.user_id == user_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Error in get_user_backup_codes: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve backup codes",
|
||||
) from err
|
||||
|
||||
|
||||
def get_user_unused_backup_codes(
|
||||
user_id: int, db: Session
|
||||
) -> list[mfa_backup_codes_models.MFABackupCode]:
|
||||
try:
|
||||
return (
|
||||
db.query(mfa_backup_codes_models.MFABackupCode)
|
||||
.filter(
|
||||
mfa_backup_codes_models.MFABackupCode.user_id == user_id,
|
||||
mfa_backup_codes_models.MFABackupCode.used == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Error in get_user_unused_backup_codes: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve unused backup codes",
|
||||
) from err
|
||||
|
||||
|
||||
def create_backup_codes(
|
||||
user_id: int,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
count: int = 10,
|
||||
) -> list[str]:
|
||||
try:
|
||||
delete_user_backup_codes(user_id, db)
|
||||
|
||||
plaintext_codes = []
|
||||
|
||||
for _ in range(count):
|
||||
# Generate unique code
|
||||
code = mfa_backup_codes_utils.generate_backup_code()
|
||||
|
||||
# Hash the code (bcrypt)
|
||||
code_hash = password_hasher.hash_password(code)
|
||||
|
||||
# Store hash in database
|
||||
backup_code = mfa_backup_codes_models.MFABackupCode(
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(backup_code)
|
||||
|
||||
# Keep plaintext for return (only time it's available)
|
||||
plaintext_codes.append(code)
|
||||
|
||||
db.commit()
|
||||
|
||||
core_logger.print_to_log(f"Created backup codes for user ID {user_id}", "info")
|
||||
|
||||
return plaintext_codes
|
||||
except HTTPException as err:
|
||||
raise err
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Error creating backup codes for user ID {user_id}: {err}",
|
||||
"error",
|
||||
exc=err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to regenerate backup codes",
|
||||
) from err
|
||||
|
||||
|
||||
def mark_backup_code_as_used(
|
||||
backup_code_hashed: str, user_id: int, db: Session
|
||||
) -> None:
|
||||
try:
|
||||
db_code = (
|
||||
db.query(mfa_backup_codes_models.MFABackupCode)
|
||||
.filter(
|
||||
mfa_backup_codes_models.MFABackupCode.user_id == user_id,
|
||||
mfa_backup_codes_models.MFABackupCode.code_hash == backup_code_hashed,
|
||||
mfa_backup_codes_models.MFABackupCode.used == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_code:
|
||||
db_code.used = True
|
||||
db_code.used_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(db_code)
|
||||
core_logger.print_to_log(
|
||||
f"Marked backup code as used for user ID {user_id}", "info"
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"No unused backup code found to mark as used for user ID {user_id}",
|
||||
"warning",
|
||||
)
|
||||
HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Backup code not found or already used",
|
||||
)
|
||||
except Exception as err:
|
||||
db.rollback()
|
||||
core_logger.print_to_log(
|
||||
f"Error in mark_backup_code_as_used: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal Server Error",
|
||||
) from err
|
||||
|
||||
|
||||
def delete_user_backup_codes(user_id: int, db: Session) -> int:
|
||||
try:
|
||||
# Delete existing codes
|
||||
num_deleted = (
|
||||
db.query(mfa_backup_codes_models.MFABackupCode)
|
||||
.filter(mfa_backup_codes_models.MFABackupCode.user_id == user_id)
|
||||
.delete()
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Deleted {num_deleted} backup codes for user ID: {user_id}", "info"
|
||||
)
|
||||
|
||||
return num_deleted
|
||||
except Exception as err:
|
||||
db.rollback()
|
||||
core_logger.print_to_log(
|
||||
f"Error in delete_user_backup_codes: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal Server Error",
|
||||
) from err
|
||||
72
backend/app/auth/mfa_backup_codes/models.py
Normal file
72
backend/app/auth/mfa_backup_codes/models.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime, timezone
|
||||
from core.database import Base
|
||||
|
||||
|
||||
class MFABackupCode(Base):
|
||||
"""
|
||||
SQLAlchemy model for MFA backup codes.
|
||||
|
||||
This model stores hashed backup codes that users can use as a fallback
|
||||
authentication method when their primary MFA device is unavailable.
|
||||
|
||||
Attributes:
|
||||
id (int): Primary key, auto-incrementing identifier.
|
||||
user_id (int): Foreign key to the users table, identifies the code owner.
|
||||
code_hash (str): Argon2 hash of the backup code for secure storage.
|
||||
used (bool): Flag indicating whether the code has been consumed.
|
||||
used_at (datetime): Timestamp when the code was used, if applicable.
|
||||
created_at (datetime): Timestamp when the code was generated (UTC).
|
||||
expires_at (datetime): Optional expiration timestamp for code rotation.
|
||||
|
||||
Relationships:
|
||||
user: Many-to-one relationship with the User model.
|
||||
|
||||
Indexes:
|
||||
- Primary index on user_id for foreign key constraint
|
||||
- Unique index on code_hash to prevent duplicates
|
||||
- Index on used for filtering consumed codes
|
||||
- Composite index (idx_user_unused_codes) on user_id and used for
|
||||
efficient lookups of available backup codes per user
|
||||
"""
|
||||
|
||||
__tablename__ = "mfa_backup_codes"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(
|
||||
Integer,
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="User who owns this backup code",
|
||||
)
|
||||
code_hash = Column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
comment="Argon2 hash of the backup code",
|
||||
)
|
||||
used = Column(
|
||||
Boolean,
|
||||
default=False,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Whether this code has been consumed",
|
||||
)
|
||||
used_at = Column(DateTime, nullable=True, comment="When this code was used")
|
||||
created_at = Column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
comment="When this code was generated",
|
||||
)
|
||||
expires_at = Column(
|
||||
DateTime, nullable=True, comment="Optional expiry for code rotation policy"
|
||||
)
|
||||
|
||||
# Establish relationship back to User model
|
||||
user = relationship("User", back_populates="mfa_backup_codes")
|
||||
|
||||
# Composite index for fast unused code lookups
|
||||
__table_args__ = (Index("idx_user_unused_codes", "user_id", "used"),)
|
||||
127
backend/app/auth/mfa_backup_codes/router.py
Normal file
127
backend/app/auth/mfa_backup_codes/router.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated, Callable
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
Response,
|
||||
Request,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.utils as session_utils
|
||||
import auth.security as auth_security
|
||||
import auth.utils as auth_utils
|
||||
import auth.constants as auth_constants
|
||||
import session.crud as session_crud
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
import auth.token_manager as auth_token_manager
|
||||
import auth.schema as auth_schema
|
||||
|
||||
import auth.mfa_backup_codes.schema as mfa_backup_codes_schema
|
||||
import auth.mfa_backup_codes.crud as mfa_backup_codes_crud
|
||||
|
||||
import auth.identity_providers.utils as idp_utils
|
||||
|
||||
import users.user.crud as users_crud
|
||||
import users.user.utils as users_utils
|
||||
import profile.utils as profile_utils
|
||||
|
||||
import core.database as core_database
|
||||
import core.rate_limit as core_rate_limit
|
||||
import core.logger as core_logger
|
||||
|
||||
import session.rotated_refresh_tokens.utils as rotated_tokens_utils
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/status",
|
||||
response_model=mfa_backup_codes_schema.MFABackupCodeStatus,
|
||||
)
|
||||
async def get_backup_code_status(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(auth_security.get_sub_from_refresh_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
codes = mfa_backup_codes_crud.get_user_backup_codes(token_user_id, db)
|
||||
|
||||
if not codes:
|
||||
return mfa_backup_codes_schema.MFABackupCodeStatus(
|
||||
has_codes=False,
|
||||
total=0,
|
||||
unused=0,
|
||||
used=0,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
unused = sum(1 for code in codes if not code.used)
|
||||
used = sum(1 for code in codes if code.used)
|
||||
created_at = codes[0].created_at if codes else None
|
||||
|
||||
return mfa_backup_codes_schema.MFABackupCodeStatus(
|
||||
has_codes=True,
|
||||
total=len(codes),
|
||||
unused=unused,
|
||||
used=used,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=mfa_backup_codes_schema.MFABackupCodesResponse,
|
||||
)
|
||||
@core_rate_limit.limiter.limit(core_rate_limit.MFA_VERIFY_LIMIT)
|
||||
async def generate_mfa_backup_codes(
|
||||
response: Response,
|
||||
request: Request,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(auth_security.get_sub_from_refresh_token),
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
user = users_crud.get_user_by_id(token_user_id, db)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if not user.mfa_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MFA must be enabled to generate backup codes",
|
||||
)
|
||||
|
||||
# Generate codes (invalidates old codes)
|
||||
codes = mfa_backup_codes_crud.create_backup_codes(
|
||||
token_user_id, password_hasher, db
|
||||
)
|
||||
|
||||
# Log event
|
||||
core_logger.print_to_log(f"User {user.id} generated MFA backup codes", "info")
|
||||
|
||||
return mfa_backup_codes_schema.MFABackupCodesResponse(
|
||||
codes=codes,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
15
backend/app/auth/mfa_backup_codes/schema.py
Normal file
15
backend/app/auth/mfa_backup_codes/schema.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class MFABackupCodesResponse(BaseModel):
|
||||
codes: list[str] = Field(..., description="10 one-time backup codes")
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class MFABackupCodeStatus(BaseModel):
|
||||
has_codes: bool
|
||||
total: int
|
||||
unused: int
|
||||
used: int
|
||||
created_at: datetime | None = None
|
||||
58
backend/app/auth/mfa_backup_codes/utils.py
Normal file
58
backend/app/auth/mfa_backup_codes/utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import secrets
|
||||
import string
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import auth.mfa_backup_codes.crud as mfa_backup_codes_crud
|
||||
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def generate_backup_code() -> str:
|
||||
"""
|
||||
Generate a cryptographically secure 8-character backup code.
|
||||
|
||||
Format: XXXX-XXXX (uppercase alphanumeric, no ambiguous chars)
|
||||
Excludes: 0, O, 1, I, l (to prevent confusion)
|
||||
Entropy: ~40 bits per code (8 chars from 32-char alphabet)
|
||||
|
||||
Returns:
|
||||
str: Formatted backup code (e.g., "A3K9-7BDF")
|
||||
"""
|
||||
alphabet = string.ascii_uppercase + string.digits
|
||||
# Remove ambiguous characters
|
||||
alphabet = (
|
||||
alphabet.replace("0", "").replace("O", "").replace("1", "").replace("I", "")
|
||||
)
|
||||
|
||||
# Generate 8 random characters
|
||||
code = "".join(secrets.choice(alphabet) for _ in range(8))
|
||||
|
||||
# Format as XXXX-XXXX
|
||||
return f"{code[:4]}-{code[4:]}"
|
||||
|
||||
|
||||
def verify_and_consume_backup_code(
|
||||
user_id: int,
|
||||
code: str,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
# Get all unused codes for this user
|
||||
unused_codes = mfa_backup_codes_crud.get_user_unused_backup_codes(user_id, db)
|
||||
|
||||
# Try each unused code (constant-time for each)
|
||||
for unused_code in unused_codes:
|
||||
if password_hasher.verify(code, unused_code.code_hash):
|
||||
# Valid code found - mark as used
|
||||
mfa_backup_codes_crud.mark_backup_code_as_used(
|
||||
unused_code.code_hash, user_id, db
|
||||
)
|
||||
|
||||
# Return success
|
||||
return True
|
||||
|
||||
# No matching code found
|
||||
return False
|
||||
@@ -227,8 +227,10 @@ async def verify_mfa_and_login(
|
||||
detail="No pending MFA login found for this username",
|
||||
)
|
||||
|
||||
# Verify the MFA code
|
||||
if not profile_utils.verify_user_mfa(user_id, mfa_request.mfa_code, db):
|
||||
# Verify the MFA code (TOTP or backup code)
|
||||
if not profile_utils.verify_user_mfa(
|
||||
user_id, mfa_request.mfa_code, password_hasher, db
|
||||
):
|
||||
# Record failed attempt and apply lockout if threshold exceeded
|
||||
failed_count = pending_mfa_store.record_failed_attempt(mfa_request.username)
|
||||
raise HTTPException(
|
||||
|
||||
@@ -16,6 +16,10 @@ import activities.activity_summaries.router as activity_summaries_router
|
||||
import activities.activity_workout_steps.router as activity_workout_steps_router
|
||||
import activities.activity_workout_steps.public_router as activity_workout_steps_public_router
|
||||
import auth.router as auth_router
|
||||
import auth.mfa_backup_codes.router as mfa_backup_codes_router
|
||||
import auth.identity_providers.router as identity_providers_router
|
||||
import auth.identity_providers.public_router as identity_providers_public_router
|
||||
import auth.security as auth_security
|
||||
import core.config as core_config
|
||||
import core.router as core_router
|
||||
import followers.router as followers_router
|
||||
@@ -26,8 +30,6 @@ import health_sleep.router as health_sleep_router
|
||||
import health_weight.router as health_weight_router
|
||||
import health_steps.router as health_steps_router
|
||||
import health_targets.router as health_targets_router
|
||||
import auth.identity_providers.router as identity_providers_router
|
||||
import auth.identity_providers.public_router as identity_providers_public_router
|
||||
import notifications.router as notifications_router
|
||||
import password_reset_tokens.router as password_reset_tokens_router
|
||||
import profile.browser_redirect_router as profile_browser_redirect_router
|
||||
@@ -35,7 +37,6 @@ import profile.router as profile_router
|
||||
import server_settings.public_router as server_settings_public_router
|
||||
import server_settings.router as server_settings_router
|
||||
import session.router as session_router
|
||||
import auth.security as auth_security
|
||||
import sign_up_tokens.router as sign_up_tokens_router
|
||||
import strava.router as strava_router
|
||||
import users.user.router as users_router
|
||||
@@ -102,6 +103,12 @@ router.include_router(
|
||||
prefix=core_config.ROOT_PATH + "/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
router.include_router(
|
||||
mfa_backup_codes_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/auth/mfa/backup-codes",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
followers_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/followers",
|
||||
|
||||
@@ -623,6 +623,10 @@ async def enable_mfa(
|
||||
int,
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
mfa_secret_store: Annotated[
|
||||
profile_schema.MFASecretStore, Depends(profile_schema.get_mfa_secret_store)
|
||||
@@ -634,6 +638,7 @@ async def enable_mfa(
|
||||
Args:
|
||||
request: MFA setup request with code.
|
||||
token_user_id: User ID from access token.
|
||||
password_hasher: Password hasher instance for backup code generation.
|
||||
db: Database session.
|
||||
mfa_secret_store: Temporary secret storage.
|
||||
|
||||
@@ -652,10 +657,15 @@ async def enable_mfa(
|
||||
)
|
||||
|
||||
try:
|
||||
profile_utils.enable_user_mfa(token_user_id, secret, request.mfa_code, db)
|
||||
backup_codes = profile_utils.enable_user_mfa(
|
||||
token_user_id, secret, request.mfa_code, password_hasher, db
|
||||
)
|
||||
# Clean up the temporary secret
|
||||
mfa_secret_store.delete_secret(token_user_id)
|
||||
return {"message": "MFA enabled successfully"}
|
||||
return {
|
||||
"message": "MFA enabled successfully",
|
||||
"backup_codes": backup_codes,
|
||||
}
|
||||
except HTTPException:
|
||||
# Clean up on error
|
||||
mfa_secret_store.delete_secret(token_user_id)
|
||||
@@ -693,6 +703,10 @@ async def verify_mfa(
|
||||
int,
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
"""
|
||||
@@ -701,6 +715,7 @@ async def verify_mfa(
|
||||
Args:
|
||||
request: MFA request with code to verify.
|
||||
token_user_id: User ID from access token.
|
||||
password_hasher: Password hasher instance for backup code verification.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
@@ -709,7 +724,9 @@ async def verify_mfa(
|
||||
Raises:
|
||||
HTTPException: If MFA code is invalid.
|
||||
"""
|
||||
is_valid = profile_utils.verify_user_mfa(token_user_id, request.mfa_code, db)
|
||||
is_valid = profile_utils.verify_user_mfa(
|
||||
token_user_id, request.mfa_code, password_hasher, db
|
||||
)
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid MFA code"
|
||||
|
||||
@@ -14,6 +14,9 @@ import core.cryptography as core_cryptography
|
||||
import core.logger as core_logger
|
||||
import profile.schema as profile_schema
|
||||
import users.user.crud as users_crud
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
import auth.mfa_backup_codes.crud as mfa_backup_codes_crud
|
||||
import auth.mfa_backup_codes.utils as mfa_backup_codes_utils
|
||||
from profile.exceptions import (
|
||||
MemoryAllocationError,
|
||||
)
|
||||
@@ -187,7 +190,13 @@ def setup_user_mfa(user_id: int, db: Session) -> profile_schema.MFASetupResponse
|
||||
)
|
||||
|
||||
|
||||
def enable_user_mfa(user_id: int, secret: str, mfa_code: str, db: Session):
|
||||
def enable_user_mfa(
|
||||
user_id: int,
|
||||
secret: str,
|
||||
mfa_code: str,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Enable MFA for user after verification.
|
||||
|
||||
@@ -195,8 +204,12 @@ def enable_user_mfa(user_id: int, secret: str, mfa_code: str, db: Session):
|
||||
user_id: User ID to enable MFA for.
|
||||
secret: TOTP secret to verify.
|
||||
mfa_code: MFA code to verify.
|
||||
password_hasher: Password hasher instance for backup code generation.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
List of generated backup codes.
|
||||
|
||||
Raises:
|
||||
HTTPException: If user not found, MFA enabled, code
|
||||
invalid, or encryption fails.
|
||||
@@ -232,8 +245,14 @@ def enable_user_mfa(user_id: int, secret: str, mfa_code: str, db: Session):
|
||||
# Update user with MFA enabled and secret
|
||||
users_crud.enable_user_mfa(user_id, encrypted_secret, db)
|
||||
|
||||
backup_codes = mfa_backup_codes_crud.create_backup_codes(
|
||||
user_id, password_hasher, db
|
||||
)
|
||||
|
||||
def disable_user_mfa(user_id: int, mfa_code: str, db: Session):
|
||||
return backup_codes
|
||||
|
||||
|
||||
def disable_user_mfa(user_id: int, mfa_code: str, db: Session) -> None:
|
||||
"""
|
||||
Disable MFA for user after verification.
|
||||
|
||||
@@ -277,14 +296,23 @@ def disable_user_mfa(user_id: int, mfa_code: str, db: Session):
|
||||
# Disable MFA for user
|
||||
users_crud.disable_user_mfa(user_id, db)
|
||||
|
||||
# Delete all backup codes for user
|
||||
mfa_backup_codes_crud.delete_user_backup_codes(user_id, db)
|
||||
|
||||
def verify_user_mfa(user_id: int, mfa_code: str, db: Session) -> bool:
|
||||
|
||||
def verify_user_mfa(
|
||||
user_id: int,
|
||||
mfa_code: str,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify MFA code for user.
|
||||
Verify MFA code for user (TOTP or backup code).
|
||||
|
||||
Args:
|
||||
user_id: User ID to verify MFA for.
|
||||
mfa_code: MFA code to verify.
|
||||
mfa_code: MFA code to verify (6-digit TOTP or 8-character backup code).
|
||||
password_hasher: Password hasher instance for backup code verification.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
@@ -292,6 +320,11 @@ def verify_user_mfa(user_id: int, mfa_code: str, db: Session) -> bool:
|
||||
|
||||
Raises:
|
||||
HTTPException: If user not found.
|
||||
|
||||
Notes:
|
||||
- First tries TOTP verification (6 digits)
|
||||
- If TOTP fails and code is 8 characters, tries backup code
|
||||
- Backup codes are consumed on successful verification
|
||||
"""
|
||||
user = users_crud.get_user_by_id(user_id, db)
|
||||
if not user:
|
||||
@@ -302,16 +335,46 @@ def verify_user_mfa(user_id: int, mfa_code: str, db: Session) -> bool:
|
||||
if not user.mfa_enabled or not user.mfa_secret:
|
||||
return False
|
||||
|
||||
# Decrypt the secret
|
||||
try:
|
||||
secret = core_cryptography.decrypt_token_fernet(user.mfa_secret)
|
||||
if not secret:
|
||||
core_logger.print_to_log("Failed to decrypt MFA secret", "error")
|
||||
# Normalize code (remove dashes, uppercase)
|
||||
normalized_code = mfa_code.strip().replace("-", "").upper()
|
||||
|
||||
# Try TOTP first (6 digits)
|
||||
if len(normalized_code) == 6 and normalized_code.isdigit():
|
||||
try:
|
||||
secret = core_cryptography.decrypt_token_fernet(user.mfa_secret)
|
||||
if not secret:
|
||||
core_logger.print_to_log("Failed to decrypt MFA secret", "error")
|
||||
return False
|
||||
|
||||
if verify_totp(secret, normalized_code):
|
||||
core_logger.print_to_log(
|
||||
f"User {user_id} verified MFA with TOTP", "info"
|
||||
)
|
||||
return True
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Error in TOTP verification: {err}", "error", exc=err
|
||||
)
|
||||
return False
|
||||
return verify_totp(secret, mfa_code)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(f"Error in verify_user_mfa: {err}", "error", exc=err)
|
||||
return False
|
||||
|
||||
# Try backup code (8 alphanumeric characters)
|
||||
elif len(normalized_code) == 8:
|
||||
try:
|
||||
if mfa_backup_codes_utils.verify_and_consume_backup_code(
|
||||
user_id, normalized_code, password_hasher, db
|
||||
):
|
||||
core_logger.print_to_log(
|
||||
f"User {user_id} verified MFA with backup code", "warning"
|
||||
)
|
||||
return True
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Error in backup code verification: {err}", "error", exc=err
|
||||
)
|
||||
return False
|
||||
|
||||
# Invalid format or code didn't match
|
||||
return False
|
||||
|
||||
|
||||
def is_mfa_enabled_for_user(user_id: int, db: Session) -> bool:
|
||||
|
||||
@@ -52,6 +52,7 @@ class User(Base):
|
||||
goals: List of user goals.
|
||||
user_identity_providers: List of identity providers linked to the user.
|
||||
oauth_states: List of OAuth states for the user (link mode).
|
||||
mfa_backup_codes: List of MFA backup codes associated with the user.
|
||||
"""
|
||||
|
||||
__tablename__ = "users"
|
||||
@@ -274,3 +275,8 @@ class User(Base):
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Establish a one-to-many relationship with mfa_backup_codes
|
||||
mfa_backup_codes = relationship(
|
||||
"MFABackupCode", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user