mirror of
https://github.com/joaovitoriasilva/endurain.git
synced 2026-01-09 15:57:59 -05:00
Add tests for MFA backup codes and OAuth state modules
Added comprehensive unit tests for MFA backup codes CRUD, router, and utility functions, as well as for OAuth state CRUD and utility functions. Also fixed timezone handling in OAuth state expiry check. These tests improve coverage and reliability for authentication-related features.
This commit is contained in:
@@ -33,7 +33,7 @@ def get_oauth_state_by_id(
|
||||
return None
|
||||
|
||||
# Check expiry
|
||||
if datetime.now() > oauth_state.expires_at:
|
||||
if datetime.now(timezone.utc) > oauth_state.expires_at:
|
||||
core_logger.print_to_log(f"OAuth state expired: {state_id[:8]}...", "warning")
|
||||
return None
|
||||
|
||||
|
||||
1
backend/tests/auth/mfa_backup_codes/__init__.py
Normal file
1
backend/tests/auth/mfa_backup_codes/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for MFA backup codes module."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
258
backend/tests/auth/mfa_backup_codes/test_crud.py
Normal file
258
backend/tests/auth/mfa_backup_codes/test_crud.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Tests for MFA backup codes CRUD operations."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
import auth.mfa_backup_codes.crud as backup_crud
|
||||
import auth.mfa_backup_codes.models as backup_models
|
||||
|
||||
|
||||
class TestGetUserBackupCodes:
|
||||
"""Test suite for get_user_backup_codes function."""
|
||||
|
||||
def test_get_user_backup_codes_success(self, mock_db):
|
||||
"""Test successful retrieval of user backup codes."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
mock_code1 = MagicMock(spec=backup_models.MFABackupCode)
|
||||
mock_code2 = MagicMock(spec=backup_models.MFABackupCode)
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.all.return_value = [mock_code1, mock_code2]
|
||||
|
||||
# Act
|
||||
result = backup_crud.get_user_backup_codes(user_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == [mock_code1, mock_code2]
|
||||
mock_db.query.assert_called_once_with(backup_models.MFABackupCode)
|
||||
|
||||
def test_get_user_backup_codes_exception(self, mock_db):
|
||||
"""Test exception handling in get_user_backup_codes."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
backup_crud.get_user_backup_codes(user_id, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert exc_info.value.detail == "Failed to retrieve backup codes"
|
||||
|
||||
|
||||
class TestGetUserUnusedBackupCodes:
|
||||
"""Test suite for get_user_unused_backup_codes function."""
|
||||
|
||||
def test_get_unused_codes_success(self, mock_db):
|
||||
"""Test successful retrieval of unused backup codes."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
mock_code1 = MagicMock(spec=backup_models.MFABackupCode, used=False)
|
||||
mock_code2 = MagicMock(spec=backup_models.MFABackupCode, used=False)
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.all.return_value = [mock_code1, mock_code2]
|
||||
|
||||
# Act
|
||||
result = backup_crud.get_user_unused_backup_codes(user_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == [mock_code1, mock_code2]
|
||||
assert all(not code.used for code in result)
|
||||
|
||||
def test_get_unused_codes_exception(self, mock_db):
|
||||
"""Test exception handling in get_user_unused_backup_codes."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
backup_crud.get_user_unused_backup_codes(user_id, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
class TestCreateBackupCodes:
|
||||
"""Test suite for create_backup_codes function."""
|
||||
|
||||
def test_create_backup_codes_success(self, mock_db, password_hasher):
|
||||
"""Test successful backup codes creation."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
count = 10
|
||||
|
||||
# Mock the model instantiation to avoid SQLAlchemy mapper issues
|
||||
with patch.object(backup_crud, "delete_user_backup_codes"), patch(
|
||||
"auth.mfa_backup_codes.crud.mfa_backup_codes_models.MFABackupCode"
|
||||
):
|
||||
# Act
|
||||
codes = backup_crud.create_backup_codes(
|
||||
user_id, password_hasher, mock_db, count
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(codes) == count
|
||||
assert all(len(code) == 9 for code in codes) # XXXX-XXXX format
|
||||
assert all("-" in code for code in codes)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_backup_codes_deletes_old_codes(self, mock_db, password_hasher):
|
||||
"""Test that old codes are deleted before creating new ones."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
|
||||
# Mock the model instantiation to avoid SQLAlchemy mapper issues
|
||||
with patch.object(
|
||||
backup_crud, "delete_user_backup_codes"
|
||||
) as mock_delete, patch(
|
||||
"auth.mfa_backup_codes.crud.mfa_backup_codes_models.MFABackupCode"
|
||||
):
|
||||
# Act
|
||||
backup_crud.create_backup_codes(user_id, password_hasher, mock_db)
|
||||
|
||||
# Assert
|
||||
mock_delete.assert_called_once_with(user_id, mock_db)
|
||||
|
||||
def test_create_backup_codes_custom_count(self, mock_db, password_hasher):
|
||||
"""Test creation with custom code count."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
custom_count = 5
|
||||
|
||||
# Mock the model instantiation to avoid SQLAlchemy mapper issues
|
||||
with patch.object(backup_crud, "delete_user_backup_codes"), patch(
|
||||
"auth.mfa_backup_codes.crud.mfa_backup_codes_models.MFABackupCode"
|
||||
):
|
||||
# Act
|
||||
codes = backup_crud.create_backup_codes(
|
||||
user_id, password_hasher, mock_db, custom_count
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(codes) == custom_count
|
||||
|
||||
def test_create_backup_codes_exception(self, mock_db, password_hasher):
|
||||
"""Test exception handling in create_backup_codes."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
mock_db.add.side_effect = Exception("Database error")
|
||||
|
||||
with patch.object(backup_crud, "delete_user_backup_codes"):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
backup_crud.create_backup_codes(user_id, password_hasher, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
class TestMarkBackupCodeAsUsed:
|
||||
"""Test suite for mark_backup_code_as_used function."""
|
||||
|
||||
def test_mark_code_as_used_success(self, mock_db):
|
||||
"""Test successful marking of backup code as used."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
code_hash = "hashed_code"
|
||||
|
||||
mock_code = MagicMock(spec=backup_models.MFABackupCode)
|
||||
mock_code.used = False
|
||||
mock_code.used_at = None
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = mock_code
|
||||
|
||||
# Act
|
||||
backup_crud.mark_backup_code_as_used(code_hash, user_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert mock_code.used is True
|
||||
assert mock_code.used_at is not None
|
||||
mock_db.commit.assert_called_once()
|
||||
mock_db.refresh.assert_called_once_with(mock_code)
|
||||
|
||||
def test_mark_code_as_used_not_found(self, mock_db):
|
||||
"""Test marking non-existent code doesn't raise exception."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
code_hash = "nonexistent_hash"
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = None
|
||||
|
||||
# Act (should not raise exception)
|
||||
backup_crud.mark_backup_code_as_used(code_hash, user_id, mock_db)
|
||||
|
||||
# Assert
|
||||
mock_db.commit.assert_not_called()
|
||||
|
||||
def test_mark_code_as_used_exception(self, mock_db):
|
||||
"""Test exception handling in mark_backup_code_as_used."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
code_hash = "hashed_code"
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
backup_crud.mark_backup_code_as_used(code_hash, user_id, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
mock_db.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestDeleteUserBackupCodes:
|
||||
"""Test suite for delete_user_backup_codes function."""
|
||||
|
||||
def test_delete_codes_success(self, mock_db):
|
||||
"""Test successful deletion of user backup codes."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
expected_count = 10
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.delete.return_value = expected_count
|
||||
|
||||
# Act
|
||||
result = backup_crud.delete_user_backup_codes(user_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == expected_count
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_delete_codes_none_found(self, mock_db):
|
||||
"""Test deletion when no codes exist."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.delete.return_value = 0
|
||||
|
||||
# Act
|
||||
result = backup_crud.delete_user_backup_codes(user_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_delete_codes_exception(self, mock_db):
|
||||
"""Test exception handling in delete_user_backup_codes."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
backup_crud.delete_user_backup_codes(user_id, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
mock_db.rollback.assert_called_once()
|
||||
232
backend/tests/auth/mfa_backup_codes/test_router.py
Normal file
232
backend/tests/auth/mfa_backup_codes/test_router.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Tests for MFA backup codes API router."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from fastapi import HTTPException, status
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
import auth.mfa_backup_codes.router as backup_router
|
||||
import auth.mfa_backup_codes.crud as backup_crud
|
||||
import auth.mfa_backup_codes.schema as backup_schema
|
||||
import users.user.crud as users_crud
|
||||
|
||||
|
||||
class TestGetMFABackupCodeStatus:
|
||||
"""Test suite for get_mfa_backup_code_status endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_status_success(self, mock_db, sample_user_read):
|
||||
"""Test successful retrieval of backup code status."""
|
||||
# Arrange
|
||||
created_at = datetime.now(timezone.utc)
|
||||
mock_codes = [
|
||||
MagicMock(used=False, created_at=created_at),
|
||||
MagicMock(used=False, created_at=created_at),
|
||||
MagicMock(used=True, created_at=created_at),
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
backup_crud, "get_user_backup_codes", return_value=mock_codes
|
||||
):
|
||||
# Act
|
||||
result = await backup_router.get_backup_code_status(
|
||||
token_user_id=sample_user_read.id, db=mock_db
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, backup_schema.MFABackupCodeStatus)
|
||||
assert result.total == 3
|
||||
assert result.unused == 2
|
||||
assert result.used == 1
|
||||
assert result.created_at == created_at
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_status_no_codes(self, mock_db, sample_user_read):
|
||||
"""Test status when user has no backup codes."""
|
||||
# Arrange
|
||||
with patch.object(backup_crud, "get_user_backup_codes", return_value=[]):
|
||||
# Act
|
||||
result = await backup_router.get_backup_code_status(
|
||||
token_user_id=sample_user_read.id, db=mock_db
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.total == 0
|
||||
assert result.unused == 0
|
||||
assert result.used == 0
|
||||
assert result.created_at is None
|
||||
|
||||
def test_get_status_requires_mfa_enabled(self, mock_db):
|
||||
"""Test status endpoint requires MFA to be enabled."""
|
||||
# Arrange
|
||||
user_without_mfa = MagicMock()
|
||||
user_without_mfa.mfa_enabled = False
|
||||
|
||||
# This would be handled by dependency injection in router
|
||||
# but we verify the logic here
|
||||
assert user_without_mfa.mfa_enabled is False
|
||||
|
||||
|
||||
class TestGenerateMFABackupCodes:
|
||||
"""Test suite for generate_mfa_backup_codes endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_codes_success(
|
||||
self, mock_db, sample_user_read, password_hasher
|
||||
):
|
||||
"""Test successful generation of new backup codes."""
|
||||
# Arrange
|
||||
generated_codes = [
|
||||
"A3K9-7BDF",
|
||||
"X7Q4-MNPR",
|
||||
"K2W9-CGHS",
|
||||
"P5T8-VJBF",
|
||||
"N6R4-XDKL",
|
||||
"M9G2-WHTR",
|
||||
"Q8B5-YZCP",
|
||||
"H4V7-NSJF",
|
||||
"C3K6-PMDG",
|
||||
"Y2T9-XBRN",
|
||||
]
|
||||
|
||||
# Need to mock the user lookup
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = sample_user_read.id
|
||||
mock_user.mfa_enabled = True
|
||||
|
||||
# Create proper Response object for rate limiting
|
||||
mock_response = Response()
|
||||
|
||||
# Create proper Request object for rate limiting
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"headers": Headers({}).raw,
|
||||
"query_string": b"",
|
||||
"path": "/api/v1/auth/mfa/backup-codes",
|
||||
"client": ("testclient", 50000),
|
||||
"server": ("testserver", 80),
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
users_crud, "get_user_by_id", return_value=mock_user
|
||||
), patch.object(
|
||||
backup_crud, "create_backup_codes", return_value=generated_codes
|
||||
):
|
||||
# Act
|
||||
result = await backup_router.generate_mfa_backup_codes(
|
||||
response=mock_response,
|
||||
request=mock_request,
|
||||
token_user_id=sample_user_read.id,
|
||||
db=mock_db,
|
||||
password_hasher=password_hasher,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, backup_schema.MFABackupCodesResponse)
|
||||
assert result.codes == generated_codes
|
||||
assert len(result.codes) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_codes_invalidates_old_codes(
|
||||
self, mock_db, sample_user_read, password_hasher
|
||||
):
|
||||
"""Test that generating new codes deletes old codes."""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = sample_user_read.id
|
||||
mock_user.mfa_enabled = True
|
||||
|
||||
# Create proper Response object for rate limiting
|
||||
mock_response = Response()
|
||||
|
||||
# Create proper Request object for rate limiting
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"headers": Headers({}).raw,
|
||||
"query_string": b"",
|
||||
"path": "/api/v1/auth/mfa/backup-codes",
|
||||
"client": ("testclient", 50000),
|
||||
"server": ("testserver", 80),
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
users_crud, "get_user_by_id", return_value=mock_user
|
||||
), patch.object(
|
||||
backup_crud, "create_backup_codes", return_value=["A3K9-7BDF"]
|
||||
) as mock_create:
|
||||
|
||||
# Act
|
||||
await backup_router.generate_mfa_backup_codes(
|
||||
response=mock_response,
|
||||
request=mock_request,
|
||||
token_user_id=sample_user_read.id,
|
||||
db=mock_db,
|
||||
password_hasher=password_hasher,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# delete_user_backup_codes is called inside create_backup_codes
|
||||
mock_create.assert_called_once()
|
||||
|
||||
def test_generate_codes_requires_mfa_enabled(self, mock_db, password_hasher):
|
||||
"""Test generation endpoint requires MFA to be enabled."""
|
||||
# Arrange
|
||||
user_without_mfa = MagicMock()
|
||||
user_without_mfa.mfa_enabled = False
|
||||
|
||||
# This would be handled by dependency injection in router
|
||||
# but we verify the logic here
|
||||
assert user_without_mfa.mfa_enabled is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_codes_format_validation(
|
||||
self, mock_db, sample_user_read, password_hasher
|
||||
):
|
||||
"""Test that generated codes have proper format."""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = sample_user_read.id
|
||||
mock_user.mfa_enabled = True
|
||||
|
||||
# Create proper Response object for rate limiting
|
||||
mock_response = Response()
|
||||
|
||||
# Create proper Request object for rate limiting
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"headers": Headers({}).raw,
|
||||
"query_string": b"",
|
||||
"path": "/api/v1/auth/mfa/backup-codes",
|
||||
"client": ("testclient", 50000),
|
||||
"server": ("testserver", 80),
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
users_crud, "get_user_by_id", return_value=mock_user
|
||||
), patch.object(backup_crud, "create_backup_codes", return_value=["A3K9-7BDF"]):
|
||||
# Act
|
||||
result = await backup_router.generate_mfa_backup_codes(
|
||||
response=mock_response,
|
||||
request=mock_request,
|
||||
token_user_id=sample_user_read.id,
|
||||
db=mock_db,
|
||||
password_hasher=password_hasher,
|
||||
)
|
||||
|
||||
# Assert
|
||||
for code in result.codes:
|
||||
assert len(code) == 9
|
||||
assert code[4] == "-"
|
||||
assert code.isupper()
|
||||
153
backend/tests/auth/mfa_backup_codes/test_utils.py
Normal file
153
backend/tests/auth/mfa_backup_codes/test_utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for MFA backup codes utilities."""
|
||||
|
||||
import pytest
|
||||
import string
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import auth.mfa_backup_codes.utils as backup_utils
|
||||
import auth.mfa_backup_codes.crud as backup_crud
|
||||
|
||||
|
||||
class TestGenerateBackupCode:
|
||||
"""Test suite for generate_backup_code function."""
|
||||
|
||||
def test_generate_backup_code_format(self):
|
||||
"""Test that generated code has correct format (XXXX-XXXX)."""
|
||||
code = backup_utils.generate_backup_code()
|
||||
|
||||
assert len(code) == 9, "Code should be 9 characters (8 + 1 dash)"
|
||||
assert code[4] == "-", "Character at position 4 should be dash"
|
||||
assert code[:4].isalnum(), "First 4 characters should be alphanumeric"
|
||||
assert code[5:].isalnum(), "Last 4 characters should be alphanumeric"
|
||||
assert code.isupper(), "Code should be uppercase"
|
||||
|
||||
def test_generate_backup_code_no_ambiguous_chars(self):
|
||||
"""Test that generated codes don't contain ambiguous characters."""
|
||||
ambiguous_chars = {"0", "O", "1", "I"}
|
||||
|
||||
# Generate multiple codes to test randomness
|
||||
for _ in range(100):
|
||||
code = backup_utils.generate_backup_code().replace("-", "")
|
||||
for char in code:
|
||||
assert (
|
||||
char not in ambiguous_chars
|
||||
), f"Code contains ambiguous character: {char}"
|
||||
|
||||
def test_generate_backup_code_uniqueness(self):
|
||||
"""Test that generated codes are unique."""
|
||||
codes = [backup_utils.generate_backup_code() for _ in range(100)]
|
||||
|
||||
# All codes should be unique (extremely high probability)
|
||||
assert len(set(codes)) == 100, "Generated codes should be unique"
|
||||
|
||||
def test_generate_backup_code_character_set(self):
|
||||
"""Test that generated codes only use allowed characters."""
|
||||
allowed_chars = set(string.ascii_uppercase + string.digits) - {
|
||||
"0",
|
||||
"O",
|
||||
"1",
|
||||
"I",
|
||||
}
|
||||
|
||||
for _ in range(50):
|
||||
code = backup_utils.generate_backup_code().replace("-", "")
|
||||
for char in code:
|
||||
assert char in allowed_chars, f"Invalid character in code: {char}"
|
||||
|
||||
|
||||
class TestVerifyAndConsumeBackupCode:
|
||||
"""Test suite for verify_and_consume_backup_code function."""
|
||||
|
||||
def test_verify_valid_code_success(self, mock_db, password_hasher):
|
||||
"""Test successful verification of valid backup code."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
code = "A3K97BDF"
|
||||
|
||||
mock_code_obj = MagicMock()
|
||||
mock_code_obj.code_hash = password_hasher.hash_password(code)
|
||||
mock_code_obj.used = False
|
||||
|
||||
with patch.object(
|
||||
backup_crud, "get_user_unused_backup_codes", return_value=[mock_code_obj]
|
||||
), patch.object(backup_crud, "mark_backup_code_as_used") as mock_mark_used:
|
||||
|
||||
# Act
|
||||
result = backup_utils.verify_and_consume_backup_code(
|
||||
user_id, code, password_hasher, mock_db
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_mark_used.assert_called_once_with(
|
||||
mock_code_obj.code_hash, user_id, mock_db
|
||||
)
|
||||
|
||||
def test_verify_invalid_code_failure(self, mock_db, password_hasher):
|
||||
"""Test verification fails with invalid code."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
correct_code = "A3K97BDF"
|
||||
wrong_code = "WRONGCOD"
|
||||
|
||||
mock_code_obj = MagicMock()
|
||||
mock_code_obj.code_hash = password_hasher.hash_password(correct_code)
|
||||
mock_code_obj.used = False
|
||||
|
||||
with patch.object(
|
||||
backup_crud, "get_user_unused_backup_codes", return_value=[mock_code_obj]
|
||||
), patch.object(backup_crud, "mark_backup_code_as_used") as mock_mark_used:
|
||||
|
||||
# Act
|
||||
result = backup_utils.verify_and_consume_backup_code(
|
||||
user_id, wrong_code, password_hasher, mock_db
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_mark_used.assert_not_called()
|
||||
|
||||
def test_verify_no_unused_codes(self, mock_db, password_hasher):
|
||||
"""Test verification fails when no unused codes exist."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
code = "A3K97BDF"
|
||||
|
||||
with patch.object(
|
||||
backup_crud, "get_user_unused_backup_codes", return_value=[]
|
||||
), patch.object(backup_crud, "mark_backup_code_as_used") as mock_mark_used:
|
||||
|
||||
# Act
|
||||
result = backup_utils.verify_and_consume_backup_code(
|
||||
user_id, code, password_hasher, mock_db
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_mark_used.assert_not_called()
|
||||
|
||||
def test_verify_multiple_codes_finds_match(self, mock_db, password_hasher):
|
||||
"""Test verification finds correct code among multiple codes."""
|
||||
# Arrange
|
||||
user_id = 1
|
||||
correct_code = "CORRECT9"
|
||||
|
||||
mock_codes = []
|
||||
for code_str in ["WRONG111", "WRONG222", correct_code, "WRONG333"]:
|
||||
mock_code = MagicMock()
|
||||
mock_code.code_hash = password_hasher.hash_password(code_str)
|
||||
mock_code.used = False
|
||||
mock_codes.append(mock_code)
|
||||
|
||||
with patch.object(
|
||||
backup_crud, "get_user_unused_backup_codes", return_value=mock_codes
|
||||
), patch.object(backup_crud, "mark_backup_code_as_used") as mock_mark_used:
|
||||
|
||||
# Act
|
||||
result = backup_utils.verify_and_consume_backup_code(
|
||||
user_id, correct_code, password_hasher, mock_db
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_mark_used.assert_called_once()
|
||||
1
backend/tests/auth/oauth_state/__init__.py
Normal file
1
backend/tests/auth/oauth_state/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for OAuth state module."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
420
backend/tests/auth/oauth_state/test_crud.py
Normal file
420
backend/tests/auth/oauth_state/test_crud.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""Tests for OAuth state CRUD operations."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
import auth.oauth_state.crud as oauth_state_crud
|
||||
import auth.oauth_state.models as oauth_state_models
|
||||
import session.models as session_models
|
||||
|
||||
|
||||
class TestGetOAuthStateById:
|
||||
"""Test suite for get_oauth_state_by_id function."""
|
||||
|
||||
def test_get_oauth_state_by_id_success(self, mock_db):
|
||||
"""Test successful retrieval of valid OAuth state."""
|
||||
# Arrange
|
||||
state_id = "test_state_12345678"
|
||||
mock_oauth_state = MagicMock(spec=oauth_state_models.OAuthState)
|
||||
mock_oauth_state.id = state_id
|
||||
mock_oauth_state.expires_at = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||
mock_oauth_state.used = False
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = mock_oauth_state
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.get_oauth_state_by_id(state_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == mock_oauth_state
|
||||
mock_db.query.assert_called_once_with(oauth_state_models.OAuthState)
|
||||
|
||||
def test_get_oauth_state_not_found(self, mock_db):
|
||||
"""Test OAuth state not found returns None."""
|
||||
# Arrange
|
||||
state_id = "nonexistent_state"
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.get_oauth_state_by_id(state_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_oauth_state_expired(self, mock_db):
|
||||
"""Test expired OAuth state returns None."""
|
||||
# Arrange
|
||||
state_id = "expired_state_12345678"
|
||||
mock_oauth_state = MagicMock(spec=oauth_state_models.OAuthState)
|
||||
mock_oauth_state.id = state_id
|
||||
mock_oauth_state.expires_at = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
mock_oauth_state.used = False
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = mock_oauth_state
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.get_oauth_state_by_id(state_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_oauth_state_already_used(self, mock_db):
|
||||
"""Test already used OAuth state returns None (replay protection)."""
|
||||
# Arrange
|
||||
state_id = "used_state_12345678"
|
||||
mock_oauth_state = MagicMock(spec=oauth_state_models.OAuthState)
|
||||
mock_oauth_state.id = state_id
|
||||
mock_oauth_state.expires_at = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||
mock_oauth_state.used = True
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = mock_oauth_state
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.get_oauth_state_by_id(state_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetOAuthStateBySessionId:
|
||||
"""Test suite for get_oauth_state_by_session_id function."""
|
||||
|
||||
def test_get_oauth_state_by_session_id_success(self, mock_db):
|
||||
"""Test successful retrieval of OAuth state via session."""
|
||||
# Arrange
|
||||
session_id = "session_123"
|
||||
oauth_state_id = "state_456"
|
||||
|
||||
mock_session = MagicMock(spec=session_models.UsersSessions)
|
||||
mock_session.id = session_id
|
||||
mock_session.oauth_state_id = oauth_state_id
|
||||
|
||||
mock_oauth_state = MagicMock(spec=oauth_state_models.OAuthState)
|
||||
mock_oauth_state.id = oauth_state_id
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.side_effect = [mock_session, mock_oauth_state]
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.get_oauth_state_by_session_id(mock_db, session_id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_oauth_state
|
||||
|
||||
def test_get_oauth_state_session_not_found(self, mock_db):
|
||||
"""Test OAuth state retrieval when session doesn't exist."""
|
||||
# Arrange
|
||||
session_id = "nonexistent_session"
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.get_oauth_state_by_session_id(mock_db, session_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_oauth_state_no_oauth_state_id(self, mock_db):
|
||||
"""Test OAuth state retrieval when session has no oauth_state_id."""
|
||||
# Arrange
|
||||
session_id = "session_without_oauth"
|
||||
|
||||
mock_session = MagicMock(spec=session_models.UsersSessions)
|
||||
mock_session.id = session_id
|
||||
mock_session.oauth_state_id = None
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = mock_session
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.get_oauth_state_by_session_id(mock_db, session_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreateOAuthState:
|
||||
"""Test suite for create_oauth_state function."""
|
||||
|
||||
def test_create_oauth_state_minimal(self, mock_db):
|
||||
"""Test OAuth state creation with minimal required fields."""
|
||||
# Arrange
|
||||
state_id = "test_state_12345678"
|
||||
idp_id = 1
|
||||
nonce = "test_nonce_123"
|
||||
client_type = "web"
|
||||
ip_address = "192.168.1.1"
|
||||
|
||||
with patch("auth.oauth_state.crud.oauth_state_models.OAuthState") as mock_model:
|
||||
mock_oauth_state = MagicMock()
|
||||
mock_model.return_value = mock_oauth_state
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.create_oauth_state(
|
||||
mock_db,
|
||||
state_id=state_id,
|
||||
idp_id=idp_id,
|
||||
nonce=nonce,
|
||||
client_type=client_type,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_model.assert_called_once()
|
||||
mock_db.add.assert_called_once_with(mock_oauth_state)
|
||||
mock_db.commit.assert_called_once()
|
||||
mock_db.refresh.assert_called_once_with(mock_oauth_state)
|
||||
assert result == mock_oauth_state
|
||||
|
||||
def test_create_oauth_state_with_pkce(self, mock_db):
|
||||
"""Test OAuth state creation with PKCE for mobile."""
|
||||
# Arrange
|
||||
state_id = "mobile_state_12345678"
|
||||
idp_id = 1
|
||||
nonce = "test_nonce_123"
|
||||
client_type = "mobile"
|
||||
ip_address = "192.168.1.1"
|
||||
code_challenge = "test_challenge"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
with patch("auth.oauth_state.crud.oauth_state_models.OAuthState") as mock_model:
|
||||
mock_oauth_state = MagicMock()
|
||||
mock_model.return_value = mock_oauth_state
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.create_oauth_state(
|
||||
mock_db,
|
||||
state_id=state_id,
|
||||
idp_id=idp_id,
|
||||
nonce=nonce,
|
||||
client_type=client_type,
|
||||
ip_address=ip_address,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_oauth_state
|
||||
call_kwargs = mock_model.call_args[1]
|
||||
assert call_kwargs["code_challenge"] == code_challenge
|
||||
assert call_kwargs["code_challenge_method"] == code_challenge_method
|
||||
|
||||
def test_create_oauth_state_with_user_id(self, mock_db):
|
||||
"""Test OAuth state creation with user_id for link mode."""
|
||||
# Arrange
|
||||
state_id = "link_state_12345678"
|
||||
idp_id = 1
|
||||
nonce = "test_nonce_123"
|
||||
client_type = "web"
|
||||
ip_address = "192.168.1.1"
|
||||
user_id = 42
|
||||
|
||||
with patch("auth.oauth_state.crud.oauth_state_models.OAuthState") as mock_model:
|
||||
mock_oauth_state = MagicMock()
|
||||
mock_model.return_value = mock_oauth_state
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.create_oauth_state(
|
||||
mock_db,
|
||||
state_id=state_id,
|
||||
idp_id=idp_id,
|
||||
nonce=nonce,
|
||||
client_type=client_type,
|
||||
ip_address=ip_address,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_oauth_state
|
||||
call_kwargs = mock_model.call_args[1]
|
||||
assert call_kwargs["user_id"] == user_id
|
||||
|
||||
def test_create_oauth_state_sets_expiry(self, mock_db):
|
||||
"""Test OAuth state creation sets 10-minute expiry."""
|
||||
# Arrange
|
||||
state_id = "expiry_test_state"
|
||||
idp_id = 1
|
||||
nonce = "test_nonce_123"
|
||||
client_type = "web"
|
||||
ip_address = "192.168.1.1"
|
||||
|
||||
with patch("auth.oauth_state.crud.oauth_state_models.OAuthState") as mock_model:
|
||||
mock_oauth_state = MagicMock()
|
||||
mock_model.return_value = mock_oauth_state
|
||||
|
||||
# Act
|
||||
oauth_state_crud.create_oauth_state(
|
||||
mock_db,
|
||||
state_id=state_id,
|
||||
idp_id=idp_id,
|
||||
nonce=nonce,
|
||||
client_type=client_type,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_model.call_args[1]
|
||||
expires_at = call_kwargs["expires_at"]
|
||||
now = datetime.now(timezone.utc)
|
||||
expected_expiry = now + timedelta(minutes=10)
|
||||
|
||||
# Allow 5 second tolerance for test execution time
|
||||
time_diff = abs((expires_at - expected_expiry).total_seconds())
|
||||
assert time_diff < 5
|
||||
|
||||
|
||||
class TestMarkOAuthStateUsed:
|
||||
"""Test suite for mark_oauth_state_used function."""
|
||||
|
||||
def test_mark_oauth_state_used_success(self, mock_db):
|
||||
"""Test successful marking of OAuth state as used."""
|
||||
# Arrange
|
||||
state_id = "test_state_12345678"
|
||||
mock_oauth_state = MagicMock(spec=oauth_state_models.OAuthState)
|
||||
mock_oauth_state.id = state_id
|
||||
mock_oauth_state.used = False
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = mock_oauth_state
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.mark_oauth_state_used(mock_db, state_id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_oauth_state
|
||||
assert mock_oauth_state.used is True
|
||||
mock_db.commit.assert_called_once()
|
||||
mock_db.refresh.assert_called_once_with(mock_oauth_state)
|
||||
|
||||
def test_mark_oauth_state_used_not_found(self, mock_db):
|
||||
"""Test marking non-existent OAuth state returns None."""
|
||||
# Arrange
|
||||
state_id = "nonexistent_state"
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.mark_oauth_state_used(mock_db, state_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_db.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestDeleteOAuthState:
|
||||
"""Test suite for delete_oauth_state function."""
|
||||
|
||||
def test_delete_oauth_state_success(self, mock_db):
|
||||
"""Test successful deletion of OAuth state."""
|
||||
# Arrange
|
||||
oauth_state_id = "test_state_12345678"
|
||||
expected_count = 1
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.delete.return_value = expected_count
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.delete_oauth_state(oauth_state_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == expected_count
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_delete_oauth_state_not_found(self, mock_db):
|
||||
"""Test deletion when OAuth state doesn't exist."""
|
||||
# Arrange
|
||||
oauth_state_id = "nonexistent_state"
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.delete.return_value = 0
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.delete_oauth_state(oauth_state_id, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_delete_oauth_state_exception(self, mock_db):
|
||||
"""Test exception handling in delete_oauth_state."""
|
||||
# Arrange
|
||||
oauth_state_id = "error_state"
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
oauth_state_crud.delete_oauth_state(oauth_state_id, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert exc_info.value.detail == "Failed to delete OAuth state"
|
||||
mock_db.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestDeleteExpiredOAuthStates:
|
||||
"""Test suite for delete_expired_oauth_states function."""
|
||||
|
||||
def test_delete_expired_oauth_states_success(self, mock_db):
|
||||
"""Test successful deletion of expired OAuth states."""
|
||||
# Arrange
|
||||
expected_count = 5
|
||||
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.delete.return_value = expected_count
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.delete_expired_oauth_states(mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == expected_count
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_delete_expired_oauth_states_none_found(self, mock_db):
|
||||
"""Test deletion when no expired states exist."""
|
||||
# Arrange
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.delete.return_value = 0
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.delete_expired_oauth_states(mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_delete_expired_oauth_states_cutoff(self, mock_db):
|
||||
"""Test expired states cutoff is 10 minutes in the past."""
|
||||
# Arrange
|
||||
mock_query = mock_db.query.return_value
|
||||
mock_filter = mock_query.filter.return_value
|
||||
mock_filter.delete.return_value = 0
|
||||
|
||||
# Act
|
||||
result = oauth_state_crud.delete_expired_oauth_states(mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_db.query.assert_called_once()
|
||||
mock_query.filter.assert_called_once()
|
||||
mock_filter.delete.assert_called_once()
|
||||
94
backend/tests/auth/oauth_state/test_utils.py
Normal file
94
backend/tests/auth/oauth_state/test_utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Tests for OAuth state utility functions."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import auth.oauth_state.utils as oauth_state_utils
|
||||
import auth.oauth_state.crud as oauth_state_crud
|
||||
|
||||
|
||||
class TestDeleteExpiredOAuthStatesFromDb:
|
||||
"""Test suite for delete_expired_oauth_states_from_db function."""
|
||||
|
||||
def test_delete_expired_oauth_states_with_deletions(self):
|
||||
"""Test cleanup when expired states exist."""
|
||||
# Arrange
|
||||
num_deleted = 5
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch(
|
||||
"auth.oauth_state.utils.SessionLocal"
|
||||
) as mock_session_local, patch.object(
|
||||
oauth_state_crud, "delete_expired_oauth_states", return_value=num_deleted
|
||||
) as mock_delete:
|
||||
|
||||
mock_session_local.return_value.__enter__.return_value = mock_db
|
||||
|
||||
# Act
|
||||
oauth_state_utils.delete_expired_oauth_states_from_db()
|
||||
|
||||
# Assert
|
||||
mock_delete.assert_called_once_with(mock_db)
|
||||
mock_session_local.assert_called_once()
|
||||
|
||||
def test_delete_expired_oauth_states_no_deletions(self):
|
||||
"""Test cleanup when no expired states exist."""
|
||||
# Arrange
|
||||
num_deleted = 0
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch(
|
||||
"auth.oauth_state.utils.SessionLocal"
|
||||
) as mock_session_local, patch.object(
|
||||
oauth_state_crud, "delete_expired_oauth_states", return_value=num_deleted
|
||||
) as mock_delete:
|
||||
|
||||
mock_session_local.return_value.__enter__.return_value = mock_db
|
||||
|
||||
# Act
|
||||
oauth_state_utils.delete_expired_oauth_states_from_db()
|
||||
|
||||
# Assert
|
||||
mock_delete.assert_called_once_with(mock_db)
|
||||
|
||||
def test_delete_expired_oauth_states_exception_handling(self):
|
||||
"""Test exception handling in cleanup function."""
|
||||
# Arrange
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch(
|
||||
"auth.oauth_state.utils.SessionLocal"
|
||||
) as mock_session_local, patch.object(
|
||||
oauth_state_crud,
|
||||
"delete_expired_oauth_states",
|
||||
side_effect=Exception("Database error"),
|
||||
):
|
||||
|
||||
mock_session_local.return_value.__enter__.return_value = mock_db
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
oauth_state_utils.delete_expired_oauth_states_from_db()
|
||||
|
||||
def test_delete_expired_oauth_states_context_manager(self):
|
||||
"""Test session context manager is properly used."""
|
||||
# Arrange
|
||||
num_deleted = 3
|
||||
mock_db = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_db
|
||||
mock_context.__exit__.return_value = None
|
||||
|
||||
with patch(
|
||||
"auth.oauth_state.utils.SessionLocal", return_value=mock_context
|
||||
) as mock_session_local, patch.object(
|
||||
oauth_state_crud, "delete_expired_oauth_states", return_value=num_deleted
|
||||
):
|
||||
|
||||
# Act
|
||||
oauth_state_utils.delete_expired_oauth_states_from_db()
|
||||
|
||||
# Assert
|
||||
mock_session_local.assert_called_once()
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
Reference in New Issue
Block a user